From 4b9f56ddad94b422c57529150302d00a6290713c Mon Sep 17 00:00:00 2001
From: Pratik Bhatu <prbhatu@microsoft.com>
Date: Wed, 3 Jun 2020 02:44:16 +0530
Subject: [PATCH] Add back transpose op in tf compiler

---
 Athos/TFCompiler/TFNodesAST.py | 14 +++++++++++++-
 1 file changed, 13 insertions(+), 1 deletion(-)

diff --git a/Athos/TFCompiler/TFNodesAST.py b/Athos/TFCompiler/TFNodesAST.py
index 1709367..194532b 100644
--- a/Athos/TFCompiler/TFNodesAST.py
+++ b/Athos/TFCompiler/TFNodesAST.py
@@ -538,6 +538,18 @@ class TFNodesAST:
 										 AST.ID(dictNodeNameToOutVarStr[inputsRef[2]]),
 										))
 
+	def Transpose(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dict, extraNodeInfoDict : dict):
+		inputsRef = curNode.getInputsRef()
+		assert(len(inputsRef) == 2)
+		permNodeName = inputsRef[1]
+		# We need to fetch the tensor value of the perm Node
+		permNode = graph.__getitem__(permNodeName)
+		permTensor = permNode.getAttrVal("value").getTensor()
+		permList = permTensor.getContentAsValArr()
+		assert(permTensor.getDType().kind == "i")
+		assert(permTensor.getShapeRef().getRank() == 1)
+		return (None, AST.Transpose(AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]), permList))
+
 	def Squeeze(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dict, extraNodeInfoDict : dict):
 		# TODO : Do this in somewhat better way
 		inputsRef = curNode.getInputsRef()
@@ -583,4 +595,4 @@ class TFNodesAST:
 	# 								TFNodesAST.UninterpFuncCallNames.Pack.name, 
 	# 								 list(map(lambda x : AST.ID(dictNodeNameToOutVarStr[x]), inputsRef)) + [AST.Int(axis)] )
 	# 	return (None, retAST)
-	
\ No newline at end of file
+	
-- 
GitLab