diff --git a/Athos/TFCompiler/TFNodesAST.py b/Athos/TFCompiler/TFNodesAST.py index 170936760d8a7e8a7be2e0d8ec6ede2dc255322c..194532ba17dae244bd0a37a3c8807ffe40b76610 100644 --- a/Athos/TFCompiler/TFNodesAST.py +++ b/Athos/TFCompiler/TFNodesAST.py @@ -538,6 +538,18 @@ class TFNodesAST: AST.ID(dictNodeNameToOutVarStr[inputsRef[2]]), )) + def Transpose(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dict, extraNodeInfoDict : dict): + inputsRef = curNode.getInputsRef() + assert(len(inputsRef) == 2) + permNodeName = inputsRef[1] + # We need to fetch the tensor value of the perm Node + permNode = graph.__getitem__(permNodeName) + permTensor = permNode.getAttrVal("value").getTensor() + permList = permTensor.getContentAsValArr() + assert(permTensor.getDType().kind == "i") + assert(permTensor.getShapeRef().getRank() == 1) + return (None, AST.Transpose(AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]), permList)) + def Squeeze(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dict, extraNodeInfoDict : dict): # TODO : Do this in somewhat better way inputsRef = curNode.getInputsRef() @@ -583,4 +595,4 @@ class TFNodesAST: # TFNodesAST.UninterpFuncCallNames.Pack.name, # list(map(lambda x : AST.ID(dictNodeNameToOutVarStr[x]), inputsRef)) + [AST.Int(axis)] ) # return (None, retAST) - \ No newline at end of file +