From 973998c27caf90ce20214b2a65aa21196c8fb6ba Mon Sep 17 00:00:00 2001
From: Nishant Kumar <t-niskum@microsoft.com>
Date: Fri, 4 Oct 2019 18:44:17 +0530
Subject: [PATCH] Small bug fixes: 1. Bug fix in SeeDot/IRBuilderCSF.py -- for
 visitFloat() -- had a bug earlier. 2. TFNodesAST -- compilation of mul
 TensorFlow node 3. Added some more functions in library files of
 TFEzPCLibrary.

---
 Athos/SeeDot/IR/IRBuilderCSF.py           |  6 ++--
 Athos/TFCompiler/TFNodesAST.py            |  2 +-
 Athos/TFEzPCLibrary/Library32_common.ezpc | 35 ++++++++++++++++++++++
 Athos/TFEzPCLibrary/Library64_common.ezpc | 36 +++++++++++++++++++++++
 4 files changed, 75 insertions(+), 4 deletions(-)

diff --git a/Athos/SeeDot/IR/IRBuilderCSF.py b/Athos/SeeDot/IR/IRBuilderCSF.py
index 31cd90a..bc566e8 100644
--- a/Athos/SeeDot/IR/IRBuilderCSF.py
+++ b/Athos/SeeDot/IR/IRBuilderCSF.py
@@ -95,15 +95,15 @@ class IRBuilderCSF(ASTVisitor):
 		r = node.value
 		p = self.get_expnt(abs(r))
 		k = IR.DataType.getInt(np.ldexp(r, p))
-		comment = IR.Comment('Float to int : ' + str(r) + ' to ' + str(k))
 		expr = None
+		prog = IR.Prog([IR.Comment('Float to int : {0} to {1}, isSecret = {2}.'.format(str(r), str(k), node.isSecret))])
 		if not(node.isSecret):
 			expr = IR.Int(k)
 		else:
 			expr = self.getTempVar()
 			self.decls[expr.idf] = [node.type]
-		prog = IRUtil.prog_merge(IR.Prog([IR.Decl(expr.idf, node.type)]), prog)
-		return (IR.Prog([comment]), expr)
+			prog = IRUtil.prog_merge(IR.Prog([IR.Decl(expr.idf, node.type)]), prog)
+		return (prog, expr)
 
 	def visitId(self, node:AST.ID, args=None):
 		idf = node.name
diff --git a/Athos/TFCompiler/TFNodesAST.py b/Athos/TFCompiler/TFNodesAST.py
index ad3c64f..d5d0956 100644
--- a/Athos/TFCompiler/TFNodesAST.py
+++ b/Athos/TFCompiler/TFNodesAST.py
@@ -121,7 +121,7 @@ class TFNodesAST:
 		inputsRef = curNode.getInputsRef()
 		assert(len(inputsRef) == 2)
 		return (None, AST.BOp(AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]),
-							TFNodesAST.getOperatorsIdx('*'),
+							TFNodesAST.getOperatorsIdx('.*'),
 							AST.ID(dictNodeNameToOutVarStr[inputsRef[1]])
 							))
 
diff --git a/Athos/TFEzPCLibrary/Library32_common.ezpc b/Athos/TFEzPCLibrary/Library32_common.ezpc
index 935ca25..f6184db 100644
--- a/Athos/TFEzPCLibrary/Library32_common.ezpc
+++ b/Athos/TFEzPCLibrary/Library32_common.ezpc
@@ -146,6 +146,15 @@ def void CreateIdentity44(int32_pl s1, int32_pl s2, int32_pl s3, int32_pl s4, in
 	};
 }
 
