diff --git a/Athos/TFCompiler/TFNodesAST.py b/Athos/TFCompiler/TFNodesAST.py index 170936760d8a7e8a7be2e0d8ec6ede2dc255322c..46bc2a7761e2c3bf09d1db6c40bfc3ca88acd953 100644 --- a/Athos/TFCompiler/TFNodesAST.py +++ b/Athos/TFCompiler/TFNodesAST.py @@ -280,13 +280,16 @@ class TFNodesAST: assert(len(inputsRef) == 2) return (None, AST.Reshape(AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]), extraNodeInfoDict[curNode.getName()][0], None)) - def helper_findPadding(imgH, imgW, FH, FW, strideH, strideW, paddingUsedStr): - zPadHLeft = zPadHRight = zPadWLeft = zPadWRight = -1 + def helper_findPadding(imgH, imgW, FH, FW, strideH, strideW, paddingUsedStr, imgD = None, FD = None, strideD = None): + if imgD: + assert(FD) + assert(strideD) + zPadHLeft = zPadHRight = zPadWLeft = zPadWRight = zPadDLeft = zPadDRight = -1 if (paddingUsedStr == "\"SAME\""): # Reference for following: - # https://web.archive.org/web/20171223022012/https://www.tensorflow.org/api_guides/python/nn - totalPaddingH = totalPaddingW = 0 - + # https://web.archive.org/web/20171223022012/https://www.tensorflow.org/api_guides/python/nn + totalPaddingH = totalPaddingW = totalPaddingD = 0 + if (imgH % strideH == 0): totalPaddingH = max(FH - strideH, 0) else: @@ -297,17 +300,30 @@ class TFNodesAST: else: totalPaddingW = max(FW - (imgW % strideW), 0) + if imgD: + if (imgD % strideD == 0): + totalPaddingD = max(FD - strideD, 0) + else: + totalPaddingD = max(FD - (imgD % strideD), 0) + zPadHLeft = totalPaddingH // 2 zPadHRight = totalPaddingH - zPadHLeft zPadWLeft = totalPaddingW // 2 - zPadWRight = totalPaddingW - zPadWLeft + zPadWRight = totalPaddingW - zPadWLeft + + zPadDLeft = totalPaddingD // 2 + zPadDRight = totalPaddingD - zPadDLeft + elif (paddingUsedStr == "\"VALID\""): - zPadHLeft = zPadHRight = zPadWLeft = zPadWRight = 0 + zPadHLeft = zPadHRight = zPadWLeft = zPadWRight = zPadDLeft = zPadDRight = 0 else: - zPadHLeft = zPadHRight = zPadWLeft = zPadWRight = -1 + zPadHLeft = zPadHRight = zPadWLeft = zPadWRight = zPadDLeft = zPadDRight = -1 - return [zPadHLeft, zPadHRight, zPadWLeft, zPadWRight] + if imgD: + return [zPadDLeft, zPadDRight, zPadHLeft, zPadHRight, zPadWLeft, zPadWRight] + else: + return [zPadHLeft, zPadHRight, zPadWLeft, zPadWRight] def Conv2D(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dict, extraNodeInfoDict : dict): inputsRef = curNode.getInputsRef() @@ -348,6 +364,104 @@ class TFNodesAST: AST.ID(dictNodeNameToOutVarStr[inputsRef[1]]), options)) + def Conv3D(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dict, extraNodeInfoDict : dict): + inputsRef = curNode.getInputsRef() + assert(len(inputsRef)==2) + + stridesUsed = curNode.getAttrMapRef()["\"strides\""].getList().getILi() + assert(stridesUsed[0]==1 and stridesUsed[4]==1) + strideD = stridesUsed[1] + strideH = stridesUsed[2] + strideW = stridesUsed[3] + + inputShape = extraNodeInfoDict[inputsRef[0]][0] + imgD = inputShape[1] + imgH = inputShape[2] + imgW = inputShape[3] + + filterShape = extraNodeInfoDict[inputsRef[1]][0] + FD = filterShape[0] + FH = filterShape[1] + FW = filterShape[2] + + paddingUsedStr = curNode.getAttrMapRef()["\"padding\""].getS() + + [zPadDLeft, zPadDRight, zPadHLeft, zPadHRight, zPadWLeft, zPadWRight] = TFNodesAST.helper_findPadding(imgH, imgW, FH, FW, strideH, strideW, paddingUsedStr, imgD, FD, strideD ) + + options = {} + options[AST.PaddingKeysDict.FD] = FD + options[AST.PaddingKeysDict.FH] = FH + options[AST.PaddingKeysDict.FW] = FW + options[AST.PaddingKeysDict.zPadDLeft] = zPadDLeft + options[AST.PaddingKeysDict.zPadDRight] = zPadDRight + options[AST.PaddingKeysDict.zPadHLeft] = zPadHLeft + options[AST.PaddingKeysDict.zPadHRight] = zPadHRight + options[AST.PaddingKeysDict.zPadWLeft] = zPadWLeft + options[AST.PaddingKeysDict.zPadWRight] = zPadWRight + options[AST.PaddingKeysDict.strideD] = strideD + options[AST.PaddingKeysDict.strideH] = strideH + options[AST.PaddingKeysDict.strideW] = strideW + options[AST.PaddingKeysDict.ConvDim] = 3 + return (None, AST.BOp(AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]), + TFNodesAST.getOperatorsIdx('#'), + AST.ID(dictNodeNameToOutVarStr[inputsRef[1]]), + options)) + + def Conv3DBackpropInputV2(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dict, extraNodeInfoDict : dict): + inputsRef = curNode.getInputsRef() + assert(len(inputsRef)==3) #output_shape, filter, input + + stridesUsed = curNode.getAttrMapRef()["\"strides\""].getList().getILi() + assert(stridesUsed[0]==1 and stridesUsed[4]==1) + strideD = stridesUsed[1] + strideH = stridesUsed[2] + strideW = stridesUsed[3] + + filterShape = extraNodeInfoDict[inputsRef[1]][0] + FD = filterShape[0] + FH = filterShape[1] + FW = filterShape[2] + + inputShape = extraNodeInfoDict[inputsRef[2]][0] + inputD = inputShape[1] + inputH = inputShape[2] + inputW = inputShape[3] + + outputShape = extraNodeInfoDict[curNode.getName()][0] + outputD = outputShape[1] + outputH = outputShape[2] + outputW = outputShape[3] + + paddingUsedStr = curNode.getAttrMapRef()["\"padding\""].getS() + + # Important: Using outputH and outputW in the below is not an error! + # For convTranspose, the parameters passed in the node are of the conv of which this convTranspose is an inverse. + # Which is why the call to helper_findPadding makes sense. + # The zPads below are of the conv of which this convTranspose is an inverse. + [zPadDLeft, zPadDRight, zPadHLeft, zPadHRight, zPadWLeft, zPadWRight] = TFNodesAST.helper_findPadding(outputH, outputW, FH, FW, strideH, strideW, paddingUsedStr, imgD = outputD, FD = FD, strideD = strideD) + + options = {} + options[AST.PaddingKeysDict.FD] = FD + options[AST.PaddingKeysDict.FH] = FH + options[AST.PaddingKeysDict.FW] = FW + options[AST.PaddingKeysDict.zPadDLeft] = zPadDLeft + options[AST.PaddingKeysDict.zPadDRight] = zPadDRight + options[AST.PaddingKeysDict.zPadHLeft] = zPadHLeft + options[AST.PaddingKeysDict.zPadHRight] = zPadHRight + options[AST.PaddingKeysDict.zPadWLeft] = zPadWLeft + options[AST.PaddingKeysDict.zPadWRight] = zPadWRight + options[AST.PaddingKeysDict.strideD] = strideD + options[AST.PaddingKeysDict.strideH] = strideH + options[AST.PaddingKeysDict.strideW] = strideW + options[AST.PaddingKeysDict.ConvDim] = 3 + options[AST.PaddingKeysDict.outputImgD] = outputD + options[AST.PaddingKeysDict.outputImgH] = outputH + options[AST.PaddingKeysDict.outputImgW] = outputW + return (None, AST.BOp(AST.ID(dictNodeNameToOutVarStr[inputsRef[2]]), + TFNodesAST.getOperatorsIdx('#T'), + AST.ID(dictNodeNameToOutVarStr[inputsRef[1]]), + options)) + def helper_processPool(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dict, extraNodeInfoDict : dict, typeOfPool:str): inputsRef = curNode.getInputsRef() assert(len(inputsRef)==1) @@ -538,6 +652,18 @@ class TFNodesAST: AST.ID(dictNodeNameToOutVarStr[inputsRef[2]]), )) + def Transpose(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dict, extraNodeInfoDict : dict): + inputsRef = curNode.getInputsRef() + assert(len(inputsRef) == 2) + permNodeName = inputsRef[1] + # We need to fetch the tensor value of the perm Node + permNode = graph.__getitem__(permNodeName) + permTensor = permNode.getAttrVal("value").getTensor() + permList = permTensor.getContentAsValArr() + assert(permTensor.getDType().kind == "i") + assert(permTensor.getShapeRef().getRank() == 1) + return (None, AST.Transpose(AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]), permList)) + def Squeeze(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dict, extraNodeInfoDict : dict): # TODO : Do this in somewhat better way inputsRef = curNode.getInputsRef() @@ -583,4 +709,4 @@ class TFNodesAST: # TFNodesAST.UninterpFuncCallNames.Pack.name, # list(map(lambda x : AST.ID(dictNodeNameToOutVarStr[x]), inputsRef)) + [AST.Int(axis)] ) # return (None, retAST) - \ No newline at end of file + diff --git a/Athos/TFEzPCLibrary/Library32_common.ezpc b/Athos/TFEzPCLibrary/Library32_common.ezpc index 1832c826a23e20c125f10a756ea4b6c1c16ee839..e6b1ccd52ebf4b30aef62ee1e1792429cfa01750 100644 --- a/Athos/TFEzPCLibrary/Library32_common.ezpc +++ b/Athos/TFEzPCLibrary/Library32_common.ezpc @@ -455,26 +455,6 @@ def void Conv2DCSF(int32_pl N, int32_pl H, int32_pl W, int32_pl CI, Conv2DReshapeMatMulOP(N, newH, newW, CO, matmulOP, outArr); } -(* int32_al[N][H][W][CI] inputArr, - int32_al[FH][FW][CI][CO] filterArr, - int32_al[N][((H-FH+zPadHLeft+zPadHRight)/strideH)+1][((W-FW+zPadWLeft+zPadWRight)/strideW)+1][CO] outArr -*) - -def void Conv2DCSFLoop(int32_pl N, int32_pl H, int32_pl W, int32_pl CI, - int32_pl FH, int32_pl FW, int32_pl CO, - int32_pl zPadHLeft, int32_pl zPadHRight, int32_pl zPadWLeft, int32_pl zPadWRight, - int32_pl strideH, int32_pl strideW, int32_pl G, - int32_al[N][H][W][CI] inputArr, - int32_al[FH][FW][CI][CO] filterArr, - int32_pl consSF, - int32_al[N][((H-FH+(zPadHLeft+zPadHRight))/strideH)+1][((W-FW+(zPadWLeft+zPadWRight))/strideW)+1][CO] outArr) -{ - int32_pl outH = ((H-FH+(zPadHLeft+zPadHRight))/strideH)+1; - int32_pl outW = ((W-FW+(zPadWLeft+zPadWRight))/strideW)+1; - - Conv2DLoop(N, H, W, CI, FH, FW, CO, zPadHLeft, zPadHRight, zPadWLeft, zPadWRight, strideH, strideW, outH, outW, G, inputArr, filterArr, consSF, outArr); -} - (**************************) (* Generic implementation of Conv2D with Groups *) @@ -669,27 +649,6 @@ def void Conv3DReshapeInput(int32_pl N, int32_pl D, int32_pl H, int32_pl W, int3 }; } -(* int32_al[N][D][H][W][CI] inputArr, - int32_al[FD][FH][FW][CI][CO] filterArr, - int32_al[N][((D-FD+zPadDLeft+zPadDRight)/strideD)+1][((H-FH+zPadHLeft+zPadHRight)/strideH)+1][((W-FW+zPadWLeft+zPadWRight)/strideW)+1][CO] outArr -*) -(* Loop implementation of convolution run faster with multithreadin *) -def void Conv3DCSFLoop(int32_pl N, int32_pl D, int32_pl H, int32_pl W, int32_pl CI, - int32_pl FD, int32_pl FH, int32_pl FW, int32_pl CO, - int32_pl zPadDLeft, int32_pl zPadDRight, int32_pl zPadHLeft, int32_pl zPadHRight, int32_pl zPadWLeft, int32_pl zPadWRight, - int32_pl strideD, int32_pl strideH, int32_pl strideW, - int32_al[N][D][H][W][CI] inputArr, - int32_al[FD][FH][FW][CI][CO] filterArr, - int32_pl consSF, - int32_al[N][((D-FD+(zPadDLeft+zPadDRight))/strideD)+1][((H-FH+(zPadHLeft+zPadHRight))/strideH)+1][((W-FW+(zPadWLeft+zPadWRight))/strideW)+1][CO] outArr) -{ - int32_pl outD = ((D-FD+(zPadDLeft+zPadDRight))/strideD)+1; - int32_pl outH = ((H-FH+(zPadHLeft+zPadHRight))/strideH)+1; - int32_pl outW = ((W-FW+(zPadWLeft+zPadWRight))/strideW)+1; - - Conv3DLoop(N, D, H, W, CI, FD, FH, FW, CO, zPadDLeft, zPadDRight, zPadHLeft, zPadHRight, zPadWLeft, zPadWRight, strideD, strideH, strideW, outD, outH, outW, inputArr, filterArr, consSF, outArr); -} - (* int32_al[N][D][H][W][CI] inputArr, int32_al[FD][FH][FW][CI][CO] filterArr, int32_al[N][((D-FD+zPadDLeft+zPadDRight)/strideD)+1][((H-FH+zPadHLeft+zPadHRight)/strideH)+1][((W-FW+zPadWLeft+zPadWRight)/strideW)+1][CO] outArr @@ -1016,23 +975,6 @@ def void ConvTranspose3DReshapeInput(int32_pl N, int32_pl DPrime, int32_pl HPrim }; } -(* int32_al[N][DPrime][HPrime][WPrime][CI] inputArr, - int32_al[FD][FH][FW][CO][CI] filter, - int32_al[N][D][H][W][CO] outputArr -*) -def void ConvTranspose3DCSFLoop(int32_pl N, int32_pl DPrime, int32_pl HPrime, int32_pl WPrime, int32_pl CI, - int32_pl FD, int32_pl FH, int32_pl FW, int32_pl CO, - int32_pl D, int32_pl H, int32_pl W, - int32_pl zPadTrDLeft, int32_pl zPadTrDRight, int32_pl zPadTrHLeft, int32_pl zPadTrHRight, int32_pl zPadTrWLeft, int32_pl zPadTrWRight, - int32_pl strideD, int32_pl strideH, int32_pl strideW, - int32_al[N][DPrime][HPrime][WPrime][CI] inputArr, - int32_al[FD][FH][FW][CO][CI] filterArr, - int32_pl consSF, - int32_al[N][D][H][W][CO] outArr) -{ - ConvTranspose3DLoop(N, DPrime, HPrime, WPrime, CI, FD, FH, FW, CO, zPadTrDLeft, zPadTrDRight, zPadTrHLeft, zPadTrHRight, zPadTrWLeft, zPadTrWRight, strideD, strideH, strideW, D, H, W, inputArr, filterArr, consSF, outArr); -} - (* int32_al[N][DPrime][HPrime][WPrime][CI] inputArr, int32_al[FD][FH][FW][CO][CI] filter, int32_al[N][D][H][W][CO] outputArr @@ -1081,4 +1023,4 @@ def void ClearMemPublic4(int32_pl s1, int32_pl s2, int32_pl s3, int32_pl s4, int def void ClearMemPublic5(int32_pl s1, int32_pl s2, int32_pl s3, int32_pl s4, int32_pl s5, int32_pl[s1][s2][s3][s4][s5] arr) { return; -} \ No newline at end of file +} diff --git a/Athos/TFEzPCLibrary/Library32_cpp.ezpc b/Athos/TFEzPCLibrary/Library32_cpp.ezpc index 57dd0a8fefa1b753d24ede8ec29d7107cc1487c8..fb657271c462d02b64a2b22a5ac37beecc33715b 100644 --- a/Athos/TFEzPCLibrary/Library32_cpp.ezpc +++ b/Athos/TFEzPCLibrary/Library32_cpp.ezpc @@ -84,6 +84,26 @@ def void Conv2DLoop(int32_pl N, int32_pl H, int32_pl W, int32_pl CI, }; } +(* int32_al[N][H][W][CI] inputArr, + int32_al[FH][FW][CI][CO] filterArr, + int32_al[N][((H-FH+zPadHLeft+zPadHRight)/strideH)+1][((W-FW+zPadWLeft+zPadWRight)/strideW)+1][CO] outArr +*) + +def void Conv2DCSFLoop(int32_pl N, int32_pl H, int32_pl W, int32_pl CI, + int32_pl FH, int32_pl FW, int32_pl CO, + int32_pl zPadHLeft, int32_pl zPadHRight, int32_pl zPadWLeft, int32_pl zPadWRight, + int32_pl strideH, int32_pl strideW, int32_pl G, + int32_al[N][H][W][CI] inputArr, + int32_al[FH][FW][CI][CO] filterArr, + int32_pl consSF, + int32_al[N][((H-FH+(zPadHLeft+zPadHRight))/strideH)+1][((W-FW+(zPadWLeft+zPadWRight))/strideW)+1][CO] outArr) +{ + int32_pl outH = ((H-FH+(zPadHLeft+zPadHRight))/strideH)+1; + int32_pl outW = ((W-FW+(zPadWLeft+zPadWRight))/strideW)+1; + + Conv2DLoop(N, H, W, CI, FH, FW, CO, zPadHLeft, zPadHRight, zPadWLeft, zPadWRight, strideH, strideW, outH, outW, G, inputArr, filterArr, consSF, outArr); +} + (**************************) def void Conv3DLoop(int32_pl N, int32_pl D, int32_pl H, int32_pl W, int32_pl CI, int32_pl FD, int32_pl FH, int32_pl FW, int32_pl CO, @@ -126,6 +146,26 @@ def void Conv3DLoop(int32_pl N, int32_pl D, int32_pl H, int32_pl W, int32_pl CI, }; } +(* int32_al[N][D][H][W][CI] inputArr, + int32_al[FD][FH][FW][CI][CO] filterArr, + int32_al[N][((D-FD+zPadDLeft+zPadDRight)/strideD)+1][((H-FH+zPadHLeft+zPadHRight)/strideH)+1][((W-FW+zPadWLeft+zPadWRight)/strideW)+1][CO] outArr +*) +(* Loop implementation of convolution run faster with multithreadin *) +def void Conv3DCSFLoop(int32_pl N, int32_pl D, int32_pl H, int32_pl W, int32_pl CI, + int32_pl FD, int32_pl FH, int32_pl FW, int32_pl CO, + int32_pl zPadDLeft, int32_pl zPadDRight, int32_pl zPadHLeft, int32_pl zPadHRight, int32_pl zPadWLeft, int32_pl zPadWRight, + int32_pl strideD, int32_pl strideH, int32_pl strideW, + int32_al[N][D][H][W][CI] inputArr, + int32_al[FD][FH][FW][CI][CO] filterArr, + int32_pl consSF, + int32_al[N][((D-FD+(zPadDLeft+zPadDRight))/strideD)+1][((H-FH+(zPadHLeft+zPadHRight))/strideH)+1][((W-FW+(zPadWLeft+zPadWRight))/strideW)+1][CO] outArr) +{ + int32_pl outD = ((D-FD+(zPadDLeft+zPadDRight))/strideD)+1; + int32_pl outH = ((H-FH+(zPadHLeft+zPadHRight))/strideH)+1; + int32_pl outW = ((W-FW+(zPadWLeft+zPadWRight))/strideW)+1; + + Conv3DLoop(N, D, H, W, CI, FD, FH, FW, CO, zPadDLeft, zPadDRight, zPadHLeft, zPadHRight, zPadWLeft, zPadWRight, strideD, strideH, strideW, outD, outH, outW, inputArr, filterArr, consSF, outArr); +} (**************************) def void ConvTranspose3DLoop(int32_pl N, int32_pl D, int32_pl H, int32_pl W, int32_pl CI, @@ -171,7 +211,22 @@ def void ConvTranspose3DLoop(int32_pl N, int32_pl D, int32_pl H, int32_pl W, int }; }; } - +(* int32_al[N][DPrime][HPrime][WPrime][CI] inputArr, + int32_al[FD][FH][FW][CO][CI] filter, + int32_al[N][D][H][W][CO] outputArr +*) +def void ConvTranspose3DCSFLoop(int32_pl N, int32_pl DPrime, int32_pl HPrime, int32_pl WPrime, int32_pl CI, + int32_pl FD, int32_pl FH, int32_pl FW, int32_pl CO, + int32_pl D, int32_pl H, int32_pl W, + int32_pl zPadTrDLeft, int32_pl zPadTrDRight, int32_pl zPadTrHLeft, int32_pl zPadTrHRight, int32_pl zPadTrWLeft, int32_pl zPadTrWRight, + int32_pl strideD, int32_pl strideH, int32_pl strideW, + int32_al[N][DPrime][HPrime][WPrime][CI] inputArr, + int32_al[FD][FH][FW][CO][CI] filterArr, + int32_pl consSF, + int32_al[N][D][H][W][CO] outArr) +{ + ConvTranspose3DLoop(N, DPrime, HPrime, WPrime, CI, FD, FH, FW, CO, zPadTrDLeft, zPadTrDRight, zPadTrHLeft, zPadTrHRight, zPadTrWLeft, zPadTrWRight, strideD, strideH, strideW, D, H, W, inputArr, filterArr, consSF, outArr); +} (**************************) def void ArgMax1(int32_pl outArrS1, int32_pl inArrS1, int32_pl inArrS2, int32_al[inArrS1][inArrS2] inArr, int32_pl dim, int32_al[outArrS1] outArr){ @@ -594,4 +649,4 @@ def void StartComputation() def void EndComputation() { return; -} \ No newline at end of file +} diff --git a/Athos/TFEzPCLibrary/Library64_common.ezpc b/Athos/TFEzPCLibrary/Library64_common.ezpc index 9eda4383ed4a57eeb35d79dda370f18872f27bbd..eb88b554d2e43cdc07ae0cc5c2fc0ec2e0968104 100644 --- a/Athos/TFEzPCLibrary/Library64_common.ezpc +++ b/Athos/TFEzPCLibrary/Library64_common.ezpc @@ -460,20 +460,6 @@ def void Conv2DCSF(int32_pl N, int32_pl H, int32_pl W, int32_pl CI, int64_al[N][((H-FH+zPadHLeft+zPadHRight)/strideH)+1][((W-FW+zPadWLeft+zPadWRight)/strideW)+1][CO] outArr *) -def void Conv2DCSFLoop(int32_pl N, int32_pl H, int32_pl W, int32_pl CI, - int32_pl FH, int32_pl FW, int32_pl CO, - int32_pl zPadHLeft, int32_pl zPadHRight, int32_pl zPadWLeft, int32_pl zPadWRight, - int32_pl strideH, int32_pl strideW, int32_pl G, - int64_al[N][H][W][CI] inputArr, - int64_al[FH][FW][CI][CO] filterArr, - int32_pl consSF, - int64_al[N][((H-FH+(zPadHLeft+zPadHRight))/strideH)+1][((W-FW+(zPadWLeft+zPadWRight))/strideW)+1][CO] outArr) -{ - int32_pl outH = ((H-FH+(zPadHLeft+zPadHRight))/strideH)+1; - int32_pl outW = ((W-FW+(zPadWLeft+zPadWRight))/strideW)+1; - - Conv2DLoop(N, H, W, CI, FH, FW, CO, zPadHLeft, zPadHRight, zPadWLeft, zPadWRight, strideH, strideW, outH, outW, G, inputArr, filterArr, consSF, outArr); -} (**************************) (* Generic implementation of Conv2D with Groups *) @@ -669,26 +655,6 @@ def void Conv3DReshapeInput(int32_pl N, int32_pl D, int32_pl H, int32_pl W, int3 }; } -(* int64_al[N][D][H][W][CI] inputArr, - int64_al[FD][FH][FW][CI][CO] filterArr, - int64_al[N][((D-FD+zPadDLeft+zPadDRight)/strideD)+1][((H-FH+zPadHLeft+zPadHRight)/strideH)+1][((W-FW+zPadWLeft+zPadWRight)/strideW)+1][CO] outArr -*) -(* Loop implementation of convolution run faster with multithreadin *) -def void Conv3DCSFLoop(int32_pl N, int32_pl D, int32_pl H, int32_pl W, int32_pl CI, - int32_pl FD, int32_pl FH, int32_pl FW, int32_pl CO, - int32_pl zPadDLeft, int32_pl zPadDRight, int32_pl zPadHLeft, int32_pl zPadHRight, int32_pl zPadWLeft, int32_pl zPadWRight, - int32_pl strideD, int32_pl strideH, int32_pl strideW, - int64_al[N][D][H][W][CI] inputArr, - int64_al[FD][FH][FW][CI][CO] filterArr, - int32_pl consSF, - int64_al[N][((D-FD+(zPadDLeft+zPadDRight))/strideD)+1][((H-FH+(zPadHLeft+zPadHRight))/strideH)+1][((W-FW+(zPadWLeft+zPadWRight))/strideW)+1][CO] outArr) -{ - int32_pl outD = ((D-FD+(zPadDLeft+zPadDRight))/strideD)+1; - int32_pl outH = ((H-FH+(zPadHLeft+zPadHRight))/strideH)+1; - int32_pl outW = ((W-FW+(zPadWLeft+zPadWRight))/strideW)+1; - - Conv3DLoop(N, D, H, W, CI, FD, FH, FW, CO, zPadDLeft, zPadDRight, zPadHLeft, zPadHRight, zPadWLeft, zPadWRight, strideD, strideH, strideW, outD, outH, outW, inputArr, filterArr, consSF, outArr); -} (* int64_al[N][D][H][W][CI] inputArr, int64_al[FD][FH][FW][CI][CO] filterArr, @@ -1015,24 +981,6 @@ def void ConvTranspose3DReshapeInput(int32_pl N, int32_pl DPrime, int32_pl HPrim }; }; } - -(* int64_al[N][DPrime][HPrime][WPrime][CI] inputArr, - int64_al[FD][FH][FW][CO][CI] filter, - int64_al[N][D][H][W][CO] outputArr -*) -def void ConvTranspose3DCSFLoop(int32_pl N, int32_pl DPrime, int32_pl HPrime, int32_pl WPrime, int32_pl CI, - int32_pl FD, int32_pl FH, int32_pl FW, int32_pl CO, - int32_pl D, int32_pl H, int32_pl W, - int32_pl zPadTrDLeft, int32_pl zPadTrDRight, int32_pl zPadTrHLeft, int32_pl zPadTrHRight, int32_pl zPadTrWLeft, int32_pl zPadTrWRight, - int32_pl strideD, int32_pl strideH, int32_pl strideW, - int64_al[N][DPrime][HPrime][WPrime][CI] inputArr, - int64_al[FD][FH][FW][CO][CI] filterArr, - int32_pl consSF, - int64_al[N][D][H][W][CO] outArr) -{ - ConvTranspose3DLoop(N, DPrime, HPrime, WPrime, CI, FD, FH, FW, CO, zPadTrDLeft, zPadTrDRight, zPadTrHLeft, zPadTrHRight, zPadTrWLeft, zPadTrWRight, strideD, strideH, strideW, D, H, W, inputArr, filterArr, consSF, outArr); -} - (* int64_al[N][DPrime][HPrime][WPrime][CI] inputArr, int64_al[FD][FH][FW][CO][CI] filter, int64_al[N][D][H][W][CO] outputArr @@ -1081,4 +1029,4 @@ def void ClearMemPublic4(int32_pl s1, int32_pl s2, int32_pl s3, int32_pl s4, int def void ClearMemPublic5(int32_pl s1, int32_pl s2, int32_pl s3, int32_pl s4, int32_pl s5, int32_pl[s1][s2][s3][s4][s5] arr) { return; -} \ No newline at end of file +} diff --git a/Athos/TFEzPCLibrary/Library64_cpp.ezpc b/Athos/TFEzPCLibrary/Library64_cpp.ezpc index 3c2a6e06d3f05922452c4bb3b398142365adb8d7..c9afe6f35c8216557091b9dcf612e28d489010fe 100644 --- a/Athos/TFEzPCLibrary/Library64_cpp.ezpc +++ b/Athos/TFEzPCLibrary/Library64_cpp.ezpc @@ -83,6 +83,25 @@ def void Conv2DLoop(int32_pl N, int32_pl H, int32_pl W, int32_pl CI, }; }; } +(* int64_al[N][H][W][CI] inputArr, + int64_al[FH][FW][CI][CO] filterArr, + int64_al[N][((H-FH+zPadHLeft+zPadHRight)/strideH)+1][((W-FW+zPadWLeft+zPadWRight)/strideW)+1][CO] outArr +*) + +def void Conv2DCSFLoop(int32_pl N, int32_pl H, int32_pl W, int32_pl CI, + int32_pl FH, int32_pl FW, int32_pl CO, + int32_pl zPadHLeft, int32_pl zPadHRight, int32_pl zPadWLeft, int32_pl zPadWRight, + int32_pl strideH, int32_pl strideW, int32_pl G, + int64_al[N][H][W][CI] inputArr, + int64_al[FH][FW][CI][CO] filterArr, + int32_pl consSF, + int64_al[N][((H-FH+(zPadHLeft+zPadHRight))/strideH)+1][((W-FW+(zPadWLeft+zPadWRight))/strideW)+1][CO] outArr) +{ + int32_pl outH = ((H-FH+(zPadHLeft+zPadHRight))/strideH)+1; + int32_pl outW = ((W-FW+(zPadWLeft+zPadWRight))/strideW)+1; + + Conv2DLoop(N, H, W, CI, FH, FW, CO, zPadHLeft, zPadHRight, zPadWLeft, zPadWRight, strideH, strideW, outH, outW, G, inputArr, filterArr, consSF, outArr); +} (**************************) def void Conv3DLoop(int32_pl N, int32_pl D, int32_pl H, int32_pl W, int32_pl CI, @@ -126,6 +145,26 @@ def void Conv3DLoop(int32_pl N, int32_pl D, int32_pl H, int32_pl W, int32_pl CI, }; } +(* int64_al[N][D][H][W][CI] inputArr, + int64_al[FD][FH][FW][CI][CO] filterArr, + int64_al[N][((D-FD+zPadDLeft+zPadDRight)/strideD)+1][((H-FH+zPadHLeft+zPadHRight)/strideH)+1][((W-FW+zPadWLeft+zPadWRight)/strideW)+1][CO] outArr +*) +(* Loop implementation of convolution run faster with multithreadin *) +def void Conv3DCSFLoop(int32_pl N, int32_pl D, int32_pl H, int32_pl W, int32_pl CI, + int32_pl FD, int32_pl FH, int32_pl FW, int32_pl CO, + int32_pl zPadDLeft, int32_pl zPadDRight, int32_pl zPadHLeft, int32_pl zPadHRight, int32_pl zPadWLeft, int32_pl zPadWRight, + int32_pl strideD, int32_pl strideH, int32_pl strideW, + int64_al[N][D][H][W][CI] inputArr, + int64_al[FD][FH][FW][CI][CO] filterArr, + int32_pl consSF, + int64_al[N][((D-FD+(zPadDLeft+zPadDRight))/strideD)+1][((H-FH+(zPadHLeft+zPadHRight))/strideH)+1][((W-FW+(zPadWLeft+zPadWRight))/strideW)+1][CO] outArr) +{ + int32_pl outD = ((D-FD+(zPadDLeft+zPadDRight))/strideD)+1; + int32_pl outH = ((H-FH+(zPadHLeft+zPadHRight))/strideH)+1; + int32_pl outW = ((W-FW+(zPadWLeft+zPadWRight))/strideW)+1; + + Conv3DLoop(N, D, H, W, CI, FD, FH, FW, CO, zPadDLeft, zPadDRight, zPadHLeft, zPadHRight, zPadWLeft, zPadWRight, strideD, strideH, strideW, outD, outH, outW, inputArr, filterArr, consSF, outArr); +} (**************************) def void ConvTranspose3DLoop(int32_pl N, int32_pl D, int32_pl H, int32_pl W, int32_pl CI, @@ -171,7 +210,22 @@ def void ConvTranspose3DLoop(int32_pl N, int32_pl D, int32_pl H, int32_pl W, int }; }; } - +(* int64_al[N][DPrime][HPrime][WPrime][CI] inputArr, + int64_al[FD][FH][FW][CO][CI] filter, + int64_al[N][D][H][W][CO] outputArr +*) +def void ConvTranspose3DCSFLoop(int32_pl N, int32_pl DPrime, int32_pl HPrime, int32_pl WPrime, int32_pl CI, + int32_pl FD, int32_pl FH, int32_pl FW, int32_pl CO, + int32_pl D, int32_pl H, int32_pl W, + int32_pl zPadTrDLeft, int32_pl zPadTrDRight, int32_pl zPadTrHLeft, int32_pl zPadTrHRight, int32_pl zPadTrWLeft, int32_pl zPadTrWRight, + int32_pl strideD, int32_pl strideH, int32_pl strideW, + int64_al[N][DPrime][HPrime][WPrime][CI] inputArr, + int64_al[FD][FH][FW][CO][CI] filterArr, + int32_pl consSF, + int64_al[N][D][H][W][CO] outArr) +{ + ConvTranspose3DLoop(N, DPrime, HPrime, WPrime, CI, FD, FH, FW, CO, zPadTrDLeft, zPadTrDRight, zPadTrHLeft, zPadTrHRight, zPadTrWLeft, zPadTrWRight, strideD, strideH, strideW, D, H, W, inputArr, filterArr, consSF, outArr); +} (**************************) def void ArgMax1(int32_pl outArrS1, int32_pl inArrS1, int32_pl inArrS2, int64_al[inArrS1][inArrS2] inArr, int32_pl dim, int64_al[outArrS1] outArr){ @@ -594,4 +648,4 @@ def void StartComputation() def void EndComputation() { return; -} \ No newline at end of file +} diff --git a/Athos/TFEzPCLibrary/Library64_porthos.ezpc b/Athos/TFEzPCLibrary/Library64_porthos.ezpc index cc1cf66d87c7718f4d3bd12f851e724692ddc3f9..f092141025b2d94b97ce6ad5cfbf3a8e9f2e1188 100644 --- a/Athos/TFEzPCLibrary/Library64_porthos.ezpc +++ b/Athos/TFEzPCLibrary/Library64_porthos.ezpc @@ -76,7 +76,7 @@ extern void ClearMemSecret1(int32_pl s1, int64_al[s1] arr); extern void ClearMemSecret2(int32_pl s1, int32_pl s2, int64_al[s1][s2] arr); extern void ClearMemSecret3(int32_pl s1, int32_pl s2, int32_pl s3, int64_al[s1][s2][s3] arr); extern void ClearMemSecret4(int32_pl s1, int32_pl s2, int32_pl s3, int32_pl s4, int64_al[s1][s2][s3][s4] arr); -extern void ClearMemSecret5(int32_pl s1, int32_pl s2, int32_pl s3, int32_pl s4, int32_pl s5, int64_al[s1][s2][s3][s4][s5] arr) +extern void ClearMemSecret5(int32_pl s1, int32_pl s2, int32_pl s3, int32_pl s4, int32_pl s5, int64_al[s1][s2][s3][s4][s5] arr); extern void ClearMemPublic2(int32_pl s1, int32_pl s2, int32_pl[s1][s2] arr);