From 29656c7a3b7b8865c1168ba3647b12568702978f Mon Sep 17 00:00:00 2001 From: Pratik Bhatu <prbhatu@microsoft.com> Date: Wed, 3 Jun 2020 03:22:54 +0530 Subject: [PATCH] Add back conv3d --- Athos/TFCompiler/TFNodesAST.py | 77 ++++++++++++++++++++++++++++++---- 1 file changed, 68 insertions(+), 9 deletions(-) diff --git a/Athos/TFCompiler/TFNodesAST.py b/Athos/TFCompiler/TFNodesAST.py index 194532b..8e05018 100644 --- a/Athos/TFCompiler/TFNodesAST.py +++ b/Athos/TFCompiler/TFNodesAST.py @@ -280,13 +280,16 @@ class TFNodesAST: assert(len(inputsRef) == 2) return (None, AST.Reshape(AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]), extraNodeInfoDict[curNode.getName()][0], None)) - def helper_findPadding(imgH, imgW, FH, FW, strideH, strideW, paddingUsedStr): - zPadHLeft = zPadHRight = zPadWLeft = zPadWRight = -1 + def helper_findPadding(imgH, imgW, FH, FW, strideH, strideW, paddingUsedStr, imgD = None, FD = None, strideD = None): + if imgD: + assert(FD) + assert(strideD) + zPadHLeft = zPadHRight = zPadWLeft = zPadWRight = zPadDLeft = zPadDRight = -1 if (paddingUsedStr == "\"SAME\""): # Reference for following: - # https://web.archive.org/web/20171223022012/https://www.tensorflow.org/api_guides/python/nn - totalPaddingH = totalPaddingW = 0 - + # https://web.archive.org/web/20171223022012/https://www.tensorflow.org/api_guides/python/nn + totalPaddingH = totalPaddingW = totalPaddingD = 0 + if (imgH % strideH == 0): totalPaddingH = max(FH - strideH, 0) else: @@ -297,17 +300,30 @@ class TFNodesAST: else: totalPaddingW = max(FW - (imgW % strideW), 0) + if imgD: + if (imgD % strideD == 0): + totalPaddingD = max(FD - strideD, 0) + else: + totalPaddingD = max(FD - (imgD % strideD), 0) + zPadHLeft = totalPaddingH // 2 zPadHRight = totalPaddingH - zPadHLeft zPadWLeft = totalPaddingW // 2 - zPadWRight = totalPaddingW - zPadWLeft + zPadWRight = totalPaddingW - zPadWLeft + + zPadDLeft = totalPaddingD // 2 + zPadDRight = totalPaddingD - zPadDLeft + elif (paddingUsedStr == "\"VALID\""): - zPadHLeft = zPadHRight = zPadWLeft = zPadWRight = 0 + zPadHLeft = zPadHRight = zPadWLeft = zPadWRight = zPadDLeft = zPadDRight = 0 else: - zPadHLeft = zPadHRight = zPadWLeft = zPadWRight = -1 + zPadHLeft = zPadHRight = zPadWLeft = zPadWRight = zPadDLeft = zPadDRight = -1 - return [zPadHLeft, zPadHRight, zPadWLeft, zPadWRight] + if imgD: + return [zPadDLeft, zPadDRight, zPadHLeft, zPadHRight, zPadWLeft, zPadWRight] + else: + return [zPadHLeft, zPadHRight, zPadWLeft, zPadWRight] def Conv2D(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dict, extraNodeInfoDict : dict): inputsRef = curNode.getInputsRef() @@ -348,6 +364,49 @@ class TFNodesAST: AST.ID(dictNodeNameToOutVarStr[inputsRef[1]]), options)) + def Conv3D(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dict, extraNodeInfoDict : dict): + inputsRef = curNode.getInputsRef() + assert(len(inputsRef)==2) + + stridesUsed = curNode.getAttrMapRef()["\"strides\""].getList().getILi() + assert(stridesUsed[0]==1 and stridesUsed[4]==1) + strideD = stridesUsed[1] + strideH = stridesUsed[2] + strideW = stridesUsed[3] + + inputShape = extraNodeInfoDict[inputsRef[0]][0] + imgD = inputShape[1] + imgH = inputShape[2] + imgW = inputShape[3] + + filterShape = extraNodeInfoDict[inputsRef[1]][0] + FD = filterShape[0] + FH = filterShape[1] + FW = filterShape[2] + + paddingUsedStr = curNode.getAttrMapRef()["\"padding\""].getS() + + [zPadDLeft, zPadDRight, zPadHLeft, zPadHRight, zPadWLeft, zPadWRight] = TFNodesAST.helper_findPadding(imgH, imgW, FH, FW, strideH, strideW, paddingUsedStr, imgD, FD, strideD ) + + options = {} + options[AST.PaddingKeysDict.FD] = FD + options[AST.PaddingKeysDict.FH] = FH + options[AST.PaddingKeysDict.FW] = FW + options[AST.PaddingKeysDict.zPadDLeft] = zPadDLeft + options[AST.PaddingKeysDict.zPadDRight] = zPadDRight + options[AST.PaddingKeysDict.zPadHLeft] = zPadHLeft + options[AST.PaddingKeysDict.zPadHRight] = zPadHRight + options[AST.PaddingKeysDict.zPadWLeft] = zPadWLeft + options[AST.PaddingKeysDict.zPadWRight] = zPadWRight + options[AST.PaddingKeysDict.strideD] = strideD + options[AST.PaddingKeysDict.strideH] = strideH + options[AST.PaddingKeysDict.strideW] = strideW + options[AST.PaddingKeysDict.ConvDim] = 3 + return (None, AST.BOp(AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]), + TFNodesAST.getOperatorsIdx('#'), + AST.ID(dictNodeNameToOutVarStr[inputsRef[1]]), + options)) + def helper_processPool(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dict, extraNodeInfoDict : dict, typeOfPool:str): inputsRef = curNode.getInputsRef() assert(len(inputsRef)==1) -- GitLab