Skip to content
Snippets Groups Projects
Commit a489c49d authored by Bhatu's avatar Bhatu
Browse files

Fixes output mismatch in reduce_mean and reduce_sum.

Closes #97
parent fd7061db
No related branches found
No related tags found
No related merge requests found
......@@ -10,4 +10,3 @@ SeeDot/debug/
*__temp1.ezpc
*__temp2.ezpc
__pycache__/
tests/debug
......@@ -1278,9 +1278,19 @@ class IRBuilderCSF(IRBuilderAST):
outputiters = []
no_elems = 1
j = 0
for i in range(len(inputShape)):
if i not in reduced_dims:
perm.append(i)
# perm will now be [ 1 ,2 ] + [ 0, 3]
perm.extend(reduced_dims)
print(perm)
print(reduced_dims)
loop_shape = [inputShape[perm[i]] for i in range(len(inputShape))]
shuffled_inputiters = [inputiters[perm[i]] for i in range(len(inputShape))]
for i in range(len(inputShape)):
if i not in reduced_dims:
calculated_shape.append(inputShape[i])
outputiters.append(inputiters[j])
j = j + 1
......@@ -1289,30 +1299,35 @@ class IRBuilderCSF(IRBuilderAST):
if node.keepdims == 1:
calculated_shape.append(1)
outputiters.append(IR.Int(0,32))
if calculated_shape == []:
calculated_shape = [1]
outputiters.append(IR.Int(0,32))
# perm will now be [ 1 ,2 ] + [ 0, 3]
perm.extend(reduced_dims)
loop_shape = [inputShape[perm[i]] for i in range(len(inputShape))]
outputShape = node.type.shape
assert calculated_shape == outputShape, "calculate shape:{} - real_shape: {}".format(calculated_shape, outputShape)
sumExpr = self.getTempVar()
sumExpr_decl = IR.Decl(sumExpr.idf, Type.Int())
initSumCmd = IR.Assn(sumExpr, IRUtil.zero)
updateSumCmd = IR.Assn(sumExpr, IRUtil.add(sumExpr, IRUtil.addIndex(expr1, inputiters)))
outer_nesting = len(inputShape) - len(reduced_dims)
temp_flat = self.getTempVar()
temp_flat_decl = IR.Decl(temp_flat.idf,
Type.Tensor([Util.get_volume(loop_shape[:outer_nesting])], node.type.bitlen, node.type.isSecret, node.type.taint),
isSecret=node.type.isSecret)
# i1*s2 + i2
flat_idx_expr = IRUtil.getFlatArrIdxExpr(inputiters[:outer_nesting], loop_shape[:outer_nesting])
# temp_flat[i1*s2 + i2] = sum
temp_flat_expr = IRUtil.addIndex(temp_flat, [flat_idx_expr])
updateOutCmd = IR.Assn(temp_flat_expr, sumExpr)
updateSumCmd = IR.Assn(sumExpr, IRUtil.add(sumExpr, IRUtil.addIndex(expr1, shuffled_inputiters)))
if node.op == AST.Operators.Mean:
outer_nesting = len(inputShape) - len(reduced_dims)
temp_flat = self.getTempVar()
temp_flat_decl = IR.Decl(temp_flat.idf,
Type.Tensor([Util.get_volume(loop_shape[:outer_nesting])], node.type.bitlen, node.type.isSecret, node.type.taint),
isSecret=node.type.isSecret)
# i1*s2 + i2
flat_idx_expr = IRUtil.getFlatArrIdxExpr(inputiters[:outer_nesting], loop_shape[:outer_nesting])
# temp_flat[i1*s2 + i2] = sum
temp_flat_expr = IRUtil.addIndex(temp_flat, [flat_idx_expr])
updateOutCmd = IR.Assn(temp_flat_expr, sumExpr)
elif node.op == AST.Operators.ADD:
output = self.getTempVar()
output_decl = IR.Decl(output.idf, node.type)
out_expr = IRUtil.addIndex(output, outputiters)
updateOutCmd = IR.Assn(out_expr, sumExpr)
# Generate the sum loop
inner_loops_processed = 0
......@@ -1323,6 +1338,17 @@ class IRBuilderCSF(IRBuilderAST):
if(inner_loops_processed == len(reduced_dims)):
sum_loop = [initSumCmd] + sum_loop + [updateOutCmd]
if node.op == AST.Operators.ADD:
comment = IR.Comment(str(node.metadata))
final_prog = IRUtil.prog_merge( prog_1,
IR.Prog([comment]),
IR.Prog([sumExpr_decl, output_decl]),
IR.Prog(sum_loop))
if not(Util.Config.disableTruncOpti):
self.scaleFacMapping[output.idf] = self.scaleFacMapping[expr1.idf]
return (final_prog, output)
# Insert call to ElemWiseVectorPublicDiv(size=s1*s2, inp=temp_flat, divisor=s0*s3, out=out_flat)
out_flat = self.getTempVar()
out_flat_decl = IR.Decl(out_flat.idf,
......
results-Porthos2PC-server.csv
debug
......@@ -75,6 +75,9 @@ def test_uop(test_dir, backend, tfOp, a_shape, dtype):
([3, 2], [0, 1], False),
([3, 2], 0, False),
([3, 2], 1, False),
([3, 2, 4], 1, False),
([3, 2, 4], [1, 2], False),
([3, 2, 4], [2, 1], False),
([3, 2], 0, True),
],
)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment