diff --git a/Athos/TFCompiler/TFNodesAST.py b/Athos/TFCompiler/TFNodesAST.py index d5d09562831aa54bf835e49cf6d60834b9d9198a..a6994cf294fda793b03af58a65de4cbe2da4e8cb 100644 --- a/Athos/TFCompiler/TFNodesAST.py +++ b/Athos/TFCompiler/TFNodesAST.py @@ -116,6 +116,13 @@ class TFNodesAST: TFNodesAST.getOperatorsIdx('+'), AST.ID(dictNodeNameToOutVarStr[inputsRef[1]]) )) + def AddV2(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dict, extraNodeInfoDict : dict): + inputsRef = curNode.getInputsRef() + assert(len(inputsRef) == 2) + return (None, AST.BOp(AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]), + TFNodesAST.getOperatorsIdx('+'), + AST.ID(dictNodeNameToOutVarStr[inputsRef[1]]) + )) def Mul(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dict, extraNodeInfoDict : dict): inputsRef = curNode.getInputsRef() @@ -533,6 +540,12 @@ class TFNodesAST: AST.ID(dictNodeNameToOutVarStr[inputsRef[1]]), AST.ID(dictNodeNameToOutVarStr[inputsRef[2]]), )) + def FusedBatchNormV3(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dict, extraNodeInfoDict : dict): + inputsRef = curNode.getInputsRef() + return (None, AST.FusedBatchNorm(AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]), + AST.ID(dictNodeNameToOutVarStr[inputsRef[1]]), + AST.ID(dictNodeNameToOutVarStr[inputsRef[2]]), + )) def Squeeze(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dict, extraNodeInfoDict : dict): # TODO : Do this in somewhat better way @@ -563,6 +576,9 @@ class TFNodesAST: inputsRef = curNode.getInputsRef() return (None, AST.ID(dictNodeNameToOutVarStr[inputsRef[0]])) + def Softmax(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dict, extraNodeInfoDict : dict): + inputsRef = curNode.getInputsRef() + return (None, AST.ID(dictNodeNameToOutVarStr[inputsRef[0]])) # def StridedSlice(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dict, extraNodeInfoDict : dict): # inputsRef = curNode.getInputsRef() # return (None, AST.ID(dictNodeNameToOutVarStr[inputsRef[0]])) @@ -575,4 +591,4 @@ class TFNodesAST: # retAST = AST.UninterpFuncCall(extraNodeInfoDict[curNode.getName()][0], # TFNodesAST.UninterpFuncCallNames.Pack.name, # list(map(lambda x : AST.ID(dictNodeNameToOutVarStr[x]), inputsRef)) + [AST.Int(axis)] ) - # return (None, retAST) \ No newline at end of file + # return (None, retAST) diff --git a/Athos/TFEzPCLibrary/Library32_cpp.ezpc b/Athos/TFEzPCLibrary/Library32_cpp.ezpc index 4663bac5a601f2642e6ab754427e2fa1f3e63f13..4e5911ae232ee2b08a00d970e43fa2a47c083918 100644 --- a/Athos/TFEzPCLibrary/Library32_cpp.ezpc +++ b/Athos/TFEzPCLibrary/Library32_cpp.ezpc @@ -278,7 +278,7 @@ def void FusedBatchNorm4411(int32_pl s1, int32_pl s2, int32_pl s3, int32_pl s4, def void ReduceMean24(int32_pl outS1, int32_pl outS2, int32_pl inS1, int32_pl inS2, int32_pl inS3, int32_pl inS4, int32_al[inS1][inS2][inS3][inS4] inputArr, - int32_pl[2] axes, + int32_al[2] axes, int32_al[outS1][outS2] outputArr ) { @@ -332,4 +332,4 @@ def void StartComputation() def void EndComputation() { return; -} \ No newline at end of file +} diff --git a/Athos/TFEzPCLibrary/Library64_cpp.ezpc b/Athos/TFEzPCLibrary/Library64_cpp.ezpc index c32a6488452cc54ad8d6e7c378b40911808ac250..3393fa06d5e2123e449f554e8972c07d30af7256 100644 --- a/Athos/TFEzPCLibrary/Library64_cpp.ezpc +++ b/Athos/TFEzPCLibrary/Library64_cpp.ezpc @@ -278,7 +278,7 @@ def void FusedBatchNorm4411(int32_pl s1, int32_pl s2, int32_pl s3, int32_pl s4, def void ReduceMean24(int32_pl outS1, int32_pl outS2, int32_pl inS1, int32_pl inS2, int32_pl inS3, int32_pl inS4, int64_al[inS1][inS2][inS3][inS4] inputArr, - int32_pl[2] axes, + int64_al[2] axes, int64_al[outS1][outS2] outputArr ) { @@ -332,4 +332,4 @@ def void StartComputation() def void EndComputation() { return; -} \ No newline at end of file +}