From cee44f6d66150057721eafb7d78148f721b840b1 Mon Sep 17 00:00:00 2001
From: Bhatu <prbhatu@microsoft.com>
Date: Thu, 26 Nov 2020 17:43:49 +0530
Subject: [PATCH] Support for reduced mean

Adds support the reduce_mean operation in tensorflow.
Consider the example:
  For inputs:
    Tensor of shape(s0,s1,s2,s3)
    reduction axes = [0,3]

  We generate the following program:
    If keep_dim == true
      output is of shape(1,s1,s2,1)
    else
      output is of shape(s1,s2)

    for i1=[0:s1]
      for i2=[0:s2]
        sum = 0
        for i0=[0:s0]
          for i3=[0:s3]
            sum  = sum + input[i0][i1][i2][i3]
        output[i1][i2] = sum / (s0 * s3)        // keep_dim=false
  OR
        output[0][i1][i2][0] = sum / (s0 * s3)  // keep_dim=true

TODO: Also add support for reduced sum.
---
 Athos/SeeDot/AST/AST.py                    |   7 +-
 Athos/SeeDot/AST/ASTVisitor.py             |   2 -
 Athos/SeeDot/AST/MtdAST.py                 |   1 -
 Athos/SeeDot/AST/PrintAST.py               |   2 -
 Athos/SeeDot/IR/IRBuilderCSF.py            | 150 +++++++++++++++++----
 Athos/SeeDot/Optimizations/LivenessOpti.py |   2 +-
 Athos/TFCompiler/TFNodesAST.py             |  26 +++-
 7 files changed, 149 insertions(+), 41 deletions(-)

diff --git a/Athos/SeeDot/AST/AST.py b/Athos/SeeDot/AST/AST.py
index 1ecc85d..ca74704 100644
--- a/Athos/SeeDot/AST/AST.py
+++ b/Athos/SeeDot/AST/AST.py
@@ -356,21 +356,20 @@ class ArgMax(ASTNode):
 		self.inShape = inShape
 
 class Reduce(ASTNode):
-	def __init__(self, expr:ID, dim:ID, keepdims:Int, outShape:list, op: Operators):
+	def __init__(self, expr:ID, keepdims:bool, outShape:list, op: Operators, reductionAxesList: list):
 		# keepdims is unused for now
 		if assertInputTypes:
 			assert isinstance(expr, ID)
-			assert isinstance(dim, ID)
-			assert isinstance(keepdims, Int)
+			assert isinstance(keepdims, bool)
 			assert isinstance(outShape, list)
 			for elem in outShape: assert isinstance(elem, int)
 			assert isinstance(op, Operators)
 		super().__init__()
 		self.expr = expr
-		self.dim = dim
 		self.keepdims = keepdims
 		self.outShape = outShape
 		self.op = op
+		self.reductionAxesList = reductionAxesList
 
 # shape : list of int, dataType : ID
 # NOTE: Though datatype is being passed to this function, the output code eventually only has 
diff --git a/Athos/SeeDot/AST/ASTVisitor.py b/Athos/SeeDot/AST/ASTVisitor.py
index 1f4d663..03f04ad 100644
--- a/Athos/SeeDot/AST/ASTVisitor.py
+++ b/Athos/SeeDot/AST/ASTVisitor.py
@@ -76,8 +76,6 @@ class ASTVisitor:
 
 	def visitReduce(self, node:AST.Reduce, args=None):
 		self.visit(node.expr, args)
-		self.visit(node.dim, args)
-		self.visit(node.keepdims, args)
 
 	def visitInput(self, node:AST.Input, args=None):
 		pass
diff --git a/Athos/SeeDot/AST/MtdAST.py b/Athos/SeeDot/AST/MtdAST.py
index e9d4614..ef9a410 100644
--- a/Athos/SeeDot/AST/MtdAST.py
+++ b/Athos/SeeDot/AST/MtdAST.py
@@ -85,7 +85,6 @@ class MtdAST(ASTVisitor):
 	def visitReduce(self, node:AST.Reduce, mtd:dict):
 		node.metadata.update(mtd)
 		self.visit(node.expr, mtd)
-		self.visit(node.dim, mtd)
 
 	def visitInput(self, node:AST.Input, mtd:dict):
 		node.metadata.update(mtd)
diff --git a/Athos/SeeDot/AST/PrintAST.py b/Athos/SeeDot/AST/PrintAST.py
index 8402549..1ef915d 100644
--- a/Athos/SeeDot/AST/PrintAST.py
+++ b/Athos/SeeDot/AST/PrintAST.py
@@ -117,8 +117,6 @@ class PrintAST(ASTVisitor):
 	def visitReduce(self, node:AST.Reduce, args=None):
 		print(indent * node.depth, "reduce", AST.OperatorsSymbolDict[node.op.name], end=' ')
 		self.visit(node.expr)
