From a489c49d5c92f51df0277a7e5751e1b8baeb0bc1 Mon Sep 17 00:00:00 2001
From: Bhatu <prbhatu@microsoft.com>
Date: Tue, 12 Jan 2021 16:51:23 +0530
Subject: [PATCH] Fixes output mismatch in reduce_mean and reduce_sum.

Closes #97
---
 Athos/.gitignore                          |  1 -
 Athos/SeeDot/IR/IRBuilderCSF.py           | 56 +++++++++++++++++------
 Athos/tests/.gitignore                    |  2 +
 Athos/tests/tf/unittests/test_unaryops.py |  3 ++
 4 files changed, 46 insertions(+), 16 deletions(-)
 create mode 100644 Athos/tests/.gitignore

diff --git a/Athos/.gitignore b/Athos/.gitignore
index c36857d..617a7ab 100644
--- a/Athos/.gitignore
+++ b/Athos/.gitignore
@@ -10,4 +10,3 @@ SeeDot/debug/
 *__temp1.ezpc
 *__temp2.ezpc
 __pycache__/
-tests/debug
diff --git a/Athos/SeeDot/IR/IRBuilderCSF.py b/Athos/SeeDot/IR/IRBuilderCSF.py
index 5561ced..8aa3370 100644
--- a/Athos/SeeDot/IR/IRBuilderCSF.py
+++ b/Athos/SeeDot/IR/IRBuilderCSF.py
@@ -1278,9 +1278,19 @@ class IRBuilderCSF(IRBuilderAST):
 		outputiters = []
 		no_elems = 1
 		j = 0
+
 		for i in range(len(inputShape)):
 			if i not in reduced_dims:
 				perm.append(i)
+		# perm will now be [ 1 ,2 ] + [ 0, 3]
+		perm.extend(reduced_dims)
+		print(perm)
+		print(reduced_dims)
+		loop_shape = [inputShape[perm[i]] for i in range(len(inputShape))]
+		shuffled_inputiters = [inputiters[perm[i]] for i in range(len(inputShape))]
+
+		for i in range(len(inputShape)):
+			if i not in reduced_dims:
 				calculated_shape.append(inputShape[i])
 				outputiters.append(inputiters[j])
 				j = j + 1
@@ -1289,30 +1299,35 @@ class IRBuilderCSF(IRBuilderAST):
 				if node.keepdims == 1:
 					calculated_shape.append(1)
 					outputiters.append(IR.Int(0,32))
+
 		if calculated_shape == []:
 			calculated_shape = [1]
 			outputiters.append(IR.Int(0,32))
-		# perm will now be [ 1 ,2 ] + [ 0, 3]
-		perm.extend(reduced_dims)
-		loop_shape = [inputShape[perm[i]] for i in range(len(inputShape))]
+
 		outputShape = node.type.shape
 		assert calculated_shape == outputShape, "calculate shape:{} - real_shape: {}".format(calculated_shape, outputShape)
 
 		sumExpr = self.getTempVar()
 		sumExpr_decl = IR.Decl(sumExpr.idf, Type.Int())
 		initSumCmd = IR.Assn(sumExpr, IRUtil.zero)
-		updateSumCmd = IR.Assn(sumExpr, IRUtil.add(sumExpr, IRUtil.addIndex(expr1, inputiters)))
-
-		outer_nesting = len(inputShape) - len(reduced_dims)
-		temp_flat = self.getTempVar()
-		temp_flat_decl = IR.Decl(temp_flat.idf,
-								Type.Tensor([Util.get_volume(loop_shape[:outer_nesting])], node.type.bitlen, node.type.isSecret, node.type.taint),
-								isSecret=node.type.isSecret)
-		# i1*s2 + i2
-		flat_idx_expr = IRUtil.getFlatArrIdxExpr(inputiters[:outer_nesting], loop_shape[:outer_nesting])
-		# temp_flat[i1*s2 + i2] = sum
-		temp_flat_expr = IRUtil.addIndex(temp_flat, [flat_idx_expr])
-		updateOutCmd = IR.Assn(temp_flat_expr, sumExpr)
+		updateSumCmd = IR.Assn(sumExpr, IRUtil.add(sumExpr, IRUtil.addIndex(expr1, shuffled_inputiters)))
+
+		if node.op == AST.Operators.Mean:
+			outer_nesting = len(inputShape) - len(reduced_dims)
+			temp_flat = self.getTempVar()
+			temp_flat_decl = IR.Decl(temp_flat.idf,
+									Type.Tensor([Util.get_volume(loop_shape[:outer_nesting])], node.type.bitlen, node.type.isSecret, node.type.taint),
+									isSecret=node.type.isSecret)
+			# i1*s2 + i2
+			flat_idx_expr = IRUtil.getFlatArrIdxExpr(inputiters[:outer_nesting], loop_shape[:outer_nesting])
+			# temp_flat[i1*s2 + i2] = sum
+			temp_flat_expr = IRUtil.addIndex(temp_flat, [flat_idx_expr])
+			updateOutCmd = IR.Assn(temp_flat_expr, sumExpr)
+		elif node.op == AST.Operators.ADD:
+			output = self.getTempVar()
+			output_decl =  IR.Decl(output.idf, node.type)
+			out_expr = IRUtil.addIndex(output, outputiters)
+			updateOutCmd = IR.Assn(out_expr, sumExpr)
 
 		# Generate the sum loop
 		inner_loops_processed = 0
@@ -1323,6 +1338,17 @@ class IRBuilderCSF(IRBuilderAST):
 			if(inner_loops_processed == len(reduced_dims)):
 				sum_loop = [initSumCmd] + sum_loop + [updateOutCmd]
 
+		if node.op == AST.Operators.ADD:
+			comment = IR.Comment(str(node.metadata))
+			final_prog = IRUtil.prog_merge(	prog_1,
+										IR.Prog([comment]),
+										IR.Prog([sumExpr_decl, output_decl]),
+										IR.Prog(sum_loop))
+			if not(Util.Config.disableTruncOpti):
+				self.scaleFacMapping[output.idf] = self.scaleFacMapping[expr1.idf]
+
+			return (final_prog, output)
+
 		# Insert call to ElemWiseVectorPublicDiv(size=s1*s2, inp=temp_flat, divisor=s0*s3, out=out_flat)
 		out_flat = self.getTempVar()
 		out_flat_decl = IR.Decl(out_flat.idf,
diff --git a/Athos/tests/.gitignore b/Athos/tests/.gitignore
new file mode 100644
index 0000000..169aaf5
--- /dev/null
+++ b/Athos/tests/.gitignore
@@ -0,0 +1,2 @@
+results-Porthos2PC-server.csv
+debug
diff --git a/Athos/tests/tf/unittests/test_unaryops.py b/Athos/tests/tf/unittests/test_unaryops.py
index 8eadc0c..87e9b77 100644
--- a/Athos/tests/tf/unittests/test_unaryops.py
+++ b/Athos/tests/tf/unittests/test_unaryops.py
@@ -75,6 +75,9 @@ def test_uop(test_dir, backend, tfOp, a_shape, dtype):
         ([3, 2], [0, 1], False),
         ([3, 2], 0, False),
         ([3, 2], 1, False),
+        ([3, 2, 4], 1, False),
+        ([3, 2, 4], [1, 2], False),
+        ([3, 2, 4], [2, 1], False),
         ([3, 2], 0, True),
     ],
 )
-- 
GitLab