From 3689a4e1d4d2d58334c07dd326555f2ab531ed35 Mon Sep 17 00:00:00 2001
From: Pratik Bhatu <prbhatu@microsoft.com>
Date: Fri, 22 May 2020 16:25:17 +0530
Subject: [PATCH] Add addv2, fusedbatchnormv3, softmax(identity)

---
 Athos/TFCompiler/TFNodesAST.py         | 18 +++++++++++++++++-
 Athos/TFEzPCLibrary/Library32_cpp.ezpc |  4 ++--
 Athos/TFEzPCLibrary/Library64_cpp.ezpc |  4 ++--
 3 files changed, 21 insertions(+), 5 deletions(-)

diff --git a/Athos/TFCompiler/TFNodesAST.py b/Athos/TFCompiler/TFNodesAST.py
index d5d0956..a6994cf 100644
--- a/Athos/TFCompiler/TFNodesAST.py
+++ b/Athos/TFCompiler/TFNodesAST.py
@@ -116,6 +116,13 @@ class TFNodesAST:
 							TFNodesAST.getOperatorsIdx('+'),
 							AST.ID(dictNodeNameToOutVarStr[inputsRef[1]])
 							))
+	def AddV2(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dict, extraNodeInfoDict : dict):
+		inputsRef = curNode.getInputsRef()
+		assert(len(inputsRef) == 2)
+		return (None, AST.BOp(AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]),
+							TFNodesAST.getOperatorsIdx('+'),
+							AST.ID(dictNodeNameToOutVarStr[inputsRef[1]])
+							))
 
 	def Mul(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dict, extraNodeInfoDict : dict):
 		inputsRef = curNode.getInputsRef()
@@ -533,6 +540,12 @@ class TFNodesAST:
 										 AST.ID(dictNodeNameToOutVarStr[inputsRef[1]]),
 										 AST.ID(dictNodeNameToOutVarStr[inputsRef[2]]),
 										))
+	def FusedBatchNormV3(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dict, extraNodeInfoDict : dict):
+		inputsRef = curNode.getInputsRef()
+		return (None, AST.FusedBatchNorm(AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]),
+										 AST.ID(dictNodeNameToOutVarStr[inputsRef[1]]),
+										 AST.ID(dictNodeNameToOutVarStr[inputsRef[2]]),
+										))
 
 	def Squeeze(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dict, extraNodeInfoDict : dict):
 		# TODO : Do this in somewhat better way
@@ -563,6 +576,9 @@ class TFNodesAST:
 		inputsRef = curNode.getInputsRef()
 		return (None, AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]))
 
+	def Softmax(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dict, extraNodeInfoDict : dict):
+		inputsRef = curNode.getInputsRef()
+		return (None, AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]))
 	# def StridedSlice(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dict, extraNodeInfoDict : dict):
 	# 	inputsRef = curNode.getInputsRef()
 	# 	return (None, AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]))
@@ -575,4 +591,4 @@ class TFNodesAST:
 	# 	retAST = AST.UninterpFuncCall(extraNodeInfoDict[curNode.getName()][0],
 	# 								TFNodesAST.UninterpFuncCallNames.Pack.name, 
 	# 								 list(map(lambda x : AST.ID(dictNodeNameToOutVarStr[x]), inputsRef)) + [AST.Int(axis)] )
-	# 	return (None, retAST)
\ No newline at end of file
+	# 	return (None, retAST)
diff --git a/Athos/TFEzPCLibrary/Library32_cpp.ezpc b/Athos/TFEzPCLibrary/Library32_cpp.ezpc
index 4663bac..4e5911a 100644
--- a/Athos/TFEzPCLibrary/Library32_cpp.ezpc
+++ b/Athos/TFEzPCLibrary/Library32_cpp.ezpc
@@ -278,7 +278,7 @@ def void FusedBatchNorm4411(int32_pl s1, int32_pl s2, int32_pl s3, int32_pl s4,
 def void ReduceMean24(int32_pl outS1, int32_pl outS2, 
 					  int32_pl inS1, int32_pl inS2, int32_pl inS3, int32_pl inS4, 
 					  int32_al[inS1][inS2][inS3][inS4] inputArr,
-					  int32_pl[2] axes,
+					  int32_al[2] axes,
 					  int32_al[outS1][outS2] outputArr
 					  )
 {
@@ -332,4 +332,4 @@ def void StartComputation()
 def void EndComputation()
 {
 	return;
-}
\ No newline at end of file
+}
diff --git a/Athos/TFEzPCLibrary/Library64_cpp.ezpc b/Athos/TFEzPCLibrary/Library64_cpp.ezpc
index c32a648..3393fa0 100644
--- a/Athos/TFEzPCLibrary/Library64_cpp.ezpc
+++ b/Athos/TFEzPCLibrary/Library64_cpp.ezpc
@@ -278,7 +278,7 @@ def void FusedBatchNorm4411(int32_pl s1, int32_pl s2, int32_pl s3, int32_pl s4,
 def void ReduceMean24(int32_pl outS1, int32_pl outS2, 
 					  int32_pl inS1, int32_pl inS2, int32_pl inS3, int32_pl inS4, 
 					  int64_al[inS1][inS2][inS3][inS4] inputArr,
-					  int32_pl[2] axes,
+					  int64_al[2] axes,
 					  int64_al[outS1][outS2] outputArr
 					  )
 {
@@ -332,4 +332,4 @@ def void StartComputation()
 def void EndComputation()
 {
 	return;
-}
\ No newline at end of file
+}
-- 
GitLab