-		self.visit(node.dim)
-		self.visit(node.keepdims)
 
 	def visitInput(self, node:AST.Input, args=None):
 		print(indent * node.depth, "input( ", node.shape, node.dataType, " <", node.inputByParty.name, "> ", end='')
diff --git a/Athos/SeeDot/IR/IRBuilderCSF.py b/Athos/SeeDot/IR/IRBuilderCSF.py
index c315b29..ea9b808 100644
--- a/Athos/SeeDot/IR/IRBuilderCSF.py
+++ b/Athos/SeeDot/IR/IRBuilderCSF.py
@@ -1116,37 +1116,139 @@ class IRBuilderCSF(IRBuilderAST):
 
 	def visitReduce(self, node:AST.Reduce, args=None):
 		(prog_1, expr1) = self.visit(node.expr)
-		(prog_2, expr2) = self.visit(node.dim)
-
-		returnExpr = self.getTempVar()
-
 		assert(node.op in [AST.Operators.ADD, AST.Operators.Mean])
-		if (node.op == AST.Operators.ADD):
-			funcName = "ReduceSum"
-		elif (node.op == AST.Operators.Mean):
-			funcName = "ReduceMean"
 
-		if not(Util.Config.disableTruncOpti):
-			self.scaleFacMapping[returnExpr.idf] = self.scaleFacMapping[expr1.idf]
+		# We already have the output shape so we dont need to calculate with keep_dims
 
-		funcArgsList = OrderedDict()
-		outputShape = node.type.shape
-		for ii, curDim in enumerate(outputShape):
-			funcArgsList[IR.Int(curDim, 32)] = "OutputShape_" + str(ii)
+		'''
+			We need to reduce across axes.
+			Example: Say reduction axes are specified as 0,3 and keep dim = false
+			output rank -> len(input_shape) - len(reduction_axes)
+			output is 2D.
+			for i1=[0:s1]
+				for i2=[0:s2]
+					sum = 0
+					for i0=[0:s0]
+						for i3=[0:s3]
+							sum  = sum + input[i0][i1][i2][i3]
+					output[i1][i2] = sum / (s0 * s3)
+			if keep dim == true, output rank is same as input. We generate:
+					output[0][i1][i2][0] = sum / (s0 * s3)
+
+			Ideally the above loop nest is what we would want to generate. But since we have
+			a division, we need to make calls to the div functionality and flatten the tensors.
+			temp_flat[s1*s2];
+			out_flat[s1*s2];
+			for i1=[0:s1]
+				for i2=[0:s2]
+					sum = 0
+					for i0=[0:s0]
+						for i3=[0:s3]
+							sum  = sum + input[i0][i1][i2][i3]
+					temp_flat[i1*s2 + i2] = sum
+			ElemWiseVectorPublicDiv(size=s1*s2, inp=temp_flat, divisor=s0*s3, out=out_flat)
+			for i1=[0:s1]
+				for i2=[0:s2]
+				  output[i1][i2] = out_flat[i1*s2 + i2]
 
+		'''
+		reduced_dims = node.reductionAxesList
 		inputShape = node.expr.type.shape
