Skip to content
Snippets Groups Projects
Commit 614fe2ea authored by Pratik Bhatu's avatar Pratik Bhatu
Browse files

Move loop versions of convolution to cpp library

parent 29656c7a
No related branches found
No related tags found
No related merge requests found
......@@ -455,26 +455,6 @@ def void Conv2DCSF(int32_pl N, int32_pl H, int32_pl W, int32_pl CI,
Conv2DReshapeMatMulOP(N, newH, newW, CO, matmulOP, outArr);
}
(* int32_al[N][H][W][CI] inputArr,
int32_al[FH][FW][CI][CO] filterArr,
int32_al[N][((H-FH+zPadHLeft+zPadHRight)/strideH)+1][((W-FW+zPadWLeft+zPadWRight)/strideW)+1][CO] outArr
*)
def void Conv2DCSFLoop(int32_pl N, int32_pl H, int32_pl W, int32_pl CI,
int32_pl FH, int32_pl FW, int32_pl CO,
int32_pl zPadHLeft, int32_pl zPadHRight, int32_pl zPadWLeft, int32_pl zPadWRight,
int32_pl strideH, int32_pl strideW, int32_pl G,
int32_al[N][H][W][CI] inputArr,
int32_al[FH][FW][CI][CO] filterArr,
int32_pl consSF,
int32_al[N][((H-FH+(zPadHLeft+zPadHRight))/strideH)+1][((W-FW+(zPadWLeft+zPadWRight))/strideW)+1][CO] outArr)
{
int32_pl outH = ((H-FH+(zPadHLeft+zPadHRight))/strideH)+1;
int32_pl outW = ((W-FW+(zPadWLeft+zPadWRight))/strideW)+1;
Conv2DLoop(N, H, W, CI, FH, FW, CO, zPadHLeft, zPadHRight, zPadWLeft, zPadWRight, strideH, strideW, outH, outW, G, inputArr, filterArr, consSF, outArr);
}
(**************************)
(* Generic implementation of Conv2D with Groups *)
......@@ -669,27 +649,6 @@ def void Conv3DReshapeInput(int32_pl N, int32_pl D, int32_pl H, int32_pl W, int3
};
}
(* int32_al[N][D][H][W][CI] inputArr,
int32_al[FD][FH][FW][CI][CO] filterArr,
int32_al[N][((D-FD+zPadDLeft+zPadDRight)/strideD)+1][((H-FH+zPadHLeft+zPadHRight)/strideH)+1][((W-FW+zPadWLeft+zPadWRight)/strideW)+1][CO] outArr
*)
(* Loop implementation of convolution run faster with multithreadin *)
def void Conv3DCSFLoop(int32_pl N, int32_pl D, int32_pl H, int32_pl W, int32_pl CI,
int32_pl FD, int32_pl FH, int32_pl FW, int32_pl CO,
int32_pl zPadDLeft, int32_pl zPadDRight, int32_pl zPadHLeft, int32_pl zPadHRight, int32_pl zPadWLeft, int32_pl zPadWRight,
int32_pl strideD, int32_pl strideH, int32_pl strideW,
int32_al[N][D][H][W][CI] inputArr,
int32_al[FD][FH][FW][CI][CO] filterArr,
int32_pl consSF,
int32_al[N][((D-FD+(zPadDLeft+zPadDRight))/strideD)+1][((H-FH+(zPadHLeft+zPadHRight))/strideH)+1][((W-FW+(zPadWLeft+zPadWRight))/strideW)+1][CO] outArr)
{
int32_pl outD = ((D-FD+(zPadDLeft+zPadDRight))/strideD)+1;
int32_pl outH = ((H-FH+(zPadHLeft+zPadHRight))/strideH)+1;
int32_pl outW = ((W-FW+(zPadWLeft+zPadWRight))/strideW)+1;
Conv3DLoop(N, D, H, W, CI, FD, FH, FW, CO, zPadDLeft, zPadDRight, zPadHLeft, zPadHRight, zPadWLeft, zPadWRight, strideD, strideH, strideW, outD, outH, outW, inputArr, filterArr, consSF, outArr);
}
(* int32_al[N][D][H][W][CI] inputArr,
int32_al[FD][FH][FW][CI][CO] filterArr,
int32_al[N][((D-FD+zPadDLeft+zPadDRight)/strideD)+1][((H-FH+zPadHLeft+zPadHRight)/strideH)+1][((W-FW+zPadWLeft+zPadWRight)/strideW)+1][CO] outArr
......@@ -1016,23 +975,6 @@ def void ConvTranspose3DReshapeInput(int32_pl N, int32_pl DPrime, int32_pl HPrim
};
}
(* int32_al[N][DPrime][HPrime][WPrime][CI] inputArr,
int32_al[FD][FH][FW][CO][CI] filter,
int32_al[N][D][H][W][CO] outputArr
*)
def void ConvTranspose3DCSFLoop(int32_pl N, int32_pl DPrime, int32_pl HPrime, int32_pl WPrime, int32_pl CI,
int32_pl FD, int32_pl FH, int32_pl FW, int32_pl CO,
int32_pl D, int32_pl H, int32_pl W,
int32_pl zPadTrDLeft, int32_pl zPadTrDRight, int32_pl zPadTrHLeft, int32_pl zPadTrHRight, int32_pl zPadTrWLeft, int32_pl zPadTrWRight,
int32_pl strideD, int32_pl strideH, int32_pl strideW,
int32_al[N][DPrime][HPrime][WPrime][CI] inputArr,
int32_al[FD][FH][FW][CO][CI] filterArr,
int32_pl consSF,
int32_al[N][D][H][W][CO] outArr)
{
ConvTranspose3DLoop(N, DPrime, HPrime, WPrime, CI, FD, FH, FW, CO, zPadTrDLeft, zPadTrDRight, zPadTrHLeft, zPadTrHRight, zPadTrWLeft, zPadTrWRight, strideD, strideH, strideW, D, H, W, inputArr, filterArr, consSF, outArr);
}
(* int32_al[N][DPrime][HPrime][WPrime][CI] inputArr,
int32_al[FD][FH][FW][CO][CI] filter,
int32_al[N][D][H][W][CO] outputArr
......@@ -1081,4 +1023,4 @@ def void ClearMemPublic4(int32_pl s1, int32_pl s2, int32_pl s3, int32_pl s4, int
def void ClearMemPublic5(int32_pl s1, int32_pl s2, int32_pl s3, int32_pl s4, int32_pl s5, int32_pl[s1][s2][s3][s4][s5] arr)
{
return;
}
\ No newline at end of file
}
......@@ -84,6 +84,26 @@ def void Conv2DLoop(int32_pl N, int32_pl H, int32_pl W, int32_pl CI,
};
}
(* int32_al[N][H][W][CI] inputArr,
int32_al[FH][FW][CI][CO] filterArr,
int32_al[N][((H-FH+zPadHLeft+zPadHRight)/strideH)+1][((W-FW+zPadWLeft+zPadWRight)/strideW)+1][CO] outArr
*)
def void Conv2DCSFLoop(int32_pl N, int32_pl H, int32_pl W, int32_pl CI,
int32_pl FH, int32_pl FW, int32_pl CO,
int32_pl zPadHLeft, int32_pl zPadHRight, int32_pl zPadWLeft, int32_pl zPadWRight,
int32_pl strideH, int32_pl strideW, int32_pl G,
int32_al[N][H][W][CI] inputArr,
int32_al[FH][FW][CI][CO] filterArr,
int32_pl consSF,
int32_al[N][((H-FH+(zPadHLeft+zPadHRight))/strideH)+1][((W-FW+(zPadWLeft+zPadWRight))/strideW)+1][CO] outArr)
{
int32_pl outH = ((H-FH+(zPadHLeft+zPadHRight))/strideH)+1;
int32_pl outW = ((W-FW+(zPadWLeft+zPadWRight))/strideW)+1;
Conv2DLoop(N, H, W, CI, FH, FW, CO, zPadHLeft, zPadHRight, zPadWLeft, zPadWRight, strideH, strideW, outH, outW, G, inputArr, filterArr, consSF, outArr);
}
(**************************)
def void Conv3DLoop(int32_pl N, int32_pl D, int32_pl H, int32_pl W, int32_pl CI,
int32_pl FD, int32_pl FH, int32_pl FW, int32_pl CO,
......@@ -126,6 +146,26 @@ def void Conv3DLoop(int32_pl N, int32_pl D, int32_pl H, int32_pl W, int32_pl CI,
};
}
(* int32_al[N][D][H][W][CI] inputArr,
int32_al[FD][FH][FW][CI][CO] filterArr,
int32_al[N][((D-FD+zPadDLeft+zPadDRight)/strideD)+1][((H-FH+zPadHLeft+zPadHRight)/strideH)+1][((W-FW+zPadWLeft+zPadWRight)/strideW)+1][CO] outArr
*)
(* Loop implementation of convolution run faster with multithreadin *)
def void Conv3DCSFLoop(int32_pl N, int32_pl D, int32_pl H, int32_pl W, int32_pl CI,
int32_pl FD, int32_pl FH, int32_pl FW, int32_pl CO,
int32_pl zPadDLeft, int32_pl zPadDRight, int32_pl zPadHLeft, int32_pl zPadHRight, int32_pl zPadWLeft, int32_pl zPadWRight,
int32_pl strideD, int32_pl strideH, int32_pl strideW,
int32_al[N][D][H][W][CI] inputArr,
int32_al[FD][FH][FW][CI][CO] filterArr,
int32_pl consSF,
int32_al[N][((D-FD+(zPadDLeft+zPadDRight))/strideD)+1][((H-FH+(zPadHLeft+zPadHRight))/strideH)+1][((W-FW+(zPadWLeft+zPadWRight))/strideW)+1][CO] outArr)
{
int32_pl outD = ((D-FD+(zPadDLeft+zPadDRight))/strideD)+1;
int32_pl outH = ((H-FH+(zPadHLeft+zPadHRight))/strideH)+1;
int32_pl outW = ((W-FW+(zPadWLeft+zPadWRight))/strideW)+1;
Conv3DLoop(N, D, H, W, CI, FD, FH, FW, CO, zPadDLeft, zPadDRight, zPadHLeft, zPadHRight, zPadWLeft, zPadWRight, strideD, strideH, strideW, outD, outH, outW, inputArr, filterArr, consSF, outArr);
}
(**************************)
def void ConvTranspose3DLoop(int32_pl N, int32_pl D, int32_pl H, int32_pl W, int32_pl CI,
......@@ -171,7 +211,22 @@ def void ConvTranspose3DLoop(int32_pl N, int32_pl D, int32_pl H, int32_pl W, int
};
};
}
(* int32_al[N][DPrime][HPrime][WPrime][CI] inputArr,
int32_al[FD][FH][FW][CO][CI] filter,
int32_al[N][D][H][W][CO] outputArr
*)
def void ConvTranspose3DCSFLoop(int32_pl N, int32_pl DPrime, int32_pl HPrime, int32_pl WPrime, int32_pl CI,
int32_pl FD, int32_pl FH, int32_pl FW, int32_pl CO,
int32_pl D, int32_pl H, int32_pl W,
int32_pl zPadTrDLeft, int32_pl zPadTrDRight, int32_pl zPadTrHLeft, int32_pl zPadTrHRight, int32_pl zPadTrWLeft, int32_pl zPadTrWRight,
int32_pl strideD, int32_pl strideH, int32_pl strideW,
int32_al[N][DPrime][HPrime][WPrime][CI] inputArr,
int32_al[FD][FH][FW][CO][CI] filterArr,
int32_pl consSF,
int32_al[N][D][H][W][CO] outArr)
{
ConvTranspose3DLoop(N, DPrime, HPrime, WPrime, CI, FD, FH, FW, CO, zPadTrDLeft, zPadTrDRight, zPadTrHLeft, zPadTrHRight, zPadTrWLeft, zPadTrWRight, strideD, strideH, strideW, D, H, W, inputArr, filterArr, consSF, outArr);
}
(**************************)
def void ArgMax1(int32_pl outArrS1, int32_pl inArrS1, int32_pl inArrS2, int32_al[inArrS1][inArrS2] inArr, int32_pl dim, int32_al[outArrS1] outArr){
......@@ -594,4 +649,4 @@ def void StartComputation()
def void EndComputation()
{
return;
}
\ No newline at end of file
}
......@@ -460,20 +460,6 @@ def void Conv2DCSF(int32_pl N, int32_pl H, int32_pl W, int32_pl CI,
int64_al[N][((H-FH+zPadHLeft+zPadHRight)/strideH)+1][((W-FW+zPadWLeft+zPadWRight)/strideW)+1][CO] outArr
*)
def void Conv2DCSFLoop(int32_pl N, int32_pl H, int32_pl W, int32_pl CI,
int32_pl FH, int32_pl FW, int32_pl CO,
int32_pl zPadHLeft, int32_pl zPadHRight, int32_pl zPadWLeft, int32_pl zPadWRight,
int32_pl strideH, int32_pl strideW, int32_pl G,
int64_al[N][H][W][CI] inputArr,
int64_al[FH][FW][CI][CO] filterArr,
int32_pl consSF,
int64_al[N][((H-FH+(zPadHLeft+zPadHRight))/strideH)+1][((W-FW+(zPadWLeft+zPadWRight))/strideW)+1][CO] outArr)
{
int32_pl outH = ((H-FH+(zPadHLeft+zPadHRight))/strideH)+1;
int32_pl outW = ((W-FW+(zPadWLeft+zPadWRight))/strideW)+1;
Conv2DLoop(N, H, W, CI, FH, FW, CO, zPadHLeft, zPadHRight, zPadWLeft, zPadWRight, strideH, strideW, outH, outW, G, inputArr, filterArr, consSF, outArr);
}
(**************************)
(* Generic implementation of Conv2D with Groups *)
......@@ -669,26 +655,6 @@ def void Conv3DReshapeInput(int32_pl N, int32_pl D, int32_pl H, int32_pl W, int3
};
}
(* int64_al[N][D][H][W][CI] inputArr,
int64_al[FD][FH][FW][CI][CO] filterArr,
int64_al[N][((D-FD+zPadDLeft+zPadDRight)/strideD)+1][((H-FH+zPadHLeft+zPadHRight)/strideH)+1][((W-FW+zPadWLeft+zPadWRight)/strideW)+1][CO] outArr
*)
(* Loop implementation of convolution run faster with multithreadin *)
def void Conv3DCSFLoop(int32_pl N, int32_pl D, int32_pl H, int32_pl W, int32_pl CI,
int32_pl FD, int32_pl FH, int32_pl FW, int32_pl CO,
int32_pl zPadDLeft, int32_pl zPadDRight, int32_pl zPadHLeft, int32_pl zPadHRight, int32_pl zPadWLeft, int32_pl zPadWRight,
int32_pl strideD, int32_pl strideH, int32_pl strideW,
int64_al[N][D][H][W][CI] inputArr,
int64_al[FD][FH][FW][CI][CO] filterArr,
int32_pl consSF,
int64_al[N][((D-FD+(zPadDLeft+zPadDRight))/strideD)+1][((H-FH+(zPadHLeft+zPadHRight))/strideH)+1][((W-FW+(zPadWLeft+zPadWRight))/strideW)+1][CO] outArr)
{
int32_pl outD = ((D-FD+(zPadDLeft+zPadDRight))/strideD)+1;
int32_pl outH = ((H-FH+(zPadHLeft+zPadHRight))/strideH)+1;
int32_pl outW = ((W-FW+(zPadWLeft+zPadWRight))/strideW)+1;
Conv3DLoop(N, D, H, W, CI, FD, FH, FW, CO, zPadDLeft, zPadDRight, zPadHLeft, zPadHRight, zPadWLeft, zPadWRight, strideD, strideH, strideW, outD, outH, outW, inputArr, filterArr, consSF, outArr);
}
(* int64_al[N][D][H][W][CI] inputArr,
int64_al[FD][FH][FW][CI][CO] filterArr,
......@@ -1015,24 +981,6 @@ def void ConvTranspose3DReshapeInput(int32_pl N, int32_pl DPrime, int32_pl HPrim
};
};
}
(* int64_al[N][DPrime][HPrime][WPrime][CI] inputArr,
int64_al[FD][FH][FW][CO][CI] filter,
int64_al[N][D][H][W][CO] outputArr
*)
def void ConvTranspose3DCSFLoop(int32_pl N, int32_pl DPrime, int32_pl HPrime, int32_pl WPrime, int32_pl CI,
int32_pl FD, int32_pl FH, int32_pl FW, int32_pl CO,
int32_pl D, int32_pl H, int32_pl W,
int32_pl zPadTrDLeft, int32_pl zPadTrDRight, int32_pl zPadTrHLeft, int32_pl zPadTrHRight, int32_pl zPadTrWLeft, int32_pl zPadTrWRight,
int32_pl strideD, int32_pl strideH, int32_pl strideW,
int64_al[N][DPrime][HPrime][WPrime][CI] inputArr,
int64_al[FD][FH][FW][CO][CI] filterArr,
int32_pl consSF,
int64_al[N][D][H][W][CO] outArr)
{
ConvTranspose3DLoop(N, DPrime, HPrime, WPrime, CI, FD, FH, FW, CO, zPadTrDLeft, zPadTrDRight, zPadTrHLeft, zPadTrHRight, zPadTrWLeft, zPadTrWRight, strideD, strideH, strideW, D, H, W, inputArr, filterArr, consSF, outArr);
}
(* int64_al[N][DPrime][HPrime][WPrime][CI] inputArr,
int64_al[FD][FH][FW][CO][CI] filter,
int64_al[N][D][H][W][CO] outputArr
......@@ -1081,4 +1029,4 @@ def void ClearMemPublic4(int32_pl s1, int32_pl s2, int32_pl s3, int32_pl s4, int
def void ClearMemPublic5(int32_pl s1, int32_pl s2, int32_pl s3, int32_pl s4, int32_pl s5, int32_pl[s1][s2][s3][s4][s5] arr)
{
return;
}
\ No newline at end of file
}
......@@ -83,6 +83,25 @@ def void Conv2DLoop(int32_pl N, int32_pl H, int32_pl W, int32_pl CI,
};
};
}
(* int64_al[N][H][W][CI] inputArr,
int64_al[FH][FW][CI][CO] filterArr,
int64_al[N][((H-FH+zPadHLeft+zPadHRight)/strideH)+1][((W-FW+zPadWLeft+zPadWRight)/strideW)+1][CO] outArr
*)
def void Conv2DCSFLoop(int32_pl N, int32_pl H, int32_pl W, int32_pl CI,
int32_pl FH, int32_pl FW, int32_pl CO,
int32_pl zPadHLeft, int32_pl zPadHRight, int32_pl zPadWLeft, int32_pl zPadWRight,
int32_pl strideH, int32_pl strideW, int32_pl G,
int64_al[N][H][W][CI] inputArr,
int64_al[FH][FW][CI][CO] filterArr,
int32_pl consSF,
int64_al[N][((H-FH+(zPadHLeft+zPadHRight))/strideH)+1][((W-FW+(zPadWLeft+zPadWRight))/strideW)+1][CO] outArr)
{
int32_pl outH = ((H-FH+(zPadHLeft+zPadHRight))/strideH)+1;
int32_pl outW = ((W-FW+(zPadWLeft+zPadWRight))/strideW)+1;
Conv2DLoop(N, H, W, CI, FH, FW, CO, zPadHLeft, zPadHRight, zPadWLeft, zPadWRight, strideH, strideW, outH, outW, G, inputArr, filterArr, consSF, outArr);
}
(**************************)
def void Conv3DLoop(int32_pl N, int32_pl D, int32_pl H, int32_pl W, int32_pl CI,
......@@ -126,6 +145,26 @@ def void Conv3DLoop(int32_pl N, int32_pl D, int32_pl H, int32_pl W, int32_pl CI,
};
}
(* int64_al[N][D][H][W][CI] inputArr,
int64_al[FD][FH][FW][CI][CO] filterArr,
int64_al[N][((D-FD+zPadDLeft+zPadDRight)/strideD)+1][((H-FH+zPadHLeft+zPadHRight)/strideH)+1][((W-FW+zPadWLeft+zPadWRight)/strideW)+1][CO] outArr
*)
(* Loop implementation of convolution run faster with multithreadin *)
def void Conv3DCSFLoop(int32_pl N, int32_pl D, int32_pl H, int32_pl W, int32_pl CI,
int32_pl FD, int32_pl FH, int32_pl FW, int32_pl CO,
int32_pl zPadDLeft, int32_pl zPadDRight, int32_pl zPadHLeft, int32_pl zPadHRight, int32_pl zPadWLeft, int32_pl zPadWRight,
int32_pl strideD, int32_pl strideH, int32_pl strideW,
int64_al[N][D][H][W][CI] inputArr,
int64_al[FD][FH][FW][CI][CO] filterArr,
int32_pl consSF,
int64_al[N][((D-FD+(zPadDLeft+zPadDRight))/strideD)+1][((H-FH+(zPadHLeft+zPadHRight))/strideH)+1][((W-FW+(zPadWLeft+zPadWRight))/strideW)+1][CO] outArr)
{
int32_pl outD = ((D-FD+(zPadDLeft+zPadDRight))/strideD)+1;
int32_pl outH = ((H-FH+(zPadHLeft+zPadHRight))/strideH)+1;
int32_pl outW = ((W-FW+(zPadWLeft+zPadWRight))/strideW)+1;
Conv3DLoop(N, D, H, W, CI, FD, FH, FW, CO, zPadDLeft, zPadDRight, zPadHLeft, zPadHRight, zPadWLeft, zPadWRight, strideD, strideH, strideW, outD, outH, outW, inputArr, filterArr, consSF, outArr);
}
(**************************)
def void ConvTranspose3DLoop(int32_pl N, int32_pl D, int32_pl H, int32_pl W, int32_pl CI,
......@@ -171,7 +210,22 @@ def void ConvTranspose3DLoop(int32_pl N, int32_pl D, int32_pl H, int32_pl W, int
};
};
}
(* int64_al[N][DPrime][HPrime][WPrime][CI] inputArr,
int64_al[FD][FH][FW][CO][CI] filter,
int64_al[N][D][H][W][CO] outputArr
*)
def void ConvTranspose3DCSFLoop(int32_pl N, int32_pl DPrime, int32_pl HPrime, int32_pl WPrime, int32_pl CI,
int32_pl FD, int32_pl FH, int32_pl FW, int32_pl CO,
int32_pl D, int32_pl H, int32_pl W,
int32_pl zPadTrDLeft, int32_pl zPadTrDRight, int32_pl zPadTrHLeft, int32_pl zPadTrHRight, int32_pl zPadTrWLeft, int32_pl zPadTrWRight,
int32_pl strideD, int32_pl strideH, int32_pl strideW,
int64_al[N][DPrime][HPrime][WPrime][CI] inputArr,
int64_al[FD][FH][FW][CO][CI] filterArr,
int32_pl consSF,
int64_al[N][D][H][W][CO] outArr)
{
ConvTranspose3DLoop(N, DPrime, HPrime, WPrime, CI, FD, FH, FW, CO, zPadTrDLeft, zPadTrDRight, zPadTrHLeft, zPadTrHRight, zPadTrWLeft, zPadTrWRight, strideD, strideH, strideW, D, H, W, inputArr, filterArr, consSF, outArr);
}
(**************************)
def void ArgMax1(int32_pl outArrS1, int32_pl inArrS1, int32_pl inArrS2, int64_al[inArrS1][inArrS2] inArr, int32_pl dim, int64_al[outArrS1] outArr){
......@@ -594,4 +648,4 @@ def void StartComputation()
def void EndComputation()
{
return;
}
\ No newline at end of file
}
......@@ -76,7 +76,7 @@ extern void ClearMemSecret1(int32_pl s1, int64_al[s1] arr);
extern void ClearMemSecret2(int32_pl s1, int32_pl s2, int64_al[s1][s2] arr);
extern void ClearMemSecret3(int32_pl s1, int32_pl s2, int32_pl s3, int64_al[s1][s2][s3] arr);
extern void ClearMemSecret4(int32_pl s1, int32_pl s2, int32_pl s3, int32_pl s4, int64_al[s1][s2][s3][s4] arr);
extern void ClearMemSecret5(int32_pl s1, int32_pl s2, int32_pl s3, int32_pl s4, int32_pl s5, int64_al[s1][s2][s3][s4][s5] arr)
extern void ClearMemSecret5(int32_pl s1, int32_pl s2, int32_pl s3, int32_pl s4, int32_pl s5, int64_al[s1][s2][s3][s4][s5] arr);
extern void ClearMemPublic2(int32_pl s1, int32_pl s2, int32_pl[s1][s2] arr);
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment