From a489c49d5c92f51df0277a7e5751e1b8baeb0bc1 Mon Sep 17 00:00:00 2001 From: Bhatu <prbhatu@microsoft.com> Date: Tue, 12 Jan 2021 16:51:23 +0530 Subject: [PATCH] Fixes output mismatch in reduce_mean and reduce_sum. Closes #97 --- Athos/.gitignore | 1 - Athos/SeeDot/IR/IRBuilderCSF.py | 56 +++++++++++++++++------ Athos/tests/.gitignore | 2 + Athos/tests/tf/unittests/test_unaryops.py | 3 ++ 4 files changed, 46 insertions(+), 16 deletions(-) create mode 100644 Athos/tests/.gitignore diff --git a/Athos/.gitignore b/Athos/.gitignore index c36857d..617a7ab 100644 --- a/Athos/.gitignore +++ b/Athos/.gitignore @@ -10,4 +10,3 @@ SeeDot/debug/ *__temp1.ezpc *__temp2.ezpc __pycache__/ -tests/debug diff --git a/Athos/SeeDot/IR/IRBuilderCSF.py b/Athos/SeeDot/IR/IRBuilderCSF.py index 5561ced..8aa3370 100644 --- a/Athos/SeeDot/IR/IRBuilderCSF.py +++ b/Athos/SeeDot/IR/IRBuilderCSF.py @@ -1278,9 +1278,19 @@ class IRBuilderCSF(IRBuilderAST): outputiters = [] no_elems = 1 j = 0 + for i in range(len(inputShape)): if i not in reduced_dims: perm.append(i) + # perm will now be [ 1 ,2 ] + [ 0, 3] + perm.extend(reduced_dims) + print(perm) + print(reduced_dims) + loop_shape = [inputShape[perm[i]] for i in range(len(inputShape))] + shuffled_inputiters = [inputiters[perm[i]] for i in range(len(inputShape))] + + for i in range(len(inputShape)): + if i not in reduced_dims: calculated_shape.append(inputShape[i]) outputiters.append(inputiters[j]) j = j + 1 @@ -1289,30 +1299,35 @@ class IRBuilderCSF(IRBuilderAST): if node.keepdims == 1: calculated_shape.append(1) outputiters.append(IR.Int(0,32)) + if calculated_shape == []: calculated_shape = [1] outputiters.append(IR.Int(0,32)) - # perm will now be [ 1 ,2 ] + [ 0, 3] - perm.extend(reduced_dims) - loop_shape = [inputShape[perm[i]] for i in range(len(inputShape))] + outputShape = node.type.shape assert calculated_shape == outputShape, "calculate shape:{} - real_shape: {}".format(calculated_shape, outputShape) sumExpr = self.getTempVar() sumExpr_decl = IR.Decl(sumExpr.idf, Type.Int()) initSumCmd = IR.Assn(sumExpr, IRUtil.zero) - updateSumCmd = IR.Assn(sumExpr, IRUtil.add(sumExpr, IRUtil.addIndex(expr1, inputiters))) - - outer_nesting = len(inputShape) - len(reduced_dims) - temp_flat = self.getTempVar() - temp_flat_decl = IR.Decl(temp_flat.idf, - Type.Tensor([Util.get_volume(loop_shape[:outer_nesting])], node.type.bitlen, node.type.isSecret, node.type.taint), - isSecret=node.type.isSecret) - # i1*s2 + i2 - flat_idx_expr = IRUtil.getFlatArrIdxExpr(inputiters[:outer_nesting], loop_shape[:outer_nesting]) - # temp_flat[i1*s2 + i2] = sum - temp_flat_expr = IRUtil.addIndex(temp_flat, [flat_idx_expr]) - updateOutCmd = IR.Assn(temp_flat_expr, sumExpr) + updateSumCmd = IR.Assn(sumExpr, IRUtil.add(sumExpr, IRUtil.addIndex(expr1, shuffled_inputiters))) + + if node.op == AST.Operators.Mean: + outer_nesting = len(inputShape) - len(reduced_dims) + temp_flat = self.getTempVar() + temp_flat_decl = IR.Decl(temp_flat.idf, + Type.Tensor([Util.get_volume(loop_shape[:outer_nesting])], node.type.bitlen, node.type.isSecret, node.type.taint), + isSecret=node.type.isSecret) + # i1*s2 + i2 + flat_idx_expr = IRUtil.getFlatArrIdxExpr(inputiters[:outer_nesting], loop_shape[:outer_nesting]) + # temp_flat[i1*s2 + i2] = sum + temp_flat_expr = IRUtil.addIndex(temp_flat, [flat_idx_expr]) + updateOutCmd = IR.Assn(temp_flat_expr, sumExpr) + elif node.op == AST.Operators.ADD: + output = self.getTempVar() + output_decl = IR.Decl(output.idf, node.type) + out_expr = IRUtil.addIndex(output, outputiters) + updateOutCmd = IR.Assn(out_expr, sumExpr) # Generate the sum loop inner_loops_processed = 0 @@ -1323,6 +1338,17 @@ class IRBuilderCSF(IRBuilderAST): if(inner_loops_processed == len(reduced_dims)): sum_loop = [initSumCmd] + sum_loop + [updateOutCmd] + if node.op == AST.Operators.ADD: + comment = IR.Comment(str(node.metadata)) + final_prog = IRUtil.prog_merge( prog_1, + IR.Prog([comment]), + IR.Prog([sumExpr_decl, output_decl]), + IR.Prog(sum_loop)) + if not(Util.Config.disableTruncOpti): + self.scaleFacMapping[output.idf] = self.scaleFacMapping[expr1.idf] + + return (final_prog, output) + # Insert call to ElemWiseVectorPublicDiv(size=s1*s2, inp=temp_flat, divisor=s0*s3, out=out_flat) out_flat = self.getTempVar() out_flat_decl = IR.Decl(out_flat.idf, diff --git a/Athos/tests/.gitignore b/Athos/tests/.gitignore new file mode 100644 index 0000000..169aaf5 --- /dev/null +++ b/Athos/tests/.gitignore @@ -0,0 +1,2 @@ +results-Porthos2PC-server.csv +debug diff --git a/Athos/tests/tf/unittests/test_unaryops.py b/Athos/tests/tf/unittests/test_unaryops.py index 8eadc0c..87e9b77 100644 --- a/Athos/tests/tf/unittests/test_unaryops.py +++ b/Athos/tests/tf/unittests/test_unaryops.py @@ -75,6 +75,9 @@ def test_uop(test_dir, backend, tfOp, a_shape, dtype): ([3, 2], [0, 1], False), ([3, 2], 0, False), ([3, 2], 1, False), + ([3, 2, 4], 1, False), + ([3, 2, 4], [1, 2], False), + ([3, 2, 4], [2, 1], False), ([3, 2], 0, True), ], ) -- GitLab