-		for ii, curDim in enumerate(inputShape):
-			funcArgsList[IR.Int(curDim, 32)] = "InputShape_" + str(ii)
+		perm = []
+		calculated_shape = []
+		inputiters = self.getTempIterators(node.expr.type.dim)
+		outputiters = []
+		no_elems = 1
+		j = 0
+		for i in range(len(inputShape)):
+			if i not in reduced_dims:
+				perm.append(i)
+				calculated_shape.append(inputShape[i])
+				outputiters.append(inputiters[j])
+				j = j + 1
+			else:
+				no_elems = no_elems * inputShape[i]
+				if node.keepdims == 1:
+					calculated_shape.append(1)
+					outputiters.append(IR.Int(0,32))
+		# perm will now be [ 1 ,2 ] + [ 0, 3]
+		perm.extend(reduced_dims)
+		loop_shape = [inputShape[perm[i]] for i in range(len(inputShape))]
+		outputShape = node.type.shape
+		assert(calculated_shape == outputShape)
+
+		sumExpr = self.getTempVar()
+		sumExpr_decl = IR.Decl(sumExpr.idf, Type.Int())
+		initSumCmd = IR.Assn(sumExpr, IRUtil.zero)
+		updateSumCmd = IR.Assn(sumExpr, IRUtil.add(sumExpr, IRUtil.addIndex(expr1, inputiters)))
+
+		outer_nesting = len(inputShape) - len(reduced_dims)
+		temp_flat = self.getTempVar()
+		temp_flat_decl = IR.Decl(temp_flat.idf,
+								Type.Tensor([Util.get_volume(loop_shape[:outer_nesting])], node.type.bitlen, node.type.isSecret, node.type.taint),
+								isSecret=node.type.isSecret)
+		# i1*s2 + i2
+		flat_idx_expr = IRUtil.getFlatArrIdxExpr(inputiters[:outer_nesting], loop_shape[:outer_nesting])
+		# temp_flat[i1*s2 + i2] = sum
+		temp_flat_expr = IRUtil.addIndex(temp_flat, [flat_idx_expr])
+		updateOutCmd = IR.Assn(temp_flat_expr, sumExpr)
+
+		# Generate the sum loop
+		inner_loops_processed = 0
+		sum_loop = [updateSumCmd]
+		for i in reversed(range(len(loop_shape))):
+			sum_loop = [IR.For(inputiters[i], 0, sum_loop, 0, endInt=loop_shape[i])]
+			inner_loops_processed+=1
+			if(inner_loops_processed == len(reduced_dims)):
+				sum_loop = [initSumCmd] + sum_loop + [updateOutCmd]
+
+		# Insert call to ElemWiseVectorPublicDiv(size=s1*s2, inp=temp_flat, divisor=s0*s3, out=out_flat)
+		out_flat = self.getTempVar()
+		out_flat_decl = IR.Decl(out_flat.idf,
+								Type.Tensor([Util.get_volume(loop_shape[:outer_nesting])], node.type.bitlen, node.type.isSecret, node.type.taint),
+								isSecret=node.type.isSecret)
+		argsDict = OrderedDict()
+		argsDict[IR.Int(Util.get_volume(loop_shape[:outer_nesting]), 32)] = "size"
+		argsDict[temp_flat] = "input"
+		argsDict[IR.Int(Util.get_volume(loop_shape[outer_nesting:]), 32)] = "divisor"
+		argsDict[out_flat] = "output"
+		div_call = IR.FuncCall("ElemWiseVectorPublicDiv", argsDict)
+
+		# Free temp_flat here
+		# Clear temp arrays
+		argsDict = OrderedDict()
+		argsDict[IR.Int(Util.get_volume(loop_shape[:outer_nesting]), 32)] = "size"
+		argsDict[temp_flat] = "A"
+		free_temp_flat_call = IR.FuncCall("ClearMemSecret1", argsDict)
+
+		# Unflatten the output
+		output = self.getTempVar()
+		output_decl =  IR.Decl(output.idf, node.type)
+		out_expr = IRUtil.addIndex(output, outputiters)
+		out_flat_expr = IRUtil.addIndex(out_flat, [flat_idx_expr])
+		out_assn_expr = IR.Assn(out_expr, out_flat_expr)
+		unflatten_loop = IRUtil.loop(loop_shape[:outer_nesting], inputiters[:outer_nesting], [out_assn_expr])
+
+		# Free out_flat here
+		argsDict = OrderedDict()
+		argsDict[IR.Int(Util.get_volume(loop_shape[:outer_nesting]), 32)] = "size"
+		argsDict[out_flat] = "A"
+		free_out_flat_call = IR.FuncCall("ClearMemSecret1", argsDict)
+
+		if not(Util.Config.disableTruncOpti):
+			self.scaleFacMapping[output.idf] = self.scaleFacMapping[expr1.idf]
 
-		funcArgsList[expr1] = "inputArr"
-		funcArgsList[expr2] = "dimension"
-		funcArgsList[returnExpr] = "outArr"
-		funcCall = IR.FuncCall(funcName + self.varNameDelim + str(len(outputShape)) + self.varNameDelim + str(len(inputShape)), funcArgsList)
 		comment = IR.Comment(str(node.metadata))
-		prog_3 = IRUtil.prog_merge(prog_1, prog_2, IR.Prog([comment, funcCall]))
+		final_prog = IRUtil.prog_merge(	prog_1,
+										IR.Prog([comment]),
+										IR.Prog([sumExpr_decl, temp_flat_decl, out_flat_decl, output_decl]),
+										IR.Prog(sum_loop),
+										IR.Prog([div_call]),
+										IR.Prog([free_temp_flat_call]),
+										IR.Prog(unflatten_loop),
+										IR.Prog([free_out_flat_call]))
 
