From cee44f6d66150057721eafb7d78148f721b840b1 Mon Sep 17 00:00:00 2001 From: Bhatu <prbhatu@microsoft.com> Date: Thu, 26 Nov 2020 17:43:49 +0530 Subject: [PATCH] Support for reduced mean Adds support the reduce_mean operation in tensorflow. Consider the example: For inputs: Tensor of shape(s0,s1,s2,s3) reduction axes = [0,3] We generate the following program: If keep_dim == true output is of shape(1,s1,s2,1) else output is of shape(s1,s2) for i1=[0:s1] for i2=[0:s2] sum = 0 for i0=[0:s0] for i3=[0:s3] sum = sum + input[i0][i1][i2][i3] output[i1][i2] = sum / (s0 * s3) // keep_dim=false OR output[0][i1][i2][0] = sum / (s0 * s3) // keep_dim=true TODO: Also add support for reduced sum. --- Athos/SeeDot/AST/AST.py | 7 +- Athos/SeeDot/AST/ASTVisitor.py | 2 - Athos/SeeDot/AST/MtdAST.py | 1 - Athos/SeeDot/AST/PrintAST.py | 2 - Athos/SeeDot/IR/IRBuilderCSF.py | 150 +++++++++++++++++---- Athos/SeeDot/Optimizations/LivenessOpti.py | 2 +- Athos/TFCompiler/TFNodesAST.py | 26 +++- 7 files changed, 149 insertions(+), 41 deletions(-) diff --git a/Athos/SeeDot/AST/AST.py b/Athos/SeeDot/AST/AST.py index 1ecc85d..ca74704 100644 --- a/Athos/SeeDot/AST/AST.py +++ b/Athos/SeeDot/AST/AST.py @@ -356,21 +356,20 @@ class ArgMax(ASTNode): self.inShape = inShape class Reduce(ASTNode): - def __init__(self, expr:ID, dim:ID, keepdims:Int, outShape:list, op: Operators): + def __init__(self, expr:ID, keepdims:bool, outShape:list, op: Operators, reductionAxesList: list): # keepdims is unused for now if assertInputTypes: assert isinstance(expr, ID) - assert isinstance(dim, ID) - assert isinstance(keepdims, Int) + assert isinstance(keepdims, bool) assert isinstance(outShape, list) for elem in outShape: assert isinstance(elem, int) assert isinstance(op, Operators) super().__init__() self.expr = expr - self.dim = dim self.keepdims = keepdims self.outShape = outShape self.op = op + self.reductionAxesList = reductionAxesList # shape : list of int, dataType : ID # NOTE: Though datatype is being passed to this function, the output code eventually only has diff --git a/Athos/SeeDot/AST/ASTVisitor.py b/Athos/SeeDot/AST/ASTVisitor.py index 1f4d663..03f04ad 100644 --- a/Athos/SeeDot/AST/ASTVisitor.py +++ b/Athos/SeeDot/AST/ASTVisitor.py @@ -76,8 +76,6 @@ class ASTVisitor: def visitReduce(self, node:AST.Reduce, args=None): self.visit(node.expr, args) - self.visit(node.dim, args) - self.visit(node.keepdims, args) def visitInput(self, node:AST.Input, args=None): pass diff --git a/Athos/SeeDot/AST/MtdAST.py b/Athos/SeeDot/AST/MtdAST.py index e9d4614..ef9a410 100644 --- a/Athos/SeeDot/AST/MtdAST.py +++ b/Athos/SeeDot/AST/MtdAST.py @@ -85,7 +85,6 @@ class MtdAST(ASTVisitor): def visitReduce(self, node:AST.Reduce, mtd:dict): node.metadata.update(mtd) self.visit(node.expr, mtd) - self.visit(node.dim, mtd) def visitInput(self, node:AST.Input, mtd:dict): node.metadata.update(mtd) diff --git a/Athos/SeeDot/AST/PrintAST.py b/Athos/SeeDot/AST/PrintAST.py index 8402549..1ef915d 100644 --- a/Athos/SeeDot/AST/PrintAST.py +++ b/Athos/SeeDot/AST/PrintAST.py @@ -117,8 +117,6 @@ class PrintAST(ASTVisitor): def visitReduce(self, node:AST.Reduce, args=None): print(indent * node.depth, "reduce", AST.OperatorsSymbolDict[node.op.name], end=' ') self.visit(node.expr) - self.visit(node.dim) - self.visit(node.keepdims) def visitInput(self, node:AST.Input, args=None): print(indent * node.depth, "input( ", node.shape, node.dataType, " <", node.inputByParty.name, "> ", end='') diff --git a/Athos/SeeDot/IR/IRBuilderCSF.py b/Athos/SeeDot/IR/IRBuilderCSF.py index c315b29..ea9b808 100644 --- a/Athos/SeeDot/IR/IRBuilderCSF.py +++ b/Athos/SeeDot/IR/IRBuilderCSF.py @@ -1116,37 +1116,139 @@ class IRBuilderCSF(IRBuilderAST): def visitReduce(self, node:AST.Reduce, args=None): (prog_1, expr1) = self.visit(node.expr) - (prog_2, expr2) = self.visit(node.dim) - - returnExpr = self.getTempVar() - assert(node.op in [AST.Operators.ADD, AST.Operators.Mean]) - if (node.op == AST.Operators.ADD): - funcName = "ReduceSum" - elif (node.op == AST.Operators.Mean): - funcName = "ReduceMean" - if not(Util.Config.disableTruncOpti): - self.scaleFacMapping[returnExpr.idf] = self.scaleFacMapping[expr1.idf] + # We already have the output shape so we dont need to calculate with keep_dims - funcArgsList = OrderedDict() - outputShape = node.type.shape - for ii, curDim in enumerate(outputShape): - funcArgsList[IR.Int(curDim, 32)] = "OutputShape_" + str(ii) + ''' + We need to reduce across axes. + Example: Say reduction axes are specified as 0,3 and keep dim = false + output rank -> len(input_shape) - len(reduction_axes) + output is 2D. + for i1=[0:s1] + for i2=[0:s2] + sum = 0 + for i0=[0:s0] + for i3=[0:s3] + sum = sum + input[i0][i1][i2][i3] + output[i1][i2] = sum / (s0 * s3) + if keep dim == true, output rank is same as input. We generate: + output[0][i1][i2][0] = sum / (s0 * s3) + + Ideally the above loop nest is what we would want to generate. But since we have + a division, we need to make calls to the div functionality and flatten the tensors. + temp_flat[s1*s2]; + out_flat[s1*s2]; + for i1=[0:s1] + for i2=[0:s2] + sum = 0 + for i0=[0:s0] + for i3=[0:s3] + sum = sum + input[i0][i1][i2][i3] + temp_flat[i1*s2 + i2] = sum + ElemWiseVectorPublicDiv(size=s1*s2, inp=temp_flat, divisor=s0*s3, out=out_flat) + for i1=[0:s1] + for i2=[0:s2] + output[i1][i2] = out_flat[i1*s2 + i2] + ''' + reduced_dims = node.reductionAxesList inputShape = node.expr.type.shape - for ii, curDim in enumerate(inputShape): - funcArgsList[IR.Int(curDim, 32)] = "InputShape_" + str(ii) + perm = [] + calculated_shape = [] + inputiters = self.getTempIterators(node.expr.type.dim) + outputiters = [] + no_elems = 1 + j = 0 + for i in range(len(inputShape)): + if i not in reduced_dims: + perm.append(i) + calculated_shape.append(inputShape[i]) + outputiters.append(inputiters[j]) + j = j + 1 + else: + no_elems = no_elems * inputShape[i] + if node.keepdims == 1: + calculated_shape.append(1) + outputiters.append(IR.Int(0,32)) + # perm will now be [ 1 ,2 ] + [ 0, 3] + perm.extend(reduced_dims) + loop_shape = [inputShape[perm[i]] for i in range(len(inputShape))] + outputShape = node.type.shape + assert(calculated_shape == outputShape) + + sumExpr = self.getTempVar() + sumExpr_decl = IR.Decl(sumExpr.idf, Type.Int()) + initSumCmd = IR.Assn(sumExpr, IRUtil.zero) + updateSumCmd = IR.Assn(sumExpr, IRUtil.add(sumExpr, IRUtil.addIndex(expr1, inputiters))) + + outer_nesting = len(inputShape) - len(reduced_dims) + temp_flat = self.getTempVar() + temp_flat_decl = IR.Decl(temp_flat.idf, + Type.Tensor([Util.get_volume(loop_shape[:outer_nesting])], node.type.bitlen, node.type.isSecret, node.type.taint), + isSecret=node.type.isSecret) + # i1*s2 + i2 + flat_idx_expr = IRUtil.getFlatArrIdxExpr(inputiters[:outer_nesting], loop_shape[:outer_nesting]) + # temp_flat[i1*s2 + i2] = sum + temp_flat_expr = IRUtil.addIndex(temp_flat, [flat_idx_expr]) + updateOutCmd = IR.Assn(temp_flat_expr, sumExpr) + + # Generate the sum loop + inner_loops_processed = 0 + sum_loop = [updateSumCmd] + for i in reversed(range(len(loop_shape))): + sum_loop = [IR.For(inputiters[i], 0, sum_loop, 0, endInt=loop_shape[i])] + inner_loops_processed+=1 + if(inner_loops_processed == len(reduced_dims)): + sum_loop = [initSumCmd] + sum_loop + [updateOutCmd] + + # Insert call to ElemWiseVectorPublicDiv(size=s1*s2, inp=temp_flat, divisor=s0*s3, out=out_flat) + out_flat = self.getTempVar() + out_flat_decl = IR.Decl(out_flat.idf, + Type.Tensor([Util.get_volume(loop_shape[:outer_nesting])], node.type.bitlen, node.type.isSecret, node.type.taint), + isSecret=node.type.isSecret) + argsDict = OrderedDict() + argsDict[IR.Int(Util.get_volume(loop_shape[:outer_nesting]), 32)] = "size" + argsDict[temp_flat] = "input" + argsDict[IR.Int(Util.get_volume(loop_shape[outer_nesting:]), 32)] = "divisor" + argsDict[out_flat] = "output" + div_call = IR.FuncCall("ElemWiseVectorPublicDiv", argsDict) + + # Free temp_flat here + # Clear temp arrays + argsDict = OrderedDict() + argsDict[IR.Int(Util.get_volume(loop_shape[:outer_nesting]), 32)] = "size" + argsDict[temp_flat] = "A" + free_temp_flat_call = IR.FuncCall("ClearMemSecret1", argsDict) + + # Unflatten the output + output = self.getTempVar() + output_decl = IR.Decl(output.idf, node.type) + out_expr = IRUtil.addIndex(output, outputiters) + out_flat_expr = IRUtil.addIndex(out_flat, [flat_idx_expr]) + out_assn_expr = IR.Assn(out_expr, out_flat_expr) + unflatten_loop = IRUtil.loop(loop_shape[:outer_nesting], inputiters[:outer_nesting], [out_assn_expr]) + + # Free out_flat here + argsDict = OrderedDict() + argsDict[IR.Int(Util.get_volume(loop_shape[:outer_nesting]), 32)] = "size" + argsDict[out_flat] = "A" + free_out_flat_call = IR.FuncCall("ClearMemSecret1", argsDict) + + if not(Util.Config.disableTruncOpti): + self.scaleFacMapping[output.idf] = self.scaleFacMapping[expr1.idf] - funcArgsList[expr1] = "inputArr" - funcArgsList[expr2] = "dimension" - funcArgsList[returnExpr] = "outArr" - funcCall = IR.FuncCall(funcName + self.varNameDelim + str(len(outputShape)) + self.varNameDelim + str(len(inputShape)), funcArgsList) comment = IR.Comment(str(node.metadata)) - prog_3 = IRUtil.prog_merge(prog_1, prog_2, IR.Prog([comment, funcCall])) + final_prog = IRUtil.prog_merge( prog_1, + IR.Prog([comment]), + IR.Prog([sumExpr_decl, temp_flat_decl, out_flat_decl, output_decl]), + IR.Prog(sum_loop), + IR.Prog([div_call]), + IR.Prog([free_temp_flat_call]), + IR.Prog(unflatten_loop), + IR.Prog([free_out_flat_call])) - prog_3 = IRUtil.prog_merge(IR.Prog([IR.Decl(returnExpr.idf, node.type)]), prog_3) - return (prog_3, returnExpr) + return (final_prog, output) def visitFusedBatchNorm(self, node:AST.FusedBatchNorm, args=None): (prog1, expr1) = self.visit(node.expr) @@ -1175,7 +1277,7 @@ class IRBuilderCSF(IRBuilderAST): addExpr_sf = self.scaleFacMapping[expr3.idf] if (expr_sf > self.scaleFac): #Scale down needed - progExtraBefore = self.addTruncateFunctionCall(node.expr, "FusedBatchNorm", expr1, expr_sf - self.scaleFac) + progExtraBefore = IRUtil.prog_merge(progExtraBefore, self.addTruncateFunctionCall(node.expr, "FusedBatchNorm", expr1, expr_sf - self.scaleFac)) self.scaleFacMapping[expr1.idf] = self.scaleFac if (multExpr_sf > self.scaleFac): diff --git a/Athos/SeeDot/Optimizations/LivenessOpti.py b/Athos/SeeDot/Optimizations/LivenessOpti.py index 161881f..d69131f 100644 --- a/Athos/SeeDot/Optimizations/LivenessOpti.py +++ b/Athos/SeeDot/Optimizations/LivenessOpti.py @@ -107,7 +107,7 @@ class LivenessAnalysis(ASTVisitor): return unboundVars def visitReduce(self, node:AST.Reduce, args): - unboundVars = list(set(self.visit(node.expr, args) + self.visit(node.dim, args) + self.visit(node.keepdims, args))) + unboundVars = list(set(self.visit(node.expr, args))) node.optidict[self.optidictKey] = unboundVars return unboundVars diff --git a/Athos/TFCompiler/TFNodesAST.py b/Athos/TFCompiler/TFNodesAST.py index 530f477..1a8f146 100644 --- a/Athos/TFCompiler/TFNodesAST.py +++ b/Athos/TFCompiler/TFNodesAST.py @@ -578,12 +578,18 @@ class TFNodesAST: keepdims = False if ("keep_dims" in attrMapRef): keepdims = attrMapRef["keep_dims"].getB() + + reductionAxesNodeName = inputsRef[1] + redAxesN = graph.__getitem__(reductionAxesNodeName) + redAxesT = redAxesN.getAttrVal("value").getTensor() + reductionAxesList = redAxesT.getContentAsValArr() + curNodeShapeLi = extraNodeInfoDict[curNode.getName()][0] return (None, { curNode.getName() : AST.Reduce(AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]), - AST.ID(dictNodeNameToOutVarStr[inputsRef[1]]), - AST.Int(int(keepdims), 32, isSecret=False), + keepdims, curNodeShapeLi, - TFNodesAST.getOperatorsIdx('+'))}) + TFNodesAST.getOperatorsIdx('+'), + reductionAxesList)}) def Mean(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dict, extraNodeInfoDict : dict): inputsRef = curNode.getInputsRef() @@ -592,12 +598,18 @@ class TFNodesAST: keepdims = False if ("keep_dims" in attrMapRef): keepdims = attrMapRef["keep_dims"].getB() + + reductionAxesNodeName = inputsRef[1] + redAxesN = graph.__getitem__(reductionAxesNodeName) + redAxesT = redAxesN.getAttrVal("value").getTensor() + reductionAxesList = redAxesT.getContentAsValArr() + curNodeShapeLi = extraNodeInfoDict[curNode.getName()][0] - return (None, { curNode.getName() : AST.Reduce(AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]), - AST.ID(dictNodeNameToOutVarStr[inputsRef[1]]), - AST.Int(int(keepdims), 32, isSecret=False), + return (None, { curNode.getName() : AST.Reduce(AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]), + keepdims, curNodeShapeLi, - TFNodesAST.getOperatorsIdx('mean'))}) + TFNodesAST.getOperatorsIdx('mean'), + reductionAxesList)}) def ArgMax(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dict, extraNodeInfoDict : dict): inputsRef = curNode.getInputsRef() -- GitLab