+(**************************)
+def void CreateCopy2211(int32_pl s1, int32_pl s2, int32_pl inps1, int32_pl inps2, int32_al[inps1][inps2] inArr, int32_pl perDimSize, int32_pl[perDimSize] beginIdx, int32_pl[perDimSize] sizeIdx, int32_al[s1][s2] outArr){
+	for i=[0:s1]{
+		for j=[0:s2]{
+			outArr[i][j] = inArr[beginIdx[0]+i][beginIdx[1]+j];
+		};
+	};
+}
+
 (**************************)
 def void Concat2T444(int32_pl s1, int32_pl s2, int32_pl s3, int32_pl s4, int32_pl inp1s1, int32_pl inp1s2, int32_pl inp1s3, int32_pl inp1s4, int32_al[inp1s1][inp1s2][inp1s3][inp1s4] inp1, int32_pl inp2s1, int32_pl inp2s2, int32_pl inp2s3, int32_pl inp2s4, int32_al[inp2s1][inp2s2][inp2s3][inp2s4] inp2, int32_pl axis, int32_al[s1][s2][s3][s4] outp){
 	for i1=[0:s1]{
@@ -195,6 +204,29 @@ def void Concat2T444(int32_pl s1, int32_pl s2, int32_pl s3, int32_pl s4, int32_p
 	};
 }
 
+def void Concat2T222(int32_pl s1, int32_pl s2, int32_pl inp1s1, int32_pl inp1s2, int32_al[inp1s1][inp1s2] inp1, int32_pl inp2s1, int32_pl inp2s2, int32_al[inp2s1][inp2s2] inp2, int32_pl axis, int32_al[s1][s2] outp){
+	for i1=[0:s1]{
+		for i2=[0:s2]{
+			if (axis==0){
+				if (i1 < inp1s1){
+					outp[i1][i2] = inp1[i1][i2];
+				}
+				else{
+					outp[i1][i2] = inp2[i1-inp1s1][i2];
+				};
+			}
+			else{
+				if (i2 < inp1s2){
+					outp[i1][i2] = inp1[i1][i2];
+				}
+				else{
+					outp[i1][i2] = inp2[i1][i2-inp1s2];
+				};
+			};
+		};
+	};
+}
+
 (**************************)
 (* Generic implementation of Conv2DCSF *)
 
@@ -353,3 +385,6 @@ def void ClearMemPublic(int32_pl x){
 	return;
 }
 
+def void ClearMemPublic1(int32_pl s, int32_pl[s] x){
+	return;
+}
\ No newline at end of file
diff --git a/Athos/TFEzPCLibrary/Library64_common.ezpc b/Athos/TFEzPCLibrary/Library64_common.ezpc
index a715aba..042cb9c 100644
--- a/Athos/TFEzPCLibrary/Library64_common.ezpc
+++ b/Athos/TFEzPCLibrary/Library64_common.ezpc
@@ -146,6 +146,15 @@ def void CreateIdentity44(int32_pl s1, int32_pl s2, int32_pl s3, int32_pl s4, in
 	};
 }
 
+(**************************)
+def void CreateCopy2211(int32_pl s1, int32_pl s2, int32_pl inps1, int32_pl inps2, int64_al[inps1][inps2] inArr, int32_pl perDimSize, int32_pl[perDimSize] beginIdx, int32_pl[perDimSize] sizeIdx, int64_al[s1][s2] outArr){
+	for i=[0:s1]{
+		for j=[0:s2]{
+			outArr[i][j] = inArr[beginIdx[0]+i][beginIdx[1]+j];
+		};
+	};
+}
+
 (**************************)
 def void Concat2T444(int32_pl s1, int32_pl s2, int32_pl s3, int32_pl s4, int32_pl inp1s1, int32_pl inp1s2, int32_pl inp1s3, int32_pl inp1s4, int64_al[inp1s1][inp1s2][inp1s3][inp1s4] inp1, int32_pl inp2s1, int32_pl inp2s2, int32_pl inp2s3, int32_pl inp2s4, int64_al[inp2s1][inp2s2][inp2s3][inp2s4] inp2, int32_pl axis, int64_al[s1][s2][s3][s4] outp){
 	for i1=[0:s1]{
@@ -195,6 +204,29 @@ def void Concat2T444(int32_pl s1, int32_pl s2, int32_pl s3, int32_pl s4, int32_p
 	};
 }
 
+def void Concat2T222(int32_pl s1, int32_pl s2, int32_pl inp1s1, int32_pl inp1s2, int64_al[inp1s1][inp1s2] inp1, int32_pl inp2s1, int32_pl inp2s2, int64_al[inp2s1][inp2s2] inp2, int32_pl axis, int64_al[s1][s2] outp){
+	for i1=[0:s1]{
+		for i2=[0:s2]{
+			if (axis==0){
+				if (i1 < inp1s1){
+					outp[i1][i2] = inp1[i1][i2];
+				}
+				else{
+					outp[i1][i2] = inp2[i1-inp1s1][i2];
+				};
+			}
+			else{
+				if (i2 < inp1s2){
+					outp[i1][i2] = inp1[i1][i2];
+				}
+				else{
+					outp[i1][i2] = inp2[i1][i2-inp1s2];
+				};
+			};
+		};
+	};
+}
+
 (**************************)
 (* Generic implementation of Conv2DCSF *)
 
@@ -351,4 +383,8 @@ def void Squeeze24(int32_pl s1, int32_pl s2, int32_pl dim1, int32_pl dim2, int32
 (**************************)
 def void ClearMemPublic(int32_pl x){
 	return;
+}
+
+def void ClearMemPublic1(int32_pl s, int32_pl[s] x){
+	return;
 }
\ No newline at end of file
-- 
GitLab