-		prog_3 = IRUtil.prog_merge(IR.Prog([IR.Decl(returnExpr.idf, node.type)]), prog_3)
-		return (prog_3, returnExpr)
+		return (final_prog, output)
 
 	def visitFusedBatchNorm(self, node:AST.FusedBatchNorm, args=None):
 		(prog1, expr1) = self.visit(node.expr)
@@ -1175,7 +1277,7 @@ class IRBuilderCSF(IRBuilderAST):
 			addExpr_sf = self.scaleFacMapping[expr3.idf]
 			if (expr_sf > self.scaleFac):
 				#Scale down needed
-				progExtraBefore = self.addTruncateFunctionCall(node.expr, "FusedBatchNorm", expr1, expr_sf - self.scaleFac)
+				progExtraBefore = IRUtil.prog_merge(progExtraBefore, self.addTruncateFunctionCall(node.expr, "FusedBatchNorm", expr1, expr_sf - self.scaleFac))
 				self.scaleFacMapping[expr1.idf] = self.scaleFac
 
 			if (multExpr_sf > self.scaleFac):
diff --git a/Athos/SeeDot/Optimizations/LivenessOpti.py b/Athos/SeeDot/Optimizations/LivenessOpti.py
index 161881f..d69131f 100644
--- a/Athos/SeeDot/Optimizations/LivenessOpti.py
+++ b/Athos/SeeDot/Optimizations/LivenessOpti.py
@@ -107,7 +107,7 @@ class LivenessAnalysis(ASTVisitor):
 		return unboundVars
 
 	def visitReduce(self, node:AST.Reduce, args):
-		unboundVars = list(set(self.visit(node.expr, args) + self.visit(node.dim, args) + self.visit(node.keepdims, args)))
+		unboundVars = list(set(self.visit(node.expr, args)))
 		node.optidict[self.optidictKey] = unboundVars
 		return unboundVars
 
diff --git a/Athos/TFCompiler/TFNodesAST.py b/Athos/TFCompiler/TFNodesAST.py
index 530f477..1a8f146 100644
--- a/Athos/TFCompiler/TFNodesAST.py
+++ b/Athos/TFCompiler/TFNodesAST.py
@@ -578,12 +578,18 @@ class TFNodesAST:
 		keepdims = False
 		if ("keep_dims" in attrMapRef):
 			keepdims = attrMapRef["keep_dims"].getB()
+
+		reductionAxesNodeName = inputsRef[1]
+		redAxesN = graph.__getitem__(reductionAxesNodeName)
+		redAxesT = redAxesN.getAttrVal("value").getTensor()
+		reductionAxesList = redAxesT.getContentAsValArr()
+
 		curNodeShapeLi = extraNodeInfoDict[curNode.getName()][0]
 		return (None, { curNode.getName() : AST.Reduce(AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]), 
-								 AST.ID(dictNodeNameToOutVarStr[inputsRef[1]]),
-								 AST.Int(int(keepdims), 32, isSecret=False),
+								 keepdims,
 								 curNodeShapeLi,
-								 TFNodesAST.getOperatorsIdx('+'))})
+								 TFNodesAST.getOperatorsIdx('+'),
+								 reductionAxesList)})
 
 	def Mean(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dict, extraNodeInfoDict : dict):
 		inputsRef = curNode.getInputsRef()
@@ -592,12 +598,18 @@ class TFNodesAST:
 		keepdims = False
 		if ("keep_dims" in attrMapRef):
 			keepdims = attrMapRef["keep_dims"].getB()
+
+		reductionAxesNodeName = inputsRef[1]
+		redAxesN = graph.__getitem__(reductionAxesNodeName)
+		redAxesT = redAxesN.getAttrVal("value").getTensor()
+		reductionAxesList = redAxesT.getContentAsValArr()
+
 		curNodeShapeLi = extraNodeInfoDict[curNode.getName()][0]
-		return (None, { curNode.getName() : AST.Reduce(AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]), 
-								AST.ID(dictNodeNameToOutVarStr[inputsRef[1]]), 
-								AST.Int(int(keepdims), 32, isSecret=False),
+		return (None, { curNode.getName() : AST.Reduce(AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]),  
+								keepdims,
 								curNodeShapeLi,
-								TFNodesAST.getOperatorsIdx('mean'))})
+								TFNodesAST.getOperatorsIdx('mean'),
+								reductionAxesList)})
 
 	def ArgMax(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dict, extraNodeInfoDict : dict):
 		inputsRef = curNode.getInputsRef()
-- 
GitLab