From 4b9f56ddad94b422c57529150302d00a6290713c Mon Sep 17 00:00:00 2001 From: Pratik Bhatu <prbhatu@microsoft.com> Date: Wed, 3 Jun 2020 02:44:16 +0530 Subject: [PATCH] Add back transpose op in tf compiler --- Athos/TFCompiler/TFNodesAST.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/Athos/TFCompiler/TFNodesAST.py b/Athos/TFCompiler/TFNodesAST.py index 1709367..194532b 100644 --- a/Athos/TFCompiler/TFNodesAST.py +++ b/Athos/TFCompiler/TFNodesAST.py @@ -538,6 +538,18 @@ class TFNodesAST: AST.ID(dictNodeNameToOutVarStr[inputsRef[2]]), )) + def Transpose(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dict, extraNodeInfoDict : dict): + inputsRef = curNode.getInputsRef() + assert(len(inputsRef) == 2) + permNodeName = inputsRef[1] + # We need to fetch the tensor value of the perm Node + permNode = graph.__getitem__(permNodeName) + permTensor = permNode.getAttrVal("value").getTensor() + permList = permTensor.getContentAsValArr() + assert(permTensor.getDType().kind == "i") + assert(permTensor.getShapeRef().getRank() == 1) + return (None, AST.Transpose(AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]), permList)) + def Squeeze(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dict, extraNodeInfoDict : dict): # TODO : Do this in somewhat better way inputsRef = curNode.getInputsRef() @@ -583,4 +595,4 @@ class TFNodesAST: # TFNodesAST.UninterpFuncCallNames.Pack.name, # list(map(lambda x : AST.ID(dictNodeNameToOutVarStr[x]), inputsRef)) + [AST.Int(axis)] ) # return (None, retAST) - \ No newline at end of file + -- GitLab