Skip to content
Snippets Groups Projects
Commit 29656c7a authored by Pratik Bhatu's avatar Pratik Bhatu
Browse files

Add back conv3d

parent 4b9f56dd
No related branches found
No related tags found
No related merge requests found
......@@ -280,13 +280,16 @@ class TFNodesAST:
assert(len(inputsRef) == 2)
return (None, AST.Reshape(AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]), extraNodeInfoDict[curNode.getName()][0], None))
def helper_findPadding(imgH, imgW, FH, FW, strideH, strideW, paddingUsedStr):
zPadHLeft = zPadHRight = zPadWLeft = zPadWRight = -1
def helper_findPadding(imgH, imgW, FH, FW, strideH, strideW, paddingUsedStr, imgD = None, FD = None, strideD = None):
if imgD:
assert(FD)
assert(strideD)
zPadHLeft = zPadHRight = zPadWLeft = zPadWRight = zPadDLeft = zPadDRight = -1
if (paddingUsedStr == "\"SAME\""):
# Reference for following:
# https://web.archive.org/web/20171223022012/https://www.tensorflow.org/api_guides/python/nn
totalPaddingH = totalPaddingW = 0
# https://web.archive.org/web/20171223022012/https://www.tensorflow.org/api_guides/python/nn
totalPaddingH = totalPaddingW = totalPaddingD = 0
if (imgH % strideH == 0):
totalPaddingH = max(FH - strideH, 0)
else:
......@@ -297,17 +300,30 @@ class TFNodesAST:
else:
totalPaddingW = max(FW - (imgW % strideW), 0)
if imgD:
if (imgD % strideD == 0):
totalPaddingD = max(FD - strideD, 0)
else:
totalPaddingD = max(FD - (imgD % strideD), 0)
zPadHLeft = totalPaddingH // 2
zPadHRight = totalPaddingH - zPadHLeft
zPadWLeft = totalPaddingW // 2
zPadWRight = totalPaddingW - zPadWLeft
zPadWRight = totalPaddingW - zPadWLeft
zPadDLeft = totalPaddingD // 2
zPadDRight = totalPaddingD - zPadDLeft
elif (paddingUsedStr == "\"VALID\""):
zPadHLeft = zPadHRight = zPadWLeft = zPadWRight = 0
zPadHLeft = zPadHRight = zPadWLeft = zPadWRight = zPadDLeft = zPadDRight = 0
else:
zPadHLeft = zPadHRight = zPadWLeft = zPadWRight = -1
zPadHLeft = zPadHRight = zPadWLeft = zPadWRight = zPadDLeft = zPadDRight = -1
return [zPadHLeft, zPadHRight, zPadWLeft, zPadWRight]
if imgD:
return [zPadDLeft, zPadDRight, zPadHLeft, zPadHRight, zPadWLeft, zPadWRight]
else:
return [zPadHLeft, zPadHRight, zPadWLeft, zPadWRight]
def Conv2D(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dict, extraNodeInfoDict : dict):
inputsRef = curNode.getInputsRef()
......@@ -348,6 +364,49 @@ class TFNodesAST:
AST.ID(dictNodeNameToOutVarStr[inputsRef[1]]),
options))
def Conv3D(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dict, extraNodeInfoDict : dict):
inputsRef = curNode.getInputsRef()
assert(len(inputsRef)==2)
stridesUsed = curNode.getAttrMapRef()["\"strides\""].getList().getILi()
assert(stridesUsed[0]==1 and stridesUsed[4]==1)
strideD = stridesUsed[1]
strideH = stridesUsed[2]
strideW = stridesUsed[3]
inputShape = extraNodeInfoDict[inputsRef[0]][0]
imgD = inputShape[1]
imgH = inputShape[2]
imgW = inputShape[3]
filterShape = extraNodeInfoDict[inputsRef[1]][0]
FD = filterShape[0]
FH = filterShape[1]
FW = filterShape[2]
paddingUsedStr = curNode.getAttrMapRef()["\"padding\""].getS()
[zPadDLeft, zPadDRight, zPadHLeft, zPadHRight, zPadWLeft, zPadWRight] = TFNodesAST.helper_findPadding(imgH, imgW, FH, FW, strideH, strideW, paddingUsedStr, imgD, FD, strideD )
options = {}
options[AST.PaddingKeysDict.FD] = FD
options[AST.PaddingKeysDict.FH] = FH
options[AST.PaddingKeysDict.FW] = FW
options[AST.PaddingKeysDict.zPadDLeft] = zPadDLeft
options[AST.PaddingKeysDict.zPadDRight] = zPadDRight
options[AST.PaddingKeysDict.zPadHLeft] = zPadHLeft
options[AST.PaddingKeysDict.zPadHRight] = zPadHRight
options[AST.PaddingKeysDict.zPadWLeft] = zPadWLeft
options[AST.PaddingKeysDict.zPadWRight] = zPadWRight
options[AST.PaddingKeysDict.strideD] = strideD
options[AST.PaddingKeysDict.strideH] = strideH
options[AST.PaddingKeysDict.strideW] = strideW
options[AST.PaddingKeysDict.ConvDim] = 3
return (None, AST.BOp(AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]),
TFNodesAST.getOperatorsIdx('#'),
AST.ID(dictNodeNameToOutVarStr[inputsRef[1]]),
options))
def helper_processPool(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dict, extraNodeInfoDict : dict, typeOfPool:str):
inputsRef = curNode.getInputsRef()
assert(len(inputsRef)==1)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment