Skip to content
Snippets Groups Projects
Commit 749ba6b9 authored by Pratik Bhatu's avatar Pratik Bhatu
Browse files

Add back Conv3DBackpropInputV2

parent 614fe2ea
No related branches found
No related tags found
No related merge requests found
......@@ -403,9 +403,64 @@ class TFNodesAST:
options[AST.PaddingKeysDict.strideW] = strideW
options[AST.PaddingKeysDict.ConvDim] = 3
return (None, AST.BOp(AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]),
TFNodesAST.getOperatorsIdx('#'),
AST.ID(dictNodeNameToOutVarStr[inputsRef[1]]),
options))
TFNodesAST.getOperatorsIdx('#'),
AST.ID(dictNodeNameToOutVarStr[inputsRef[1]]),
options))
def Conv3DBackpropInputV2(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dict, extraNodeInfoDict : dict):
inputsRef = curNode.getInputsRef()
assert(len(inputsRef)==3) #output_shape, filter, input
stridesUsed = curNode.getAttrMapRef()["\"strides\""].getList().getILi()
assert(stridesUsed[0]==1 and stridesUsed[4]==1)
strideD = stridesUsed[1]
strideH = stridesUsed[2]
strideW = stridesUsed[3]
filterShape = extraNodeInfoDict[inputsRef[1]][0]
FD = filterShape[0]
FH = filterShape[1]
FW = filterShape[2]
inputShape = extraNodeInfoDict[inputsRef[2]][0]
inputD = inputShape[1]
inputH = inputShape[2]
inputW = inputShape[3]
outputShape = extraNodeInfoDict[curNode.getName()][0]
outputD = outputShape[1]
outputH = outputShape[2]
outputW = outputShape[3]
paddingUsedStr = curNode.getAttrMapRef()["\"padding\""].getS()
# Important: Using outputH and outputW in the below is not an error!
# For convTranspose, the parameters passed in the node are of the conv of which this convTranspose is an inverse.
# Which is why the call to helper_findPadding makes sense.
# The zPads below are of the conv of which this convTranspose is an inverse.
[zPadDLeft, zPadDRight, zPadHLeft, zPadHRight, zPadWLeft, zPadWRight] = TFNodesAST.helper_findPadding(outputH, outputW, FH, FW, strideH, strideW, paddingUsedStr, imgD = outputD, FD = FD, strideD = strideD)
options = {}
options[AST.PaddingKeysDict.FD] = FD
options[AST.PaddingKeysDict.FH] = FH
options[AST.PaddingKeysDict.FW] = FW
options[AST.PaddingKeysDict.zPadDLeft] = zPadDLeft
options[AST.PaddingKeysDict.zPadDRight] = zPadDRight
options[AST.PaddingKeysDict.zPadHLeft] = zPadHLeft
options[AST.PaddingKeysDict.zPadHRight] = zPadHRight
options[AST.PaddingKeysDict.zPadWLeft] = zPadWLeft
options[AST.PaddingKeysDict.zPadWRight] = zPadWRight
options[AST.PaddingKeysDict.strideD] = strideD
options[AST.PaddingKeysDict.strideH] = strideH
options[AST.PaddingKeysDict.strideW] = strideW
options[AST.PaddingKeysDict.ConvDim] = 3
options[AST.PaddingKeysDict.outputImgD] = outputD
options[AST.PaddingKeysDict.outputImgH] = outputH
options[AST.PaddingKeysDict.outputImgW] = outputW
return (None, AST.BOp(AST.ID(dictNodeNameToOutVarStr[inputsRef[2]]),
TFNodesAST.getOperatorsIdx('#T'),
AST.ID(dictNodeNameToOutVarStr[inputsRef[1]]),
options))
def helper_processPool(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dict, extraNodeInfoDict : dict, typeOfPool:str):
inputsRef = curNode.getInputsRef()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment