-
Bhatu authored
-- Implicit broadcasting: If any of the dims in the input is 1, the values along that dimension need to be broadcasted so as to match the output dimension. We add ternary operators that check at each iteration of that dim if the inputs have 1 as dim, and chosse array indexes appropriately. -- Add support for secure conv3d using multithreaded conv instead of matmul -- Fix bug related to overflow in send/receive message -- ONNXCompiler: Generate input and model parameter inputs separately for 3pc computation too.
Bhatu authored-- Implicit broadcasting: If any of the dims in the input is 1, the values along that dimension need to be broadcasted so as to match the output dimension. We add ternary operators that check at each iteration of that dim if the inputs have 1 as dim, and chosse array indexes appropriately. -- Add support for secure conv3d using multithreaded conv instead of matmul -- Fix bug related to overflow in send/receive message -- ONNXCompiler: Generate input and model parameter inputs separately for 3pc computation too.
Library32_common.ezpc 64.70 KiB
(*
Authors: Nishant Kumar.
Copyright:
Copyright (c) 2018 Microsoft Research
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
*)
(**************************)
def void MatAddBroadCast2(int32_pl a1, int32_pl a2, int32_pl b1, int32_pl s1, int32_pl s2, int32_al[a1][a2] A, int32_al[b1] B, int32_al[s1][s2] outArr){
for i1=[0:s1]{
for i2=[0:s2]{
outArr[i1][i2] = A[i1][i2] + B[i2];
};
};
}
def void MatAdd2(int32_pl a1, int32_pl a2, int32_pl b1, int32_pl b2, int32_pl s1, int32_pl s2, int32_al[a1][a2] A, int32_al[b1][b2] B, int32_al[s1][s2] outArr){
int32_pl aIdx1 = 0;
int32_pl aIdx2 = 0;
int32_pl bIdx1 = 0;
int32_pl bIdx2 = 0;
for i1=[0:s1]{
aIdx1 = ((a1 == 1) ? 0 : i1);
bIdx1 = ((b1 == 1) ? 0 : i1);
for i2=[0:s2]{
aIdx2 = ((a2 == 1) ? 0 : i2);
bIdx2 = ((b2 == 1) ? 0 : i2);
outArr[i1][i2] = A[aIdx1][aIdx2] + B[bIdx1][bIdx2];
};
};
}
def void MatAddBroadCast4(int32_pl a1, int32_pl a2, int32_pl a3, int32_pl a4, int32_pl b1, int32_pl s1, int32_pl s2, int32_pl s3, int32_pl s4, int32_al[s1][s2][s3][s4] A, int32_al[s4] B, int32_al[s1][s2][s3][s4] outArr){def void MatAddBroadCast4(int32_pl a1, int32_pl a2, int32_pl a3, int32_pl a4, int32_pl b1, int32_pl s1, int32_pl s2, int32_pl s3, int32_pl s4, int32_al[s1][s2][s3][s4] A, int32_al[s4] B, int32_al[s1][s2][s3][s4] outArr){
for i1=[0:s1]{
for i2=[0:s2]{
for i3=[0:s3]{
for i4=[0:s4]{
outArr[i1][i2][i3][i4] = A[i1][i2][i3][i4] + B[i4];
};
};
};
};
}
def void MatAddBroadCast5(int32_pl a1, int32_pl a2, int32_pl a3, int32_pl a4, int32_pl a5, int32_pl b1, int32_pl s1, int32_pl s2, int32_pl s3, int32_pl s4, int32_pl s5, int32_al[s1][s2][s3][s4][s5] A, int32_al[s5] B, int32_al[s1][s2][s3][s4][s5] outArr){
for i1=[0:s1]{
for i2=[0:s2]{
for i3=[0:s3]{
for i4=[0:s4]{
for i5=[0:s5]{
outArr[i1][i2][i3][i4][i5] = A[i1][i2][i3][i4][i5] + B[i5];
};
};
};
};
};
}
def void MatAdd4(int32_pl a1, int32_pl a2, int32_pl a3, int32_pl a4, int32_pl b1, int32_pl b2, int32_pl b3, int32_pl b4, int32_pl s1, int32_pl s2, int32_pl s3, int32_pl s4, int32_al[a1][a2][a3][a4] A, int32_al[b1][b2][b3][b4] B, int32_al[s1][s2][s3][s4] outArr){
int32_pl aIdx1 = 0;
int32_pl aIdx2 = 0;
int32_pl aIdx3 = 0;
int32_pl aIdx4 = 0;
int32_pl bIdx1 = 0;
int32_pl bIdx2 = 0;
int32_pl bIdx3 = 0;
int32_pl bIdx4 = 0;
for i1=[0:s1]{
aIdx1 = ((a1 == 1) ? 0 : i1);
bIdx1 = ((b1 == 1) ? 0 : i1);
for i2=[0:s2]{
aIdx2 = ((a2 == 1) ? 0 : i2);
bIdx2 = ((b2 == 1) ? 0 : i2);
for i3=[0:s3]{
aIdx3 = ((a3 == 1) ? 0 : i3);
bIdx3 = ((b3 == 1) ? 0 : i3);
for i4=[0:s4]{
aIdx4 = ((a4 == 1) ? 0 : i4);
bIdx4 = ((b4 == 1) ? 0 : i4);
outArr[i1][i2][i3][i4] = A[aIdx1][aIdx2][aIdx3][aIdx4] + B[bIdx1][bIdx2][bIdx3][bIdx4];
};
};
};
};
}
def void MatAdd5(int32_pl a1, int32_pl a2, int32_pl a3, int32_pl a4, int32_pl a5, int32_pl b1, int32_pl b2, int32_pl b3, int32_pl b4, int32_pl b5, int32_pl s1, int32_pl s2, int32_pl s3, int32_pl s4, int32_pl s5, int32_al[a1][a2][a3][a4][a5] A, int32_al[b1][b2][b3][b4][b5] B, int32_al[s1][s2][s3][s4][s5] outArr){
int32_pl aIdx1 = 0;
int32_pl aIdx2 = 0;
int32_pl aIdx3 = 0;
int32_pl aIdx4 = 0;
int32_pl aIdx5 = 0;
int32_pl bIdx1 = 0;
int32_pl bIdx2 = 0;
int32_pl bIdx3 = 0;
int32_pl bIdx4 = 0;
int32_pl bIdx5 = 0;
for i1=[0:s1]{
aIdx1 = ((a1 == 1) ? 0 : i1);
bIdx1 = ((b1 == 1) ? 0 : i1);
for i2=[0:s2]{
aIdx2 = ((a2 == 1) ? 0 : i2);
bIdx2 = ((b2 == 1) ? 0 : i2);
for i3=[0:s3]{
aIdx3 = ((a3 == 1) ? 0 : i3);
bIdx3 = ((b3 == 1) ? 0 : i3);
for i4=[0:s4]{
aIdx4 = ((a4 == 1) ? 0 : i4);
bIdx4 = ((b4 == 1) ? 0 : i4);
for i5=[0:s5]{
aIdx5 = ((a5 == 1) ? 0 : i5);
bIdx5 = ((b5 == 1) ? 0 : i5);
outArr[i1][i2][i3][i4][i5] = A[aIdx1][aIdx2][aIdx3][aIdx4][aIdx5] + B[bIdx1][bIdx2][bIdx3][bIdx4][bIdx5];
};
};
};
};
};
}
(**************************)
def void CreateTensor1(int32_pl s1, int32_pl val, int32_pl[s1] arr){
for i1=[0:s1]{
arr[i1] = val;
};
}
def void CreateTensor2(int32_pl s1, int32_pl s2, int32_pl val, int32_pl[s1][s2] arr){
for i1=[0:s1]{
for i2=[0:s2]{
arr[i1][i2] = val;
};
};
}
def void CreateTensor3(int32_pl s1, int32_pl s2, int32_pl s3, int32_pl val, int32_pl[s1][s2][s3] arr){
for i1=[0:s1]{
for i2=[0:s2]{
for i3=[0:s3]{
arr[i1][i2][i3] = val;
};
};
};
}
def void CreateTensor4(int32_pl s1, int32_pl s2, int32_pl s3, int32_pl s4, int32_pl val, int32_pl[s1][s2][s3][s4] arr){
for i1=[0:s1]{
for i2=[0:s2]{
for i3=[0:s3]{
for i4=[0:s4]{
arr[i1][i2][i3][i4] = val;
};
};
};
};
}
def void CreateTensor5(int32_pl s1, int32_pl s2, int32_pl s3, int32_pl s4, int32_pl s5, int32_pl val, int32_pl[s1][s2][s3][s4][s5] arr){
for i1=[0:s1]{
for i2=[0:s2]{
for i3=[0:s3]{
for i4=[0:s4]{
for i5=[0:s5]{
arr[i1][i2][i3][i4][i5] = val;
};
};
};
};
};
}
(**************************)
def void CopyTensor1(int32_pl s1, int32_al[s1] targetArr, int32_al[s1] fromArr, int32_al[s1] ignore){
for i1=[0:s1]{
targetArr[i1] = fromArr[i1];
};
}
def void CopyTensor2(int32_pl s1, int32_pl s2, int32_al[s1][s2] targetArr, int32_al[s1][s2] fromArr, int32_al[s1][s2] ignore){
for i1=[0:s1]{
for i2=[0:s2]{
targetArr[i1][i2] = fromArr[i1][i2];
};
};
}
def void CopyTensor3(int32_pl s1, int32_pl s2, int32_pl s3, int32_al[s1][s2][s3] targetArr, int32_al[s1][s2][s3] fromArr, int32_al[s1][s2][s3] ignore){
for i1=[0:s1]{
for i2=[0:s2]{
for i3=[0:s3]{
targetArr[i1][i2][i3] = fromArr[i1][i2][i3];
};
};
};
}
def void CopyTensor4(int32_pl s1, int32_pl s2, int32_pl s3, int32_pl s4, int32_al[s1][s2][s3][s4] targetArr, int32_al[s1][s2][s3][s4] fromArr, int32_al[s1][s2][s3][s4] ignore){
for i1=[0:s1]{
for i2=[0:s2]{
for i3=[0:s3]{
for i4=[0:s4]{
targetArr[i1][i2][i3][i4] = fromArr[i1][i2][i3][i4];
};
};
};
};
}
(**************************)
def void CreateIdentity11(int32_pl s1, int32_al[s1] fromArr, int32_al[s1] newArr){
for i1=[0:s1]{
newArr[i1] = fromArr[i1];
};
}
def void CreateIdentity22(int32_pl s1, int32_pl s2, int32_al[s1][s2] fromArr, int32_al[s1][s2] newArr){
for i1=[0:s1]{
for i2=[0:s2]{
newArr[i1][i2] = fromArr[i1][i2];
};
};
}
def void CreateIdentity33(int32_pl s1, int32_pl s2, int32_pl s3, int32_al[s1][s2][s3] fromArr, int32_al[s1][s2][s3] newArr){
for i1=[0:s1]{
for i2=[0:s2]{
for i3=[0:s3]{
newArr[i1][i2][i3] = fromArr[i1][i2][i3];
};
};
};
}
def void CreateIdentity44(int32_pl s1, int32_pl s2, int32_pl s3, int32_pl s4, int32_al[s1][s2][s3][s4] fromArr, int32_al[s1][s2][s3][s4] newArr){
for i1=[0:s1]{
for i2=[0:s2]{
for i3=[0:s3]{
for i4=[0:s4]{
newArr[i1][i2][i3][i4] = fromArr[i1][i2][i3][i4];
};
};
};
};
}
(**************************)
def void CreateCopy2211(int32_pl s1, int32_pl s2, int32_pl inps1, int32_pl inps2, int32_al[inps1][inps2] inArr, int32_pl perDimSize, int32_pl[perDimSize] beginIdx, int32_pl[perDimSize] sizeIdx, int32_al[s1][s2] outArr){
for i=[0:s1]{
for j=[0:s2]{
outArr[i][j] = inArr[beginIdx[0]+i][beginIdx[1]+j];
};
};
}
def void CreateCopy5511(int32_pl s1, int32_pl s2, int32_pl s3, int32_pl s4, int32_pl s5, int32_pl inps1, int32_pl inps2, int32_pl inps3, int32_pl inps4, int32_pl inps5, int32_al[inps1][inps2][inps3][inps4][inps5] inArr, int32_pl perDimSize, int32_pl[perDimSize] beginIdx, int32_pl[perDimSize] sizeIdx, int32_al[s1][s2][s3][s4][s5] outArr){
for i=[0:s1]{
for j=[0:s2]{
for k=[0:s3]{
for l=[0:s4]{
for m=[0:s5]{
outArr[i][j][k][l][m] = inArr[beginIdx[0]+i][beginIdx[1]+j][beginIdx[2]+k][beginIdx[3]+l][beginIdx[4]+m];
};
};
};
};
};
}
(**************************)
def void Concat2T222(int32_pl s1, int32_pl s2, int32_pl inp1s1, int32_pl inp1s2, int32_al[inp1s1][inp1s2] inp1, int32_pl inp2s1, int32_pl inp2s2, int32_al[inp2s1][inp2s2] inp2, int32_pl axis, int32_al[s1][s2] outp){
for i1=[0:s1]{
for i2=[0:s2]{
if (axis==0){
if (i1 < inp1s1){
outp[i1][i2] = inp1[i1][i2];
}
else{
outp[i1][i2] = inp2[i1-inp1s1][i2];
};
}
else{
if (i2 < inp1s2){
outp[i1][i2] = inp1[i1][i2];
}
else{
outp[i1][i2] = inp2[i1][i2-inp1s2];
};
};
};
};
}
def void Concat2T444(int32_pl s1, int32_pl s2, int32_pl s3, int32_pl s4, int32_pl inp1s1, int32_pl inp1s2, int32_pl inp1s3, int32_pl inp1s4, int32_al[inp1s1][inp1s2][inp1s3][inp1s4] inp1, int32_pl inp2s1, int32_pl inp2s2, int32_pl inp2s3, int32_pl inp2s4, int32_al[inp2s1][inp2s2][inp2s3][inp2s4] inp2, int32_pl axis, int32_al[s1][s2][s3][s4] outp){
for i1=[0:s1]{
for i2=[0:s2]{
for i3=[0:s3]{
for i4=[0:s4]{
if (axis==0){
if (i1 < inp1s1){
outp[i1][i2][i3][i4] = inp1[i1][i2][i3][i4];
}
else{
outp[i1][i2][i3][i4] = inp2[i1-inp1s1][i2][i3][i4];
};
}
else{
if (axis==1){
if (i2 < inp1s2){
outp[i1][i2][i3][i4] = inp1[i1][i2][i3][i4];
}
else{
outp[i1][i2][i3][i4] = inp2[i1][i2-inp1s2][i3][i4];
};
}
else{
if (axis==2){
if (i3 < inp1s3){
outp[i1][i2][i3][i4] = inp1[i1][i2][i3][i4];
}
else{
outp[i1][i2][i3][i4] = inp2[i1][i2][i3-inp1s3][i4];
};
}
else{
(* axis = 3 *)
if (i4 < inp1s4){
outp[i1][i2][i3][i4] = inp1[i1][i2][i3][i4];
}
else{
outp[i1][i2][i3][i4] = inp2[i1][i2][i3][i4 - inp1s4];
};
};
};
};
};
};
};
};
}
(**************************)
def void Split44(int32_pl O1, int32_pl O2, int32_pl O3, int32_pl O4, int32_pl I1, int32_pl I2, int32_pl I3, int32_pl I4, int32_al[I1][I2][I3][I4] inp, int32_pl axis, int32_pl curCount, int32_pl total, int32_al[O1][O2][O3][O4] out){
for o1=[0:O1]{
for o2=[0:O2]{
for o3=[0:O3]{
for o4=[0:O4]{
int32_pl i1 = o1;
int32_pl i2 = o2;
int32_pl i3 = o3;
int32_pl i4 = o4;
if(axis == 0){
i1 = (I1/total)*curCount+o1;
};
if(axis == 1){
i2 = (I2/total)*curCount+o2;
};
if(axis == 2){
i3 = (I3/total)*curCount+o3;
};
if(axis == 3){
i4 = (I4/total)*curCount+o4;
};
out[o1][o2][o3][o4] = inp[i1][i2][i3][i4];
};
};
};
};
}
(**************************)
(* Generic implementation of Conv2D *)
def void Conv2DReshapeFilter(int32_pl FH, int32_pl FW, int32_pl CI, int32_pl CO, int32_al[FH][FW][CI][CO] inputArr, int32_al[CO][FH*FW*CI] outputArr){
for co=[0:CO]{
for fh=[0:FH]{
for fw=[0:FW]{
for ci=[0:CI]{
int32_pl linIdx = (fh*FW*CI) + (fw*CI) + ci;
outputArr[co][linIdx] = inputArr[fh][fw][ci][co];
};
};
};
};
}
def void Conv2DReshapeMatMulOP(int32_pl N, int32_pl finalH, int32_pl finalW, int32_pl CO, int32_al[CO][N*finalH*finalW] inputArr, int32_al[N][finalH][finalW][CO] outputArr){
for co=[0:CO]{
for n=[0:N]{
for h=[0:finalH]{
for w=[0:finalW]{
outputArr[n][h][w][co] = inputArr[co][(n*finalH*finalW) + (h*finalW) + w];
};
};
};
};
}
def void Conv2DReshapeInput(int32_pl N, int32_pl H, int32_pl W, int32_pl CI, int32_pl FH, int32_pl FW, int32_pl zPadHLeft, int32_pl zPadHRight, int32_pl zPadWLeft, int32_pl zPadWRight, int32_pl strideH, int32_pl strideW, int32_pl RRows, int32_pl RCols, int32_al[N][H][W][CI] inputArr, int32_al[RRows][RCols] outputArr){
int32_pl linIdxFilterMult = 0;
for n=[0:N]{
int32_pl leftTopCornerH = 0 - zPadHLeft;
int32_pl extremeRightBottomCornerH = H - 1 + zPadHRight;
while((leftTopCornerH + FH - 1) <= extremeRightBottomCornerH){
int32_pl leftTopCornerW = 0 - zPadWLeft;
int32_pl extremeRightBottomCornerW = W - 1 + zPadWRight;
while((leftTopCornerW + FW - 1) <= extremeRightBottomCornerW){
for fh=[0:FH]{
for fw=[0:FW]{
int32_pl curPosH = leftTopCornerH + fh;
int32_pl curPosW = leftTopCornerW + fw;
int32_al val = 0;
for ci=[0:CI]{
if ((((curPosH < 0) || (curPosH >= H)) || ((curPosW < 0) || (curPosW >= W)))){
val = 0;
}
else{
val = inputArr[n][curPosH][curPosW][ci];
};
outputArr[(fh*FW*CI) + (fw*CI) + ci][linIdxFilterMult] = val;
};
};
};
linIdxFilterMult = linIdxFilterMult + 1;
leftTopCornerW = leftTopCornerW + strideW;
};
leftTopCornerH = leftTopCornerH + strideH;
};
};
}
(* int32_al[N][H][W][CI] inputArr,
int32_al[FH][FW][CI][CO] filterArr,
int32_al[N][((H-FH+zPadHLeft+zPadHRight)/strideH)+1][((W-FW+zPadWLeft+zPadWRight)/strideW)+1][CO] outArr
*)
def void Conv2D(int32_pl N, int32_pl H, int32_pl W, int32_pl CI,
int32_pl FH, int32_pl FW, int32_pl CO,
int32_pl zPadHLeft, int32_pl zPadHRight, int32_pl zPadWLeft, int32_pl zPadWRight,
int32_pl strideH, int32_pl strideW,
int32_al[N][H][W][CI] inputArr,
int32_al[FH][FW][CI][CO] filterArr,
int32_al[N][((H-FH+(zPadHLeft+zPadHRight))/strideH)+1][((W-FW+(zPadWLeft+zPadWRight))/strideW)+1][CO] outArr)
{
int32_pl reshapedFilterRows = CO;
int32_pl reshapedFilterCols = FH*FW*CI;
int32_pl reshapedIPRows = FH*FW*CI;
int32_pl newH = (((H + (zPadHLeft+zPadHRight) - FH)/strideH) + 1);
int32_pl newW = (((W + (zPadWLeft+zPadWRight) - FW)/strideW) + 1);
int32_pl reshapedIPCols = N * newH * newW;
int32_al[reshapedFilterRows][reshapedFilterCols] filterReshaped;
int32_al[reshapedIPRows][reshapedIPCols] inputReshaped;
int32_al[reshapedFilterRows][reshapedIPCols] matmulOP;
Conv2DReshapeFilter(FH, FW, CI, CO, filterArr, filterReshaped);
Conv2DReshapeInput(N, H, W, CI, FH, FW, zPadHLeft, zPadHRight, zPadWLeft, zPadWRight, strideH, strideW, reshapedIPRows, reshapedIPCols, inputArr, inputReshaped);
MatMul2D(reshapedFilterRows, reshapedFilterCols, reshapedIPCols, filterReshaped, inputReshaped, matmulOP, true);
Conv2DReshapeMatMulOP(N, newH, newW, CO, matmulOP, outArr);
ClearMemSecret2(reshapedFilterRows, reshapedFilterCols, filterReshaped);
ClearMemSecret2(reshapedIPRows, reshapedIPCols, inputReshaped);
ClearMemSecret2(reshapedFilterRows, reshapedIPCols, matmulOP);
}
(**************************)
(* Loop-based implementation of Conv2D *)
(* These loop implementations of convolution run faster with multithreading *)
def void Conv2DLoopInner(int32_pl N, int32_pl H, int32_pl W, int32_pl CI,
int32_pl FH, int32_pl FW, int32_pl CO,
int32_pl zPadHLeft, int32_pl zPadHRight, int32_pl zPadWLeft, int32_pl zPadWRight,
int32_pl strideH, int32_pl strideW,
int32_pl outH, int32_pl outW, int32_pl G,
int32_al[N][H][W][CI] inputArr,
int32_al[FH][FW][CI/G][CO] filterArr,
int32_al[N][outH][outW][CO] outArr)
{
int32_pl GIS = CI/G;
int32_pl GOS = CO/G;
for n=[0:N]{
for cog=[0:GOS]{
for cig=[0:GIS]{
for g=[0:G]{
for h=[0:outH]{
for w=[0:outW]{
int32_al val = 0;
int32_pl ci = GIS*g + cig;
int32_pl co = GOS*g + cog;
int32_pl curPosH = strideH*h-zPadHLeft;
for fh=[0:FH]{
int32_pl curPosW = strideW*w-zPadWLeft;
for fw=[0:FW]{
if( (curPosH >= 0) && (curPosW >= 0) && (curPosH < H) && (curPosW < W)){
val = val +_al (inputArr[n][curPosH][curPosW][ci]*filterArr[fh][fw][(ci/G)][co]);
};
curPosW = curPosW + 1;
};
curPosH = curPosH + 1;
};
outArr[n][h][w][co] = outArr[n][h][w][co] + val;
};
};
};
};
};
};
}
(* int32_al[N][H][W][CI] inputArr,
int32_al[FH][FW][CI][CO] filterArr,
int32_al[N][((H-FH+zPadHLeft+zPadHRight)/strideH)+1][((W-FW+zPadWLeft+zPadWRight)/strideW)+1][CO] outArr
*)
def void Conv2DLoop(int32_pl N, int32_pl H, int32_pl W, int32_pl CI,
int32_pl FH, int32_pl FW, int32_pl CO,
int32_pl zPadHLeft, int32_pl zPadHRight, int32_pl zPadWLeft, int32_pl zPadWRight,
int32_pl strideH, int32_pl strideW, int32_pl G,
int32_al[N][H][W][CI] inputArr,
int32_al[FH][FW][CI][CO] filterArr,
int32_al[N][((H-FH+(zPadHLeft+zPadHRight))/strideH)+1][((W-FW+(zPadWLeft+zPadWRight))/strideW)+1][CO] outArr)
{
int32_pl outH = ((H-FH+(zPadHLeft+zPadHRight))/strideH)+1;
int32_pl outW = ((W-FW+(zPadWLeft+zPadWRight))/strideW)+1;
Conv2DLoopInner(N, H, W, CI, FH, FW, CO, zPadHLeft, zPadHRight, zPadWLeft, zPadWRight, strideH, strideW, outH, outW, G, inputArr, filterArr, outArr);
}
(**************************)
(* Generic implementation of Conv2D with Groups *)
def void Conv2DReshapeFilterGroup(int32_pl FH, int32_pl FW, int32_pl CI, int32_pl CO, int32_pl g, int32_pl G, int32_al[FH][FW][CI/G][CO] inputArr, int32_al[CO/G][FH*FW*(CI/G)] outputArr){
int32_pl CIG = CI/G;
int32_pl COG = CO/G;
int32_pl startCO = g*COG;
for co=[0:COG]{
for fh=[0:FH]{
for fw=[0:FW]{
for ci=[0:CIG]{
int32_pl linIdx = (fh*FW*CIG) + (fw*CIG) + ci;
outputArr[co][linIdx] = inputArr[fh][fw][ci][co+startCO];
};
};
};
};
}
def void Conv2DReshapeMatMulOPGroup(int32_pl N, int32_pl finalH, int32_pl finalW, int32_pl CO, int32_pl g, int32_pl G, int32_al[CO/G][N*finalH*finalW] inputArr, int32_al[N][finalH][finalW][CO] outputArr){
int32_pl COG = CO/G;
int32_pl startCO = g*COG;
for co=[0:COG]{
for n=[0:N]{
for h=[0:finalH]{
for w=[0:finalW]{
outputArr[n][h][w][co+startCO] = inputArr[co][(n*finalH*finalW) + (h*finalW) + w];
};
};
};
};
}
def void Conv2DReshapeInputGroup(int32_pl N, int32_pl H, int32_pl W, int32_pl CI, int32_pl FH, int32_pl FW, int32_pl zPadHLeft, int32_pl zPadHRight, int32_pl zPadWLeft, int32_pl zPadWRight, int32_pl strideH, int32_pl strideW, int32_pl g, int32_pl G, int32_pl RRows, int32_pl RCols, int32_al[N][H][W][CI] inputArr, int32_al[RRows][RCols] outputArr){
int32_pl linIdxFilterMult = 0;
int32_pl CIG = CI/G;
for n=[0:N]{
int32_pl leftTopCornerH = 0 - zPadHLeft;
int32_pl extremeRightBottomCornerH = H - 1 + zPadHRight;
while((leftTopCornerH + FH - 1) <= extremeRightBottomCornerH){
int32_pl leftTopCornerW = 0 - zPadWLeft;
int32_pl extremeRightBottomCornerW = W - 1 + zPadWRight;
while((leftTopCornerW + FW - 1) <= extremeRightBottomCornerW){
for fh=[0:FH]{
for fw=[0:FW]{
int32_pl curPosH = leftTopCornerH + fh;
int32_pl curPosW = leftTopCornerW + fw;
int32_al val = 0;
int32_pl startCI = g*CIG;
for ci=[0:CIG]{
if ((((curPosH < 0) || (curPosH >= H)) || ((curPosW < 0) || (curPosW >= W)))){
val = 0;
}
else{
val = inputArr[n][curPosH][curPosW][ci+startCI];
};
outputArr[(fh*FW*CIG) + (fw*CIG) + ci][linIdxFilterMult] = val;
};
};
};
linIdxFilterMult = linIdxFilterMult + 1;
leftTopCornerW = leftTopCornerW + strideW;
};
leftTopCornerH = leftTopCornerH + strideH;
};
};
}
(* int32_al[N][H][W][CI] inputArr,
int32_al[FH][FW][CI][CO] filterArr,
int32_al[N][((H-FH+zPadHLeft+zPadHRight)/strideH)+1][((W-FW+zPadWLeft+zPadWRight)/strideW)+1][CO] outArr
*)
def void Conv2DGroup(int32_pl N, int32_pl H, int32_pl W, int32_pl CI,
int32_pl FH, int32_pl FW, int32_pl CO,
int32_pl zPadHLeft, int32_pl zPadHRight, int32_pl zPadWLeft, int32_pl zPadWRight,
int32_pl strideH, int32_pl strideW, int32_pl G,
int32_al[N][H][W][CI] inputArr,
int32_al[FH][FW][CI/G][CO] filterArr,
int32_al[N][((H-FH+(zPadHLeft+zPadHRight))/strideH)+1][((W-FW+(zPadWLeft+zPadWRight))/strideW)+1][CO] outArr)
{
int32_pl CIG = CI/G;
int32_pl reshapedFilterRows = CO/G;
int32_pl reshapedFilterCols = FH*FW*CIG;
int32_pl reshapedIPRows = FH*FW*CIG;
int32_pl outH = (((H + (zPadHLeft+zPadHRight) - FH)/strideH) + 1);
int32_pl outW = (((W + (zPadWLeft+zPadWRight) - FW)/strideW) + 1);
int32_pl reshapedIPCols = N * outH * outW;
for g=[0:G]{
int32_al[reshapedIPRows][reshapedIPCols] inputReshaped;
int32_al[reshapedFilterRows][reshapedIPCols] matmulOP;
int32_al[reshapedFilterRows][reshapedFilterCols] filterReshaped;
Conv2DReshapeFilterGroup(FH, FW, CI, CO, g, G, filterArr, filterReshaped);
Conv2DReshapeInputGroup(N, H, W, CI, FH, FW, zPadHLeft, zPadHRight, zPadWLeft, zPadWRight, strideH, strideW, g, G, reshapedIPRows, reshapedIPCols, inputArr, inputReshaped);
MatMul2D(reshapedFilterRows, reshapedFilterCols, reshapedIPCols, filterReshaped, inputReshaped, matmulOP, true);
Conv2DReshapeMatMulOPGroup(N, outH, outW, CO, g, G, matmulOP, outArr);
ClearMemSecret2(reshapedIPRows, reshapedIPCols, inputReshaped);
ClearMemSecret2(reshapedFilterRows, reshapedIPCols, matmulOP);
ClearMemSecret2(reshapedFilterRows, reshapedFilterCols, filterReshaped);
}
}
(**************************)
(* Generic implementation of Conv3D *)
def void Conv3DReshapeFilter(int32_pl FD, int32_pl FH, int32_pl FW, int32_pl CI, int32_pl CO, int32_al[FD][FH][FW][CI][CO] inputArr, int32_al[CO][FD*FH*FW*CI] outputArr){
for co=[0:CO]{
for fd=[0:FD]{
for fh=[0:FH]{
for fw=[0:FW]{
for ci=[0:CI]{
int32_pl linIdx = (fd*FH*FW*CI) + (fh*FW*CI) + (fw*CI) + ci;
outputArr[co][linIdx] = inputArr[fd][fh][fw][ci][co];
};
};
};
};
};
}
def void Conv3DReshapeMatMulOP(int32_pl N, int32_pl finalD, int32_pl finalH, int32_pl finalW, int32_pl CO, int32_al[CO][N*finalD*finalH*finalW] inputArr, int32_al[N][finalD][finalH][finalW][CO] outputArr){
for co=[0:CO]{
for n=[0:N]{
for d=[0:finalD]{
for h=[0:finalH]{
for w=[0:finalW]{
outputArr[n][d][h][w][co] = inputArr[co][(n*finalD*finalH*finalW) + (d*finalH*finalW) + (h*finalW) + w];
};
};
};
};
};
}
def void Conv3DReshapeInput(int32_pl N, int32_pl D, int32_pl H, int32_pl W, int32_pl CI, int32_pl FD, int32_pl FH, int32_pl FW, int32_pl zPadDLeft, int32_pl zPadDRight, int32_pl zPadHLeft, int32_pl zPadHRight, int32_pl zPadWLeft, int32_pl zPadWRight, int32_pl strideD, int32_pl strideH, int32_pl strideW, int32_pl RRows, int32_pl RCols, int32_al[N][D][H][W][CI] inputArr, int32_al[RRows][RCols] outputArr){
int32_pl linIdxFilterMult = 0;
for n=[0:N]{
int32_pl leftTopCornerD = 0 - zPadDLeft;
int32_pl extremeRightBottomCornerD = D - 1 + zPadDRight;
while((leftTopCornerD + FD - 1) <= extremeRightBottomCornerD){
int32_pl leftTopCornerH = 0 - zPadHLeft;
int32_pl extremeRightBottomCornerH = H - 1 + zPadHRight;
while((leftTopCornerH + FH - 1) <= extremeRightBottomCornerH){
int32_pl leftTopCornerW = 0 - zPadWLeft;
int32_pl extremeRightBottomCornerW = W - 1 + zPadWRight;
while((leftTopCornerW + FW - 1) <= extremeRightBottomCornerW){
for fd=[0:FD]{
for fh=[0:FH]{
for fw=[0:FW]{
int32_pl curPosD = leftTopCornerD + fd;
int32_pl curPosH = leftTopCornerH + fh;
int32_pl curPosW = leftTopCornerW + fw;
int32_al val = 0;
for ci=[0:CI]{
if ((((curPosD < 0) || (curPosD >= D)) || ((curPosH < 0) || (curPosH >= H)) || ((curPosW < 0) || (curPosW >= W)))){
val = 0;
}
else{
val = inputArr[n][curPosD][curPosH][curPosW][ci];
};
outputArr[(fd*FH*FW*CI) + (fh*FW*CI) + (fw*CI) + ci][linIdxFilterMult] = val;
};
};
};
};
linIdxFilterMult = linIdxFilterMult + 1;
leftTopCornerW = leftTopCornerW + strideW;
};
leftTopCornerH = leftTopCornerH + strideH;
};
leftTopCornerD = leftTopCornerD + strideD;
};
};
}
(* int32_al[N][D][H][W][CI] inputArr,
int32_al[FD][FH][FW][CI][CO] filterArr,
int32_al[N][((D-FD+zPadDLeft+zPadDRight)/strideD)+1][((H-FH+zPadHLeft+zPadHRight)/strideH)+1][((W-FW+zPadWLeft+zPadWRight)/strideW)+1][CO] outArr
*)
def void Conv3D(int32_pl N, int32_pl D, int32_pl H, int32_pl W, int32_pl CI,
int32_pl FD, int32_pl FH, int32_pl FW, int32_pl CO,
int32_pl zPadDLeft, int32_pl zPadDRight, int32_pl zPadHLeft, int32_pl zPadHRight, int32_pl zPadWLeft, int32_pl zPadWRight,
int32_pl strideD, int32_pl strideH, int32_pl strideW,
int32_al[N][D][H][W][CI] inputArr,
int32_al[FD][FH][FW][CI][CO] filterArr,
int32_al[N][((D-FD+(zPadDLeft+zPadDRight))/strideD)+1][((H-FH+(zPadHLeft+zPadHRight))/strideH)+1][((W-FW+(zPadWLeft+zPadWRight))/strideW)+1][CO] outArr)
{
int32_pl reshapedFilterRows = CO;
int32_pl reshapedFilterCols = FD*FH*FW*CI;
int32_pl reshapedIPRows = FD*FH*FW*CI;
int32_pl newD = (((D + (zPadDLeft+zPadDRight) - FD)/strideD) + 1);
int32_pl newH = (((H + (zPadHLeft+zPadHRight) - FH)/strideH) + 1);
int32_pl newW = (((W + (zPadWLeft+zPadWRight) - FW)/strideW) + 1);
int32_pl reshapedIPCols = N * newD * newH * newW;
int32_al[reshapedFilterRows][reshapedFilterCols] filterReshaped;
int32_al[reshapedIPRows][reshapedIPCols] inputReshaped;
int32_al[reshapedFilterRows][reshapedIPCols] matmulOP;
Conv3DReshapeFilter(FD, FH, FW, CI, CO, filterArr, filterReshaped);
Conv3DReshapeInput(N, D, H, W, CI, FD, FH, FW, zPadDLeft, zPadDRight, zPadHLeft, zPadHRight, zPadWLeft, zPadWRight, strideD, strideH, strideW, reshapedIPRows, reshapedIPCols, inputArr, inputReshaped);
MatMul2D(reshapedFilterRows, reshapedFilterCols, reshapedIPCols, filterReshaped, inputReshaped, matmulOP, true);
Conv3DReshapeMatMulOP(N, newD, newH, newW, CO, matmulOP, outArr);
ClearMemSecret2(reshapedFilterRows, reshapedFilterCols, filterReshaped);
ClearMemSecret2(reshapedIPRows, reshapedIPCols, inputReshaped);
ClearMemSecret2(reshapedFilterRows, reshapedIPCols, matmulOP);
}
(**************************)
(* Loop-based implementation of Conv3D *)
(* Loop implementation of convolution run faster with multithreading *)
def void Conv3DLoopInner(int32_pl N, int32_pl D, int32_pl H, int32_pl W, int32_pl CI,
int32_pl FD, int32_pl FH, int32_pl FW, int32_pl CO,
int32_pl zPadDLeft, int32_pl zPadDRight,int32_pl zPadHLeft, int32_pl zPadHRight, int32_pl zPadWLeft, int32_pl zPadWRight,
int32_pl strideD, int32_pl strideH, int32_pl strideW,
int32_pl outD, int32_pl outH, int32_pl outW,
int32_al[N][D][H][W][CI] inputArr,
int32_al[FD][FH][FW][CI][CO] filterArr,
int32_al[N][outD][outH][outW][CO] outArr)
{
for n=[0:N]{
for co=[0:CO]{
for d=[0:outD]{
for h=[0:outH]{
for w=[0:outW]{
for ci=[0:CI]{
int32_al val = 0;
for fd=[d*strideD:d*strideD+FD]{
for fh=[h*strideH:h*strideH+FH]{
for fw=[w*strideW:w*strideW+FW]{
int32_pl curPosD = fd-zPadDLeft;
int32_pl curPosH = fh-zPadHLeft;
int32_pl curPosW = fw-zPadWLeft;
if( (curPosD >= 0) && (curPosH >= 0) && (curPosW >= 0) && (curPosD < D) && (curPosH < H) && (curPosW < W)){
int32_pl curFilterPosD = fd-(d*strideD);
int32_pl curFilterPosH = fh-(h*strideH);
int32_pl curFilterPosW = fw-(w*strideW);
val = val +_al (inputArr[n][curPosD][curPosH][curPosW][ci]*filterArr[curFilterPosD][curFilterPosH][curFilterPosW][ci][co]);
};
};
};
};
outArr[n][d][h][w][co] = outArr[n][d][h][w][co] + val;
};
};
};
};
};
};
}
(* int32_al[N][D][H][W][CI] inputArr,
int32_al[FD][FH][FW][CI][CO] filterArr,
int32_al[N][((D-FD+zPadDLeft+zPadDRight)/strideD)+1][((H-FH+zPadHLeft+zPadHRight)/strideH)+1][((W-FW+zPadWLeft+zPadWRight)/strideW)+1][CO] outArr
*)
def void Conv3DLoop(int32_pl N, int32_pl D, int32_pl H, int32_pl W, int32_pl CI,
int32_pl FD, int32_pl FH, int32_pl FW, int32_pl CO,
int32_pl zPadDLeft, int32_pl zPadDRight, int32_pl zPadHLeft, int32_pl zPadHRight, int32_pl zPadWLeft, int32_pl zPadWRight,
int32_pl strideD, int32_pl strideH, int32_pl strideW,
int32_al[N][D][H][W][CI] inputArr,
int32_al[FD][FH][FW][CI][CO] filterArr,
int32_al[N][((D-FD+(zPadDLeft+zPadDRight))/strideD)+1][((H-FH+(zPadHLeft+zPadHRight))/strideH)+1][((W-FW+(zPadWLeft+zPadWRight))/strideW)+1][CO] outArr)
{
int32_pl outD = ((D-FD+(zPadDLeft+zPadDRight))/strideD)+1;
int32_pl outH = ((H-FH+(zPadHLeft+zPadHRight))/strideH)+1;
int32_pl outW = ((W-FW+(zPadWLeft+zPadWRight))/strideW)+1;
Conv3DLoopInner(N, D, H, W, CI, FD, FH, FW, CO, zPadDLeft, zPadDRight, zPadHLeft, zPadHRight, zPadWLeft, zPadWRight, strideD, strideH, strideW, outD, outH, outW, inputArr, filterArr, outArr);
}
(**************************)
(* Generic implementation of ConvTranpose2D *)
def void ConvTranspose2DReshapeMatMulOP(int32_pl N, int32_pl finalH, int32_pl finalW, int32_pl CO, int32_al[CO][N*finalH*finalW] inputArr, int32_al[N][finalH][finalW][CO] outputArr){
for co=[0:CO]{
for n=[0:N]{
for h=[0:finalH]{
for w=[0:finalW]{
outputArr[n][h][w][co] = inputArr[co][(n*finalH*finalW) + (h*finalW) + w];
};
};
};
};
}
def void ConvTranspose2DReshapeFilter(int32_pl FH, int32_pl FW, int32_pl CO, int32_pl CI, int32_al[FH][FW][CO][CI] inputArr, int32_al[CO][FH*FW*CI] outputArr)
{
for co=[0:CO]{
for fh=[0:FH]{
for fw=[0:FW]{
for ci=[0:CI]{
int32_pl linIdx = (fh*FW*CI) + (fw*CI) + ci;
outputArr[co][linIdx] = inputArr[FH-1-fh][FW-1-fw][co][ci];
};
};
};
};
}
def void ConvTranspose2DReshapeInput(int32_pl N, int32_pl HPrime, int32_pl WPrime, int32_pl CI, int32_pl FH, int32_pl FW, int32_pl zPadTrHLeft, int32_pl zPadTrHRight, int32_pl zPadTrWLeft, int32_pl zPadTrWRight, int32_pl strideH, int32_pl strideW, int32_pl RRows, int32_pl RCols, int32_al[N][HPrime][WPrime][CI] inputArr, int32_al[RRows][RCols] outputArr){
int32_pl linIdxFilterMult = 0;
for n=[0:N]{
int32_pl leftTopCornerH = 0 - zPadTrHLeft;
int32_pl HPrimeTilde = HPrime + ((HPrime-1)*(strideH-1));
int32_pl extremeRightBottomCornerH = HPrimeTilde - 1 + zPadTrHRight;
while((leftTopCornerH + FH - 1) <= extremeRightBottomCornerH){
int32_pl leftTopCornerW = 0 - zPadTrWLeft;
int32_pl WPrimeTilde = WPrime + ((WPrime-1)*(strideW-1));
int32_pl extremeRightBottomCornerW = WPrimeTilde - 1 + zPadTrWRight;
while((leftTopCornerW + FW - 1) <= extremeRightBottomCornerW){
for fh=[0:FH]{
for fw=[0:FW]{
int32_pl curPosH = leftTopCornerH + fh;
int32_pl curPosW = leftTopCornerW + fw;
int32_al val = 0;
for ci=[0:CI]{
if ((((curPosH < 0) || (curPosH >= HPrimeTilde)) || ((curPosW < 0) || (curPosW >= WPrimeTilde)))){
val = 0;
}
else{
(* curPosH lies between 0 and HPrimeTilde *)
if (((curPosH % strideH) == 0) && ((curPosW % strideW) == 0)) {
int32_pl idxInputH = curPosH / strideH;
int32_pl idxInputW = curPosW / strideW;
val = inputArr[n][idxInputH][idxInputW][ci];
}
else{
val = 0; (* This represents fractional stride. *)
};
};
outputArr[(fh*FW*CI) + (fw*CI) + ci][linIdxFilterMult] = val;
};
};
};
linIdxFilterMult = linIdxFilterMult + 1;
leftTopCornerW = leftTopCornerW + 1; (* Imp Note: The actual stride is always 1 *)
};
leftTopCornerH = leftTopCornerH + 1; (* Imp Note: The actual stride is always 1 *)
};
};
}
(* int32_al[N][HPrime][WPrime][CI] inputArr,
int32_al[FH][FW][CO][CI] filter,
int32_al[N][H][W][CO] outputArr
*)
def void ConvTranspose2D(int32_pl N, int32_pl HPrime, int32_pl WPrime, int32_pl CI,
int32_pl FH, int32_pl FW, int32_pl CO,
int32_pl H, int32_pl W,
int32_pl zPadTrHLeft, int32_pl zPadTrHRight, int32_pl zPadTrWLeft, int32_pl zPadTrWRight,
int32_pl strideH, int32_pl strideW,
int32_al[N][HPrime][WPrime][CI] inputArr,
int32_al[FH][FW][CO][CI] filterArr,
int32_al[N][H][W][CO] outArr)
{
int32_pl reshapedFilterRows = CO;
int32_pl reshapedFilterCols = FH*FW*CI;
int32_pl reshapedIPRows = FH*FW*CI;
int32_pl reshapedIPCols = N * H * W;
int32_al[reshapedFilterRows][reshapedFilterCols] filterReshaped;
int32_al[reshapedIPRows][reshapedIPCols] inputReshaped;
int32_al[reshapedFilterRows][reshapedIPCols] matmulOP;
ConvTranspose2DReshapeFilter(FH, FW, CO, CI, filterArr, filterReshaped);
ConvTranspose2DReshapeInput(N, HPrime, WPrime, CI, FH, FW, zPadTrHLeft, zPadTrHRight, zPadTrWLeft, zPadTrWRight, strideH, strideW, reshapedIPRows, reshapedIPCols, inputArr, inputReshaped);
MatMul2D(reshapedFilterRows, reshapedFilterCols, reshapedIPCols, filterReshaped, inputReshaped, matmulOP, true);
ConvTranspose2DReshapeMatMulOP(N, H, W, CO, matmulOP, outArr);
ClearMemSecret2(reshapedFilterRows, reshapedFilterCols, filterReshaped);
ClearMemSecret2(reshapedIPRows, reshapedIPCols, inputReshaped);
ClearMemSecret2(reshapedFilterRows, reshapedIPCols, matmulOP);
}
(**************************)
(* Generic implementation of ConvTranpose3D *)
def void ConvTranspose3DReshapeFilter(int32_pl FD, int32_pl FH, int32_pl FW, int32_pl CO, int32_pl CI, int32_al[FD][FH][FW][CO][CI] inputArr, int32_al[CO][FD*FH*FW*CI] outputArr)
{
for co=[0:CO]{
for fd=[0:FD]{
for fh=[0:FH]{
for fw=[0:FW]{
for ci=[0:CI]{
int32_pl linIdx = (fd*FH*FW*CI) + (fh*FW*CI) + (fw*CI) + ci;
outputArr[co][linIdx] = inputArr[FD-1-fd][FH-1-fh][FW-1-fw][co][ci];
};
};
};
};
};
}
def void ConvTranspose3DReshapeInput(int32_pl N, int32_pl DPrime, int32_pl HPrime, int32_pl WPrime, int32_pl CI, int32_pl FD, int32_pl FH, int32_pl FW, int32_pl zPadTrDLeft, int32_pl zPadTrDRight, int32_pl zPadTrHLeft, int32_pl zPadTrHRight, int32_pl zPadTrWLeft, int32_pl zPadTrWRight, int32_pl strideD, int32_pl strideH, int32_pl strideW, int32_pl RRows, int32_pl RCols, int32_al[N][DPrime][HPrime][WPrime][CI] inputArr, int32_al[RRows][RCols] outputArr){
int32_pl linIdxFilterMult = 0;
for n=[0:N]{
int32_pl leftTopCornerD = 0 - zPadTrDLeft;
int32_pl DPrimeTilde = DPrime + ((DPrime-1)*(strideD-1));
int32_pl extremeRightBottomCornerD = DPrimeTilde - 1 + zPadTrDRight;
while((leftTopCornerD + FD - 1) <= extremeRightBottomCornerD){
int32_pl leftTopCornerH = 0 - zPadTrHLeft;
int32_pl HPrimeTilde = HPrime + ((HPrime-1)*(strideH-1));
int32_pl extremeRightBottomCornerH = HPrimeTilde - 1 + zPadTrHRight;
while((leftTopCornerH + FH - 1) <= extremeRightBottomCornerH){
int32_pl leftTopCornerW = 0 - zPadTrWLeft;
int32_pl WPrimeTilde = WPrime + ((WPrime-1)*(strideW-1));
int32_pl extremeRightBottomCornerW = WPrimeTilde - 1 + zPadTrWRight;
while((leftTopCornerW + FW - 1) <= extremeRightBottomCornerW){
for fd=[0:FD]{
for fh=[0:FH]{
for fw=[0:FW]{
int32_pl curPosD = leftTopCornerD + fd;
int32_pl curPosH = leftTopCornerH + fh;
int32_pl curPosW = leftTopCornerW + fw;
int32_al val = 0;
for ci=[0:CI]{
if (((curPosD < 0) || (curPosD >= DPrimeTilde)) || ((curPosH < 0) || (curPosH >= HPrimeTilde)) || ((curPosW < 0) || (curPosW >= WPrimeTilde))) {
val = 0;
}
else{
(* curPosH lies between 0 and HPrimeTilde *)
if (((curPosD % strideD) == 0) && ((curPosH % strideH) == 0) && ((curPosW % strideW) == 0)) {
int32_pl idxInputD = curPosD / strideD;
int32_pl idxInputH = curPosH / strideH;
int32_pl idxInputW = curPosW / strideW;
val = inputArr[n][idxInputD][idxInputH][idxInputW][ci];
}
else{
val = 0; (* This represents fractional stride. *)
};
};
outputArr[(fd*FH*FW*CI) + (fh*FW*CI) + (fw*CI) + ci][linIdxFilterMult] = val;
};
};
};
};
linIdxFilterMult = linIdxFilterMult + 1;
leftTopCornerW = leftTopCornerW + 1; (* Imp Note: The actual stride is always 1 *)
};
leftTopCornerH = leftTopCornerH + 1; (* Imp Note: The actual stride is always 1 *)
};
leftTopCornerD = leftTopCornerD + 1; (* Imp Note: The actual stride is always 1 *)
};
};
}
(* int32_al[N][DPrime][HPrime][WPrime][CI] inputArr,
int32_al[FD][FH][FW][CO][CI] filter,
int32_al[N][D][H][W][CO] outputArr
*)
def void ConvTranspose3D(int32_pl N, int32_pl DPrime, int32_pl HPrime, int32_pl WPrime, int32_pl CI,
int32_pl FD, int32_pl FH, int32_pl FW, int32_pl CO,
int32_pl D, int32_pl H, int32_pl W,
int32_pl zPadTrDLeft, int32_pl zPadTrDRight, int32_pl zPadTrHLeft, int32_pl zPadTrHRight, int32_pl zPadTrWLeft, int32_pl zPadTrWRight,
int32_pl strideD, int32_pl strideH, int32_pl strideW,
int32_al[N][DPrime][HPrime][WPrime][CI] inputArr,
int32_al[FD][FH][FW][CO][CI] filterArr,
int32_al[N][D][H][W][CO] outArr)
{
int32_pl reshapedFilterRows = CO;
int32_pl reshapedFilterCols = FD*FH*FW*CI;
int32_pl reshapedIPRows = FD*FH*FW*CI;
int32_pl reshapedIPCols = N * D * H * W;
int32_al[reshapedFilterRows][reshapedFilterCols] filterReshaped;
int32_al[reshapedIPRows][reshapedIPCols] inputReshaped;
int32_al[reshapedFilterRows][reshapedIPCols] matmulOP;
ConvTranspose3DReshapeFilter(FD, FH, FW, CO, CI, filterArr, filterReshaped);
ConvTranspose3DReshapeInput(N, DPrime, HPrime, WPrime, CI, FD, FH, FW, zPadTrDLeft, zPadTrDRight, zPadTrHLeft, zPadTrHRight, zPadTrWLeft, zPadTrWRight, strideD, strideH, strideW, reshapedIPRows, reshapedIPCols, inputArr, inputReshaped);
MatMul2D(reshapedFilterRows, reshapedFilterCols, reshapedIPCols, filterReshaped, inputReshaped, matmulOP, true);
Conv3DReshapeMatMulOP(N, D, H, W, CO, matmulOP, outArr);
ClearMemSecret2(reshapedFilterRows, reshapedFilterCols, filterReshaped);
ClearMemSecret2(reshapedIPRows, reshapedIPCols, inputReshaped);
ClearMemSecret2(reshapedFilterRows, reshapedIPCols, matmulOP);
}
(**************************)
(* Loop-based implementation of ConvTranspose3D *)
def void ConvTranspose3DLoopInner(int32_pl N, int32_pl D, int32_pl H, int32_pl W, int32_pl CI,
int32_pl FD, int32_pl FH, int32_pl FW, int32_pl CO,
int32_pl zPadDLeft, int32_pl zPadDRight,int32_pl zPadHLeft, int32_pl zPadHRight, int32_pl zPadWLeft, int32_pl zPadWRight,
int32_pl strideD, int32_pl strideH, int32_pl strideW,
int32_pl outD, int32_pl outH, int32_pl outW,
int32_al[N][D][H][W][CI] inputArr,
int32_al[FD][FH][FW][CO][CI] filterArr,
int32_al[N][outD][outH][outW][CO] outArr)
{
for n=[0:N]{
for co=[0:CO]{
for d=[0:outD]{
for h=[0:outH]{
for w=[0:outW]{
for ci=[0:CI]{
int32_al val = 0;
for fd=[d:d+FD]{
for fh=[h:h+FH]{
for fw=[w:w+FW]{
int32_pl curPosD = (fd-zPadDLeft)/strideD;
int32_pl curPosH = (fh-zPadHLeft)/strideD;
int32_pl curPosW = (fw-zPadWLeft)/strideD;
if( (curPosD >= 0) && (curPosH >= 0) && (curPosW >= 0) && (curPosD < D) && (curPosH < H) && (curPosW < W) && ((fd-zPadDLeft)%strideD == 0) && ((fh-zPadHLeft)%strideH == 0) && ((fw-zPadWLeft)%strideW == 0)){
int32_pl curFilterPosD = FD+d-fd-1;
int32_pl curFilterPosH = FH+h-fh-1;
int32_pl curFilterPosW = FW+w-fw-1;
val = val +_al (inputArr[n][curPosD][curPosH][curPosW][ci]*filterArr[curFilterPosD][curFilterPosH][curFilterPosW][co][ci]);
};
};
};
};
outArr[n][d][h][w][co] = outArr[n][d][h][w][co] + val;
};
};
};
};
};
};
}
(* int32_al[N][DPrime][HPrime][WPrime][CI] inputArr,
int32_al[FD][FH][FW][CO][CI] filter,
int32_al[N][D][H][W][CO] outputArr
*)
def void ConvTranspose3DLoop(int32_pl N, int32_pl DPrime, int32_pl HPrime, int32_pl WPrime, int32_pl CI,
int32_pl FD, int32_pl FH, int32_pl FW, int32_pl CO,
int32_pl D, int32_pl H, int32_pl W,
int32_pl zPadTrDLeft, int32_pl zPadTrDRight, int32_pl zPadTrHLeft, int32_pl zPadTrHRight, int32_pl zPadTrWLeft, int32_pl zPadTrWRight,
int32_pl strideD, int32_pl strideH, int32_pl strideW,
int32_al[N][DPrime][HPrime][WPrime][CI] inputArr,
int32_al[FD][FH][FW][CO][CI] filterArr,
int32_al[N][D][H][W][CO] outArr)
{
ConvTranspose3DLoopInner(N, DPrime, HPrime, WPrime, CI, FD, FH, FW, CO, zPadTrDLeft, zPadTrDRight, zPadTrHLeft, zPadTrHRight, zPadTrWLeft, zPadTrWRight, strideD, strideH, strideW, D, H, W, inputArr, filterArr, outArr);
}
(**************************)
def void Transpose2(int32_pl s1, int32_pl s2, int32_al[s2][s1] inArr, int32_al[s1][s2] outArr){
for i=[0:s1]{
for j=[0:s2]{
outArr[i][j] = inArr[j][i];
};
};
}
(**************************)
def void Pad442(int32_pl s1, int32_pl s2, int32_pl s3, int32_pl s4, int32_pl inps1, int32_pl inps2, int32_pl inps3, int32_pl inps4, int32_al[inps1][inps2][inps3][inps4] inpArr, int32_pl pads1, int32_pl pads2, int32_pl[pads1][pads2] paddings, int32_al[s1][s2][s3][s4] outArr){
int32_pl lbounds1 = paddings[0][0];
int32_pl rbounds1excl = s1-paddings[0][1];
int32_pl lbounds2 = paddings[1][0];
int32_pl rbounds2excl = s2-paddings[1][1];
int32_pl lbounds3 = paddings[2][0];
int32_pl rbounds3excl = s3-paddings[2][1];
int32_pl lbounds4 = paddings[3][0];
int32_pl rbounds4excl = s4-paddings[3][1];
for i=[0:s1]{
for j=[0:s2]{
for k=[0:s3]{
for l=[0:s4]{
if ((i >= lbounds1) && (i < rbounds1excl) && (j >= lbounds2) && (j < rbounds2excl) && (k >= lbounds3) && (k < rbounds3excl) && (l >= lbounds4) && (l < rbounds4excl)){
outArr[i][j][k][l] = inpArr[i-paddings[0][0]][j-paddings[1][0]][k-paddings[2][0]][l-paddings[3][0]];
}
else{
outArr[i][j][k][l] = 0;
};
};
};
};
};
}
def void Pad552(int32_pl s1, int32_pl s2, int32_pl s3, int32_pl s4, int32_pl s5, int32_pl inps1, int32_pl inps2, int32_pl inps3, int32_pl inps4, int32_pl inps5, int32_al[inps1][inps2][inps3][inps4][inps5] inpArr, int32_pl pads1, int32_pl pads2, int32_pl[pads1][pads2] paddings, int32_al[s1][s2][s3][s4][s5] outArr){
int32_pl lbounds1 = paddings[0][0];
int32_pl rbounds1excl = s1-paddings[0][1];
int32_pl lbounds2 = paddings[1][0];
int32_pl rbounds2excl = s2-paddings[1][1];
int32_pl lbounds3 = paddings[2][0];
int32_pl rbounds3excl = s3-paddings[2][1];
int32_pl lbounds4 = paddings[3][0];
int32_pl rbounds4excl = s4-paddings[3][1];
int32_pl lbounds5 = paddings[4][0];
int32_pl rbounds5excl = s5-paddings[4][1];
for i=[0:s1]{
for j=[0:s2]{
for k=[0:s3]{
for l=[0:s4]{
for m=[0:s5]{
if ((i >= lbounds1) && (i < rbounds1excl) && (j >= lbounds2) && (j < rbounds2excl) && (k >= lbounds3) && (k < rbounds3excl) && (l >= lbounds4) && (l < rbounds4excl) && (m >= lbounds5) && (m < rbounds5excl)){
outArr[i][j][k][l][m] = inpArr[i-paddings[0][0]][j-paddings[1][0]][k-paddings[2][0]][l-paddings[3][0]][m-paddings[4][0]];
}
else{
outArr[i][j][k][l][m] = 0;
};
};
};
};
};
};
}
def void PadONNX441(int32_pl o1, int32_pl o2, int32_pl o3, int32_pl o4, int32_pl i1, int32_pl i2, int32_pl i3, int32_pl i4, int32_al[i1][i2][i3][i4] inpArr, int32_pl pads, int32_pl[pads] paddings, int32_al[o1][o2][o3][o4] outArr) {
int32_pl lbounds1 = paddings[0];
int32_pl rbounds1excl = o1 - paddings[4];
int32_pl lbounds2 = paddings[1];
int32_pl rbounds2excl = o2 - paddings[5];
int32_pl lbounds3 = paddings[2];
int32_pl rbounds3excl = o3 - paddings[6];
int32_pl lbounds4 = paddings[3];
int32_pl rbounds4excl = o4 - paddings[7];
for i=[0:o1]{
for j=[0:o2]{
for k=[0:o3]{
for l=[0:o4]{
if ((i >= lbounds1) && (i < rbounds1excl) && (j >= lbounds2) && (j < rbounds2excl) && (k >= lbounds3) && (k < rbounds3excl) && (l >= lbounds4) && (l < rbounds4excl)){
outArr[i][j][k][l] = inpArr[i-paddings[0]][j-paddings[1]][k-paddings[2]][l-paddings[3]];
}
else{
outArr[i][j][k][l] = 0;
};
};
};
};
};
}
(**************************)
(* Squeeze where the input is a 4D tensor, output is a 2D tensor and hence 2 dims are getting squeezed. *)
def void Squeeze24(int32_pl s1, int32_pl s2, int32_pl dim1, int32_pl dim2, int32_pl ins1, int32_pl ins2, int32_pl ins3, int32_pl ins4, int32_al[ins1][ins2][ins3][ins4] inArr, int32_al[s1][s2] outArr)
{
(* Since num of elements are same in both arrays, therefore do a linear traversal of both and fill *)
for i=[0:ins1]{
for j=[0:ins2]{
for k=[0:ins3]{
for l=[0:ins4]{
int32_pl linIdx = (i*ins2*ins3*ins4) + (j*ins3*ins4) + (k*ins4) + l;
int32_pl outIdx1 = linIdx / s2;
int32_pl outIdx2 = linIdx % s2;
outArr[outIdx1][outIdx2] = inArr[i][j][k][l];
};
};
};
};
}
def void Squeeze34(int32_pl s1, int32_pl s2, int32_pl s3, int32_pl dim, int32_pl ins1, int32_pl ins2, int32_pl ins3, int32_pl ins4, int32_al[ins1][ins2][ins3][ins4] inArr, int32_al[s1][s2][s3] outArr)
{
(* Since num of elements are same in both arrays, therefore do a linear traversal of both and fill *)
for i=[0:ins1]{
for j=[0:ins2]{
for k=[0:ins3]{
for l=[0:ins4]{
int32_pl linIdx = (i*ins2*ins3*ins4) + (j*ins3*ins4) + (k*ins4) + l;
int32_pl outIdx1 = linIdx / (s2*s3);
int32_pl outIdx2 = (linIdx - (outIdx1*s2*s3)) / s3;
int32_pl outIdx3 = (linIdx - (outIdx1*s2*s3) - (outIdx2*s3));
outArr[outIdx1][outIdx2][outIdx3] = inArr[i][j][k][l];
};
};
};
};
}
def void Squeeze23(int32_pl s1, int32_pl s2, int32_pl dim, int32_pl ins1, int32_pl ins2, int32_pl ins3, int32_al[ins1][ins2][ins3] inArr, int32_al[s1][s2] outArr)
{
(* Since num of elements are same in both arrays, therefore do a linear traversal of both and fill *)
for i=[0:ins1]{
for j=[0:ins2]{
for k=[0:ins3]{
int32_pl linIdx = (i*ins2*ins3) + (j*ins3) + (k);
int32_pl outIdx1 = linIdx / s2;
int32_pl outIdx2 = linIdx % s2;
outArr[outIdx1][outIdx2] = inArr[i][j][k];
};
};
};
}
(**************************)
def void FusedBatchNorm4411(int32_pl s1, int32_pl s2, int32_pl s3, int32_pl s4, int32_al[s1][s2][s3][s4] inArr, int32_al[s4] multArr, int32_al[s4] biasArr, int32_pl multExprScaleDownSf, int32_pl biasExprScaleUpSf, int32_al[s1][s2][s3][s4] outputArr){
int32_pl inpSize = s1*s2*s3*s4;
int32_al[inpSize] inArrReshaped;
int32_al[inpSize] multArrReshaped;
int32_al[inpSize] multExprAns;
for i1=[0:s1]{
for i2=[0:s2]{
for i3=[0:s3]{
for i4=[0:s4]{
int32_pl linIdx = (i1*s2*s3*s4) + (i2*s3*s4) + (i3*s4) + i4;
inArrReshaped[linIdx] = inArr[i1][i2][i3][i4];
multArrReshaped[linIdx] = multArr[i4];
};
};
};
};
ElemWiseActModelVectorMult(inpSize, inArrReshaped, multArrReshaped, multExprAns);
if (multExprScaleDownSf > 0) {
ScaleDown(inpSize, multExprAns, multExprScaleDownSf);
};
int32_al[s4] biasArrScaledUp;
for ii=[0:s4]{
biasArrScaledUp[ii] = biasArr[ii];
};
if (biasExprScaleUpSf > 0){
ScaleUp(s4, biasArrScaledUp, biasExprScaleUpSf);
};
for i1=[0:s1]{
for i2=[0:s2]{
for i3=[0:s3]{
for i4=[0:s4]{
int32_pl linIdx = (i1*s2*s3*s4) + (i2*s3*s4) + (i3*s4) + i4;
outputArr[i1][i2][i3][i4] = multExprAns[linIdx] + biasArrScaledUp[i4];
};
};
};
};
ClearMemSecret1(inpSize, inArrReshaped);
ClearMemSecret1(inpSize, multArrReshaped);
ClearMemSecret1(inpSize, multExprAns);
ClearMemSecret1(s4, biasArrScaledUp);
}
def void FusedBatchNorm5511(int32_pl s1, int32_pl s2, int32_pl s3, int32_pl s4, int32_pl s5, int32_al[s1][s2][s3][s4][s5] inArr, int32_al[s5] multArr, int32_al[s5] biasArr, int32_pl multExprScaleDownSf, int32_pl biasExprScaleUpSf, int32_al[s1][s2][s3][s4][s5] outputArr){
int32_pl inpSize = s1*s2*s3*s4*s5;
int32_al[inpSize] inArrReshaped;
int32_al[inpSize] multArrReshaped;
int32_al[inpSize] multExprAns;
for i1=[0:s1]{
for i2=[0:s2]{
for i3=[0:s3]{
for i4=[0:s4]{
for i5=[0:s5]{
int32_pl linIdx = (i1*s2*s3*s4*s5) + (i2*s3*s4*s5) + (i3*s4*s5) + (i4*s5) + i5;
inArrReshaped[linIdx] = inArr[i1][i2][i3][i4][i5];
multArrReshaped[linIdx] = multArr[i5];
};
};
};
};
};
ElemWiseActModelVectorMult(inpSize, inArrReshaped, multArrReshaped, multExprAns);
if (multExprScaleDownSf > 0) {
ScaleDown(inpSize, multExprAns, multExprScaleDownSf);
};
int32_al[s5] biasArrScaledUp;
for ii=[0:s5]{
biasArrScaledUp[ii] = biasArr[ii];
};
if (biasExprScaleUpSf > 0){
ScaleUp(s5, biasArrScaledUp, biasExprScaleUpSf);
};
for i1=[0:s1]{
for i2=[0:s2]{
for i3=[0:s3]{
for i4=[0:s4]{
for i5=[0:s5]{
int32_pl linIdx = (i1*s2*s3*s4*s5) + (i2*s3*s4*s5) + (i3*s4*s5) + (i4*s5) + i5;
outputArr[i1][i2][i3][i4][i5] = multExprAns[linIdx] + biasArrScaledUp[i5];
};
};
};
};
};
ClearMemSecret1(inpSize, inArrReshaped);
ClearMemSecret1(inpSize, multArrReshaped);
ClearMemSecret1(inpSize, multExprAns);
ClearMemSecret1(s5, biasArrScaledUp);
}
(**************************)
def void ElemWiseMul2(int32_pl s1, int32_pl s2, int32_al[s1][s2] arr1, int32_al[s1][s2] arr2, int32_al[s1][s2] outArr)
{
int32_pl inpSize = s1*s2;
int32_al[inpSize] arr1Reshaped;
int32_al[inpSize] arr2Reshaped;
int32_al[inpSize] outArrReshaped;
for i1=[0:s1]{
for i2=[0:s2]{
int32_pl linIdx = (i1*s2) + (i2);
arr1Reshaped[linIdx] = arr1[i1][i2];
arr2Reshaped[linIdx] = arr2[i1][i2];
};
};
ElemWiseSecretSharedVectorMult(inpSize, arr1Reshaped, arr2Reshaped, outArrReshaped);
for i1=[0:s1]{
for i2=[0:s2]{
int32_pl linIdx = (i1*s2) + (i2);
outArr[i1][i2] = outArrReshaped[linIdx];
};
};
ClearMemSecret1(inpSize, arr1Reshaped);
ClearMemSecret1(inpSize, arr2Reshaped);
ClearMemSecret1(inpSize, outArrReshaped);
}
def void ElemWiseMul4(int32_pl s1, int32_pl s2, int32_pl s3, int32_pl s4, int32_al[s1][s2][s3][s4] arr1, int32_al[s1][s2][s3][s4] arr2, int32_al[s1][s2][s3][s4] outArr)
{
int32_pl inpSize = s1*s2*s3*s4;
int32_al[inpSize] arr1Reshaped;
int32_al[inpSize] arr2Reshaped;
int32_al[inpSize] outArrReshaped;
for i1=[0:s1]{
for i2=[0:s2]{
for i3=[0:s3]{
for i4=[0:s4]{
int32_pl linIdx = (i1*s2*s3*s4) + (i2*s3*s4) + (i3*s4) + (i4);
arr1Reshaped[linIdx] = arr1[i1][i2][i3][i4];
arr2Reshaped[linIdx] = arr2[i1][i2][i3][i4];
};
};
};
};
ElemWiseSecretSharedVectorMult(inpSize, arr1Reshaped, arr2Reshaped, outArrReshaped);
for i1=[0:s1]{
for i2=[0:s2]{
for i3=[0:s3]{
for i4=[0:s4]{
int32_pl linIdx = (i1*s2*s3*s4) + (i2*s3*s4) + (i3*s4) + (i4);
outArr[i1][i2][i3][i4] = outArrReshaped[linIdx];
};
};
};
};
ClearMemSecret1(inpSize, arr1Reshaped);
ClearMemSecret1(inpSize, arr2Reshaped);
ClearMemSecret1(inpSize, outArrReshaped);
}
def void ElemWiseMul5(int32_pl s1, int32_pl s2, int32_pl s3, int32_pl s4, int32_pl s5, int32_al[s1][s2][s3][s4][s5] arr1, int32_al[s1][s2][s3][s4][s5] arr2, int32_al[s1][s2][s3][s4][s5] outArr)
{
int32_pl inpSize = s1*s2*s3*s4*s5;
int32_al[inpSize] arr1Reshaped;
int32_al[inpSize] arr2Reshaped;
int32_al[inpSize] outArrReshaped;
for i1=[0:s1]{
for i2=[0:s2]{
for i3=[0:s3]{
for i4=[0:s4]{
for i5=[0:s5]{
int32_pl linIdx = (i1*s2*s3*s4*s5) + (i2*s3*s4*s5) + (i3*s4*s5) + (i4*s5) + i5;
arr1Reshaped[linIdx] = arr1[i1][i2][i3][i4][i5];
arr2Reshaped[linIdx] = arr2[i1][i2][i3][i4][i5];
};
};
};
};
};
ElemWiseSecretSharedVectorMult(inpSize, arr1Reshaped, arr2Reshaped, outArrReshaped);
for i1=[0:s1]{
for i2=[0:s2]{
for i3=[0:s3]{
for i4=[0:s4]{
for i5=[0:s5]{
int32_pl linIdx = (i1*s2*s3*s4*s5) + (i2*s3*s4*s5) + (i3*s4*s5) + (i4*s5) + i5;
outArr[i1][i2][i3][i4][i5] = outArrReshaped[linIdx];
};
};
};
};
};
ClearMemSecret1(inpSize, arr1Reshaped);
ClearMemSecret1(inpSize, arr2Reshaped);
ClearMemSecret1(inpSize, outArrReshaped);
}
(**************************)
def void ReduceMean24(int32_pl outS1, int32_pl outS2,
int32_pl inS1, int32_pl inS2, int32_pl inS3, int32_pl inS4,
int32_al[inS1][inS2][inS3][inS4] inputArr,
int32_pl[2] axes,
int32_al[outS1][outS2] outputArr
)
{
int32_pl divisor = inS2*inS3;
int32_pl outputSize = outS1*outS2;
int32_al[outputSize] sumArr;
int32_al[outputSize] outputArrReshaped;
for i1=[0:outS1]{
for i2=[0:outS2]{
int32_al summ = 0;
for i=[0:inS2]{
for j=[0:inS3]{
summ = summ + inputArr[i1][i][j][i2];
};
};
sumArr[(i1*outS2) + i2] = summ;
};
};
ElemWiseVectorPublicDiv(outputSize, sumArr, divisor, outputArrReshaped);
for i1=[0:outS1]{
for i2=[0:outS2]{
outputArr[i1][i2] = outputArrReshaped[(i1*outS2) + i2];
};
};
ClearMemSecret1(outputSize, sumArr);
ClearMemSecret1(outputSize, outputArrReshaped);
}
def void ReduceMeanONNX24(int32_pl outS1, int32_pl outS2,
int32_pl inS1, int32_pl inS2, int32_pl inS3, int32_pl inS4,
int32_al[inS1][inS2][inS3][inS4] inputArr,
int32_pl axis1, int32_pl axis2,
int32_al[outS1][outS2] outputArr
)
{
int32_pl divisor = inS3*inS4;
int32_pl outputSize = outS1*outS2;
int32_al[outputSize] sumArr;
int32_al[outputSize] outputArrReshaped;
for i1=[0:outS1]{
for i2=[0:outS2]{
int32_al summ = 0;
for i=[0:inS3]{
for j=[0:inS4]{
summ = summ + inputArr[i1][i2][i][j];
};
};
sumArr[(i1*outS2) + i2] = summ;
};
};
ElemWiseVectorPublicDiv(outputSize, sumArr, divisor, outputArrReshaped);
for i1=[0:outS1]{
for i2=[0:outS2]{
outputArr[i1][i2] = outputArrReshaped[(i1*outS2) + i2];
};
};
ClearMemSecret1(outputSize, sumArr);
ClearMemSecret1(outputSize, outputArrReshaped);
}
(**************************)
def void ArgMax1(int32_pl outArrS1, int32_pl inArrS1, int32_pl inArrS2, int32_al[inArrS1][inArrS2] inArr, int32_pl dim, int32_al[outArrS1] outArr)
{
(*
Making some assumptions:
- Ignoring dimension
- outArrS1==inArrS1
*)
ArgMax(inArrS1, inArrS2, inArr, outArr);
}
def void ArgMax3(int32_pl outs1, int32_pl outs2, int32_pl outs3,
int32_pl ins1, int32_pl ins2, int32_pl ins3, int32_pl ins4,
int32_al[ins1][ins2][ins3][ins4] inArr, int32_pl dim, int32_al[outs1][outs2][outs3] outArr)
{
(*
Making some assumptions:
- Ignoring dimension
- outs1==ins1 && outs2==ins2 && outs3==ins3
*)
int32_pl size = ins1*ins2*ins3;
int32_al[size][ins4] reshapedInArr;
int32_al[size] reshapedOutArr;
for i1=[0:ins1]{
for i2=[0:ins2]{
for i3=[0:ins3]{
for i4=[0:ins4]{
int32_pl linIdx = (i1*ins2*ins3) + (i2*ins3) + (i3);
reshapedInArr[linIdx][i4] = inArr[i1][i2][i3][i4];
};
};
};
};
ArgMax(size, ins4, reshapedInArr, reshapedOutArr);
for i1=[0:ins1]{
for i2=[0:ins2]{
for i3=[0:ins3]{
int32_pl linIdx = (i1*ins2*ins3) + (i2*ins3) + (i3);
outArr[i1][i2][i3] = reshapedOutArr[linIdx];
};
};
};
ClearMemSecret2(size, ins4, reshapedInArr);
ClearMemSecret1(size, reshapedOutArr);
}
(**************************)
def void Relu2(int32_pl s1, int32_pl s2, int32_al[s1][s2] inArr, int32_al[s1][s2] outArr, int32_pl sf, bool_pl doTruncation)
{
int32_pl size = s1*s2;
int32_al[size] reshapedInArr;
int32_al[size] reshapedOutArr;
for i1=[0:s1]{
for i2=[0:s2]{
int32_pl linIdx = (i1*s2) + (i2);
reshapedInArr[linIdx] = inArr[i1][i2];
};
};
Relu(size, reshapedInArr, reshapedOutArr, sf, doTruncation);
for i1=[0:s1]{
for i2=[0:s2]{
int32_pl linIdx = (i1*s2) + (i2);
outArr[i1][i2] = reshapedOutArr[linIdx];
};
};
ClearMemSecret1(size, reshapedInArr);
ClearMemSecret1(size, reshapedOutArr);
}
def void Relu4(int32_pl s1, int32_pl s2, int32_pl s3, int32_pl s4, int32_al[s1][s2][s3][s4] inArr, int32_al[s1][s2][s3][s4] outArr, int32_pl sf, bool_pl doTruncation)
{
int32_pl size = s1*s2*s3*s4;
int32_al[size] reshapedInArr;
int32_al[size] reshapedOutArr;
for i1=[0:s1]{
for i2=[0:s2]{
for i3=[0:s3]{
for i4=[0:s4]{
int32_pl linIdx = (i1*s2*s3*s4) + (i2*s3*s4) + (i3*s4) + i4;
reshapedInArr[linIdx] = inArr[i1][i2][i3][i4];
};
};
};
};
Relu(size, reshapedInArr, reshapedOutArr, sf, doTruncation);
for i1=[0:s1]{
for i2=[0:s2]{
for i3=[0:s3]{
for i4=[0:s4]{
int32_pl linIdx = (i1*s2*s3*s4) + (i2*s3*s4) + (i3*s4) + i4;
outArr[i1][i2][i3][i4] = reshapedOutArr[linIdx];
};
};
};
};
ClearMemSecret1(size, reshapedInArr);
ClearMemSecret1(size, reshapedOutArr);
}
def void Relu5(int32_pl s1, int32_pl s2, int32_pl s3, int32_pl s4, int32_pl s5, int32_al[s1][s2][s3][s4][s5] inArr, int32_al[s1][s2][s3][s4][s5] outArr, int32_pl sf, bool_pl doTruncation)
{
int32_pl size = s1*s2*s3*s4*s5;
int32_al[size] reshapedInArr;
int32_al[size] reshapedOutArr;
for i1=[0:s1]{
for i2=[0:s2]{
for i3=[0:s3]{
for i4=[0:s4]{
for i5=[0:s5]{
int32_pl linIdx = (i1*s2*s3*s4*s5) + (i2*s3*s4*s5) + (i3*s4*s5) + (i4*s5) + i5;
reshapedInArr[linIdx] = inArr[i1][i2][i3][i4][i5];
};
};
};
};
};
Relu(size, reshapedInArr, reshapedOutArr, sf, doTruncation);
for i1=[0:s1]{
for i2=[0:s2]{
for i3=[0:s3]{
for i4=[0:s4]{
for i5=[0:s5]{
int32_pl linIdx = (i1*s2*s3*s4*s5) + (i2*s3*s4*s5) + (i3*s4*s5) + (i4*s5) + i5;
outArr[i1][i2][i3][i4][i5] = reshapedOutArr[linIdx];
};
};
};
};
};
ClearMemSecret1(size, reshapedInArr);
ClearMemSecret1(size, reshapedOutArr);
}
(**************************)
def void Floor2(int32_pl s1, int32_pl s2, int32_al[s1][s2] inArr, int32_al[s1][s2] outArr, int32_pl sf)
{
int32_pl size = s1*s2;
int32_al[size] reshapedInArr;
int32_al[size] reshapedOutArr;
for i1=[0:s1]{
for i2=[0:s2]{
int32_pl linIdx = (i1*s2) + (i2);
reshapedInArr[linIdx] = inArr[i1][i2];
};
};
Floor(size, reshapedInArr, reshapedOutArr, sf);
for i1=[0:s1]{
for i2=[0:s2]{
int32_pl linIdx = (i1*s2) + (i2);
outArr[i1][i2] = reshapedOutArr[linIdx];
};
};
ClearMemSecret1(size, reshapedInArr);
ClearMemSecret1(size, reshapedOutArr);
}
(**************************)
def void ScaleUp1(int32_pl s1, int32_al[s1] arr, int32_pl sf)
{
ScaleUp(s1, arr, sf);
}
def void ScaleUp2(int32_pl s1, int32_pl s2, int32_al[s1][s2] arr, int32_pl sf)
{
int32_pl size = s1*s2;
int32_al[size] reshapedArr;
for i1=[0:s1]{
for i2=[0:s2]{
int32_pl linIdx = (i1*s2) + (i2);
reshapedArr[linIdx] = arr[i1][i2];
};
};
ScaleUp(size, reshapedArr, sf);
for i1=[0:s1]{
for i2=[0:s2]{
int32_pl linIdx = (i1*s2) + (i2);
arr[i1][i2] = reshapedArr[linIdx];
};
};
ClearMemSecret1(size, reshapedArr);
}
def void ScaleUp3(int32_pl s1, int32_pl s2, int32_pl s3, int32_al[s1][s2][s3] arr, int32_pl sf)
{
int32_pl size = s1*s2*s3;
int32_al[size] reshapedArr;
for i1=[0:s1]{
for i2=[0:s2]{
for i3=[0:s3]{
int32_pl linIdx = (i1*s2*s3) + (i2*s3) + (i3);
reshapedArr[linIdx] = arr[i1][i2][i3];
};
};
};
ScaleUp(size, reshapedArr, sf);
for i1=[0:s1]{
for i2=[0:s2]{
for i3=[0:s3]{
int32_pl linIdx = (i1*s2*s3) + (i2*s3) + (i3);
arr[i1][i2][i3] = reshapedArr[linIdx];
};
};
};
ClearMemSecret1(size, reshapedArr);
}
def void ScaleUp4(int32_pl s1, int32_pl s2, int32_pl s3, int32_pl s4, int32_al[s1][s2][s3][s4] arr, int32_pl sf)
{
int32_pl size = s1*s2*s3*s4;
int32_al[size] reshapedArr;
for i1=[0:s1]{
for i2=[0:s2]{
for i3=[0:s3]{
for i4=[0:s4]{
int32_pl linIdx = (i1*s2*s3*s4) + (i2*s3*s4) + (i3*s4) + i4;
reshapedArr[linIdx] = arr[i1][i2][i3][i4];
};
};
};
};
ScaleUp(size, reshapedArr, sf);
for i1=[0:s1]{
for i2=[0:s2]{
for i3=[0:s3]{
for i4=[0:s4]{
int32_pl linIdx = (i1*s2*s3*s4) + (i2*s3*s4) + (i3*s4) + i4;
arr[i1][i2][i3][i4] = reshapedArr[linIdx];
};
};
};
};
ClearMemSecret1(size, reshapedArr);
}
def void ScaleUp5(int32_pl s1, int32_pl s2, int32_pl s3, int32_pl s4, int32_pl s5, int32_al[s1][s2][s3][s4][s5] arr, int32_pl sf)
{
int32_pl size = s1*s2*s3*s4*s5;
int32_al[size] reshapedArr;
for i1=[0:s1]{
for i2=[0:s2]{
for i3=[0:s3]{
for i4=[0:s4]{
for i5=[0:s5]{
int32_pl linIdx = (i1*s2*s3*s4*s5) + (i2*s3*s4*s5) + (i3*s4*s5) + (i4*s5) + i5;
reshapedArr[linIdx] = arr[i1][i2][i3][i4][i5];
};
};
};
};
};
ScaleUp(size, reshapedArr, sf);
for i1=[0:s1]{
for i2=[0:s2]{
for i3=[0:s3]{
for i4=[0:s4]{
for i5=[0:s5]{
int32_pl linIdx = (i1*s2*s3*s4*s5) + (i2*s3*s4*s5) + (i3*s4*s5) + (i4*s5) + i5;
arr[i1][i2][i3][i4][i5] = reshapedArr[linIdx];
};
};
};
};
};
ClearMemSecret1(size, reshapedArr);
}
(**************************)
def void ScaleDown1(int32_pl s1, int32_al[s1] arr, int32_pl sf)
{
ScaleDown(s1, arr, sf);
}
def void ScaleDown2(int32_pl s1, int32_pl s2, int32_al[s1][s2] arr, int32_pl sf)
{
int32_pl size = s1*s2;
int32_al[size] reshapedArr;
for i1=[0:s1]{
for i2=[0:s2]{
int32_pl linIdx = (i1*s2) + (i2);
reshapedArr[linIdx] = arr[i1][i2];
};
};
ScaleDown(size, reshapedArr, sf);
for i1=[0:s1]{
for i2=[0:s2]{
int32_pl linIdx = (i1*s2) + (i2);
arr[i1][i2] = reshapedArr[linIdx];
};
};
ClearMemSecret1(size, reshapedArr);
}
def void ScaleDown3(int32_pl s1, int32_pl s2, int32_pl s3, int32_al[s1][s2][s3] arr, int32_pl sf)
{
int32_pl size = s1*s2*s3;
int32_al[size] reshapedArr;
for i1=[0:s1]{
for i2=[0:s2]{
for i3=[0:s3]{
int32_pl linIdx = (i1*s2*s3) + (i2*s3) + (i3);
reshapedArr[linIdx] = arr[i1][i2][i3];
};
};
};
ScaleDown(size, reshapedArr, sf);
for i1=[0:s1]{
for i2=[0:s2]{
for i3=[0:s3]{
int32_pl linIdx = (i1*s2*s3) + (i2*s3) + (i3);
arr[i1][i2][i3] = reshapedArr[linIdx];
};
};
};
ClearMemSecret1(size, reshapedArr);
}
def void ScaleDown4(int32_pl s1, int32_pl s2, int32_pl s3, int32_pl s4, int32_al[s1][s2][s3][s4] arr, int32_pl sf)
{
int32_pl size = s1*s2*s3*s4;
int32_al[size] reshapedArr;
for i1=[0:s1]{
for i2=[0:s2]{
for i3=[0:s3]{
for i4=[0:s4]{
int32_pl linIdx = (i1*s2*s3*s4) + (i2*s3*s4) + (i3*s4) + i4;
reshapedArr[linIdx] = arr[i1][i2][i3][i4];
};
};
};
};
ScaleDown(size, reshapedArr, sf);
for i1=[0:s1]{
for i2=[0:s2]{
for i3=[0:s3]{
for i4=[0:s4]{
int32_pl linIdx = (i1*s2*s3*s4) + (i2*s3*s4) + (i3*s4) + i4;
arr[i1][i2][i3][i4] = reshapedArr[linIdx];
};
};
};
};
ClearMemSecret1(size, reshapedArr);
}
def void ScaleDown5(int32_pl s1, int32_pl s2, int32_pl s3, int32_pl s4, int32_pl s5, int64_al[s1][s2][s3][s4][s5] arr, int32_pl sf)
{
int32_pl size = s1*s2*s3*s4*s5;
int32_al[size] reshapedArr;
for i1=[0:s1]{
for i2=[0:s2]{
for i3=[0:s3]{
for i4=[0:s4]{
for i5=[0:s5]{
int32_pl linIdx = (i1*s2*s3*s4*s5) + (i2*s3*s4*s5) + (i3*s4*s5) + (i4*s5) + i5;
reshapedArr[linIdx] = arr[i1][i2][i3][i4][i5];
};
};
};
};
};
ScaleDown(size, reshapedArr, sf);
for i1=[0:s1]{
for i2=[0:s2]{
for i3=[0:s3]{
for i4=[0:s4]{
for i5=[0:s5]{
int32_pl linIdx = (i1*s2*s3*s4*s5) + (i2*s3*s4*s5) + (i3*s4*s5) + (i4*s5) + i5;
arr[i1][i2][i3][i4][i5] = reshapedArr[linIdx];
};
};
};
};
};
ClearMemSecret1(size, reshapedArr);
}