diff --git a/Athos/.gitignore b/Athos/.gitignore index c36857d56981da383bdb94b4918d549adeb6e985..617a7ab99c9e1d1a7681f3f492bef679234ea83c 100644 --- a/Athos/.gitignore +++ b/Athos/.gitignore @@ -10,4 +10,3 @@ SeeDot/debug/ *__temp1.ezpc *__temp2.ezpc __pycache__/ -tests/debug diff --git a/Athos/SeeDot/IR/IRBuilderCSF.py b/Athos/SeeDot/IR/IRBuilderCSF.py index 5561cedee0acd254d20216b8e8e0bad45c3b40bc..8aa3370b1415e94e6da334b558f94febcb93e29b 100644 --- a/Athos/SeeDot/IR/IRBuilderCSF.py +++ b/Athos/SeeDot/IR/IRBuilderCSF.py @@ -1278,9 +1278,19 @@ class IRBuilderCSF(IRBuilderAST): outputiters = [] no_elems = 1 j = 0 + for i in range(len(inputShape)): if i not in reduced_dims: perm.append(i) + # perm will now be [ 1 ,2 ] + [ 0, 3] + perm.extend(reduced_dims) + print(perm) + print(reduced_dims) + loop_shape = [inputShape[perm[i]] for i in range(len(inputShape))] + shuffled_inputiters = [inputiters[perm[i]] for i in range(len(inputShape))] + + for i in range(len(inputShape)): + if i not in reduced_dims: calculated_shape.append(inputShape[i]) outputiters.append(inputiters[j]) j = j + 1 @@ -1289,30 +1299,35 @@ class IRBuilderCSF(IRBuilderAST): if node.keepdims == 1: calculated_shape.append(1) outputiters.append(IR.Int(0,32)) + if calculated_shape == []: calculated_shape = [1] outputiters.append(IR.Int(0,32)) - # perm will now be [ 1 ,2 ] + [ 0, 3] - perm.extend(reduced_dims) - loop_shape = [inputShape[perm[i]] for i in range(len(inputShape))] + outputShape = node.type.shape assert calculated_shape == outputShape, "calculate shape:{} - real_shape: {}".format(calculated_shape, outputShape) sumExpr = self.getTempVar() sumExpr_decl = IR.Decl(sumExpr.idf, Type.Int()) initSumCmd = IR.Assn(sumExpr, IRUtil.zero) - updateSumCmd = IR.Assn(sumExpr, IRUtil.add(sumExpr, IRUtil.addIndex(expr1, inputiters))) - - outer_nesting = len(inputShape) - len(reduced_dims) - temp_flat = self.getTempVar() - temp_flat_decl = IR.Decl(temp_flat.idf, - Type.Tensor([Util.get_volume(loop_shape[:outer_nesting])], node.type.bitlen, node.type.isSecret, node.type.taint), - isSecret=node.type.isSecret) - # i1*s2 + i2 - flat_idx_expr = IRUtil.getFlatArrIdxExpr(inputiters[:outer_nesting], loop_shape[:outer_nesting]) - # temp_flat[i1*s2 + i2] = sum - temp_flat_expr = IRUtil.addIndex(temp_flat, [flat_idx_expr]) - updateOutCmd = IR.Assn(temp_flat_expr, sumExpr) + updateSumCmd = IR.Assn(sumExpr, IRUtil.add(sumExpr, IRUtil.addIndex(expr1, shuffled_inputiters))) + + if node.op == AST.Operators.Mean: + outer_nesting = len(inputShape) - len(reduced_dims) + temp_flat = self.getTempVar() + temp_flat_decl = IR.Decl(temp_flat.idf, + Type.Tensor([Util.get_volume(loop_shape[:outer_nesting])], node.type.bitlen, node.type.isSecret, node.type.taint), + isSecret=node.type.isSecret) + # i1*s2 + i2 + flat_idx_expr = IRUtil.getFlatArrIdxExpr(inputiters[:outer_nesting], loop_shape[:outer_nesting]) + # temp_flat[i1*s2 + i2] = sum + temp_flat_expr = IRUtil.addIndex(temp_flat, [flat_idx_expr]) + updateOutCmd = IR.Assn(temp_flat_expr, sumExpr) + elif node.op == AST.Operators.ADD: + output = self.getTempVar() + output_decl = IR.Decl(output.idf, node.type) + out_expr = IRUtil.addIndex(output, outputiters) + updateOutCmd = IR.Assn(out_expr, sumExpr) # Generate the sum loop inner_loops_processed = 0 @@ -1323,6 +1338,17 @@ class IRBuilderCSF(IRBuilderAST): if(inner_loops_processed == len(reduced_dims)): sum_loop = [initSumCmd] + sum_loop + [updateOutCmd] + if node.op == AST.Operators.ADD: + comment = IR.Comment(str(node.metadata)) + final_prog = IRUtil.prog_merge( prog_1, + IR.Prog([comment]), + IR.Prog([sumExpr_decl, output_decl]), + IR.Prog(sum_loop)) + if not(Util.Config.disableTruncOpti): + self.scaleFacMapping[output.idf] = self.scaleFacMapping[expr1.idf] + + return (final_prog, output) + # Insert call to ElemWiseVectorPublicDiv(size=s1*s2, inp=temp_flat, divisor=s0*s3, out=out_flat) out_flat = self.getTempVar() out_flat_decl = IR.Decl(out_flat.idf, diff --git a/Athos/tests/.gitignore b/Athos/tests/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..169aaf56ee8d81997e8bb968e8dfb6335c04221b --- /dev/null +++ b/Athos/tests/.gitignore @@ -0,0 +1,2 @@ +results-Porthos2PC-server.csv +debug diff --git a/Athos/tests/tf/unittests/test_unaryops.py b/Athos/tests/tf/unittests/test_unaryops.py index 8eadc0c1f569ea3a13cb60d1f672e08c06bd3bef..87e9b77c8ad11162b8823e6372628061c5b4dc63 100644 --- a/Athos/tests/tf/unittests/test_unaryops.py +++ b/Athos/tests/tf/unittests/test_unaryops.py @@ -75,6 +75,9 @@ def test_uop(test_dir, backend, tfOp, a_shape, dtype): ([3, 2], [0, 1], False), ([3, 2], 0, False), ([3, 2], 1, False), + ([3, 2, 4], 1, False), + ([3, 2, 4], [1, 2], False), + ([3, 2, 4], [2, 1], False), ([3, 2], 0, True), ], )