diff --git a/Athos/SeeDot/IR/IRBuilderCSF.py b/Athos/SeeDot/IR/IRBuilderCSF.py
index ea9b80833e84ab7f0cf6a0d6542f0cb86d9d08ca..6f889ae1b6327fcacf13f9d0776dca748da13f83 100644
--- a/Athos/SeeDot/IR/IRBuilderCSF.py
+++ b/Athos/SeeDot/IR/IRBuilderCSF.py
@@ -425,19 +425,16 @@ class IRBuilderCSF(IRBuilderAST):
 
 		op = node.op
 		if (op == AST.Operators.ADD):
-			(op_ir, op_fn) = (IR.Op.Op['+'], operator.add)
-			funcName = "MatAdd"
+			op_ir = IR.Op.Op['+']
 		elif (op == AST.Operators.SUB):
-			(op_ir, op_fn) = (IR.Op.Op['-'], operator.sub)
-			funcName = "MatSub"
+			op_ir = IR.Op.Op['-']
 		elif (op == AST.Operators.Equal):
-			(op_ir, op_fn) = (IR.Op.Op['=='], operator.eq)
-			funcName = "MatEqual"
+			op_ir = IR.Op.Op['==']
 		else:
 			assert False
 
-		typ_3 = node.type
-		expr_3 = self.getTempVar()
+		node_type = node.type
+		out_arr = self.getTempVar()
 		cmd0 = IR.Comment(expr_1.idf + ' ' + op_ir.name + ' ' + expr_2.idf)
 		comment = IR.Comment(str(node.metadata))
 
@@ -468,43 +465,54 @@ class IRBuilderCSF(IRBuilderAST):
 				argsDict[exprToScale] = "exprToScale, arg#{0}".format(2 if (expr1_sf>expr2_sf) else 1)
 				argsDict[IR.Int(scaleUpFactor, 32)] = "ScaleUpFactor"
 				funcCall = IR.FuncCall(curFuncName, argsDict)
-				curProg = IR.Prog([comm,funcCall])
+
+				if Type.isInt(typeOfExprToScale) or typeOfExprToScale.shape == []:
+					assn_expr = IR.Assn(exprToScale, funcCall)
+					curProg = IR.Prog([comm,assn_expr])
+				else:
+					curProg = IR.Prog([comm,funcCall])
 				prog_1 = IRUtil.prog_merge(curProg, prog_1)
 
-			self.scaleFacMapping[expr_3.idf] = self.scaleFacMapping[expr_1.idf]
+			self.scaleFacMapping[out_arr.idf] = self.scaleFacMapping[expr_1.idf]
 
-		if Type.isInt(typ_3):
-			decl = IR.Decl(expr_3.idf, typ_3, typ_3.bitlen, typ_3.isSecret)
-			assign = IR.Assn(expr_3, IR.IntBop(expr_1, op_ir, expr_2))
-			prog_3 = IRUtil.prog_merge(prog_1, prog_2, IR.Prog([comment, cmd0, decl, assign]))
-		else:
-			## TODO
-			if (node.type.dim != node.expr1.type.dim):
-				# This needs broadcast of expr1
-				assert False # For now this shouldn't occur
-			if (node.type.dim != node.expr2.type.dim):
-				# This needs broadcast of expr2
-				funcName += 'BroadCast'
-
-			outputShape = typ_3.shape
-			argsDict = OrderedDict()
-			inp1_shape = node.expr1.type.shape
-			inp2_shape = node.expr2.type.shape
-			for ii,curDimSize in enumerate(inp1_shape):
-				argsDict[IR.Int(curDimSize, 32)] = "size_" + str(ii)
-			for ii,curDimSize in enumerate(inp2_shape):
-				argsDict[IR.Int(curDimSize, 32)] = "size_" + str(ii)
-			for ii,curDimSize in enumerate(outputShape):
-				argsDict[IR.Int(curDimSize, 32)] = "size_" + str(ii)
-			argsDict[expr_1] = "A"
-			argsDict[expr_2] = "B"
-			argsDict[expr_3] = "C"
-			funcCall = IR.FuncCall(funcName + self.varNameDelim + str(len(outputShape)), argsDict)
-			prog_3 = IRUtil.prog_merge(prog_1, prog_2, IR.Prog([comment, cmd0, funcCall]))
-			prog_3 = IRUtil.prog_merge(IR.Prog([IR.Decl(expr_3.idf, node.type)]), prog_3)
-
-		return (prog_3, expr_3)
+		decl = IR.Decl(out_arr.idf, node_type, node_type.bitlen, node_type.isSecret)
 
+		if Type.isInt(node_type):
+			assign = IR.Assn(out_arr, IR.IntBop(expr_1, op_ir, expr_2))
+			out_prog = IR.Prog([assign])
+		else:
+			outputShape = node_type.shape
+			inp1_shape = [] if Type.isInt(node.expr1.type) else node.expr1.type.shape
+			inp2_shape = [] if Type.isInt(node.expr2.type) else node.expr2.type.shape
+
+			expected_output_shape, _, _ = Util.getBroadcastShapes(inp1_shape, inp2_shape)
+			assert(outputShape == expected_output_shape)
+			out_prog = IRUtil.generateBroadcastLoopBOp(expr_1, inp1_shape, expr_2, inp2_shape, out_arr, op_ir)
+
+		out_prog = IRUtil.prog_merge(IR.Prog([comment, cmd0, decl]), out_prog)
+		out_prog = IRUtil.prog_merge(prog_1, prog_2, out_prog)
+		return (out_prog, out_arr)
+
+
+	# We first reshape both inputs and flatten them into 1d dims.
+	# For simplicity consider a non-broadcast example:
+	# inputs : inp1_arr[s1][s2], inp2_arr[s1][s2]
+	# after flattening : inp1_arr_flat[s1*s2], inp2_arr_flat[s1*s2]
+	# for i1=[0:s1]
+	#   for i2=[0:s2]
+	#     idx = i1*s2 + i2
+	#     inp1_arr_flat[idx] = inp1_arr[i1][i2]
+	#     inp2_arr_flat[idx] = inp2_arr[i1][i2]
+	# If one input is from server and the other from model we can call an optimized version of mul
+	#   ElemWiseActModelVectorMult(s1*s2, inp1_arr_flat, inp2_arr_flat, out_arr_flat) <- optimized
+	# OR
+	#  ElemWiseSecretSharedVectorMult(s1*s2, inp1_arr_flat, inp2_arr_flat, out_arr_flat)
+	# Finally we reshape the flattened output
+	# for i1=[0:s1]
+	#   for i2=[0:s2]
+	#     idx = i1*s2 + i2
+	#     out_arr[i1][i2] = out_arr_flat[idx]
+	# Standard broadcast rules apply to generate these flattened tensors.
 	def visitBopElemWiseOp(self, node:AST.BOp, args=None):
 		(prog_1, expr_1) = self.visit(node.expr1)
 		(prog_2, expr_2) = self.visit(node.expr2)
@@ -515,38 +523,148 @@ class IRBuilderCSF(IRBuilderAST):
 		elif (node.op == AST.Operators.ElemWiseDiv):
 			op_ir = IR.Op.Op['./']
 			funcName = "ElemWiseDiv"
+			assert False, "Did not implement div yet"
+		else:
+			assert False, "Non mul/div elemwise op"
 
-		typ_3 = node.type
-		expr_3 = self.getTempVar()
+		comment = IR.Comment(str(node.metadata))
 		cmd0 = IR.Comment(expr_1.idf + ' ' + op_ir.name + ' ' + expr_2.idf)
-		outputShape = typ_3.shape
-		argsDict = OrderedDict()
-		for ii,curDimSize in enumerate(outputShape):
-			argsDict[IR.Int(curDimSize, 32)] = "size_" + str(ii)
-		argsDict[expr_1] = "A"
-		argsDict[expr_2] = "B"
-		argsDict[expr_3] = "C"
+
+		node_type = node.type
+		# outArr[s1][s2]
+		out_arr = self.getTempVar()
+		decl_out_arr = IR.Decl(out_arr.idf, node_type, node_type.bitlen, node_type.isSecret)
+
+		if Type.isInt(node_type):
+			assign = IR.Assn(out_arr, IR.IntBop(expr_1, op_ir, expr_2))
+			out_prog = IR.Prog([assign])
+		else:
+			# Flattening inputs
+			output_shape = node_type.shape
+			inp1_shape = [] if Type.isInt(node.expr1.type) else node.expr1.type.shape
+			inp2_shape = [] if Type.isInt(node.expr2.type) else node.expr2.type.shape
+			out_iters = self.getTempIterators(len(output_shape))
+			expected_output_shape, broadcast_mask_1, broadcast_mask_2 = Util.getBroadcastShapes(inp1_shape, inp2_shape)
+			assert(expected_output_shape == output_shape)
+
+			# inp1_arr[i1][i2], inp2_arr[i1][i2], out_arr[i1][i2]
+			inp1_iters = IRUtil.getMaskedIters(broadcast_mask_1, out_iters, inp1_shape)
+			inp2_iters = IRUtil.getMaskedIters(broadcast_mask_2, out_iters, inp2_shape)
+			inp1_arr_expr = IRUtil.addIndex(expr_1, inp1_iters)
+			inp2_arr_expr = IRUtil.addIndex(expr_2, inp2_iters)
+			out_arr_expr = IRUtil.addIndex(out_arr, out_iters)
+
+			flat_size = Util.get_volume(output_shape)
+			inp1_arr_flat = self.getTempVar()
+			inp2_arr_flat = self.getTempVar()
+			out_arr_flat = self.getTempVar()
+			flat_type = Type.Tensor([flat_size], node.expr1.type.bitlen, node.expr1.type.isSecret, node.expr1.type.taint)
+			# inp1_arr_flat[s1*s2]
+			# inp2_arr_flat[s1*s2]
+			# out_arr_flat[s1*s2]
+			decl_inp1_arr_flat = IR.Decl(inp1_arr_flat.idf, flat_type, node.expr1.type.bitlen, node.expr1.type.isSecret)
+			decl_inp2_arr_flat = IR.Decl(inp2_arr_flat.idf, flat_type, node.expr2.type.bitlen, node.expr2.type.isSecret)
+			decl_out_arr_flat = IR.Decl(out_arr_flat.idf, flat_type, node.type.bitlen, node.type.isSecret)
+			# idx
+			flat_idx = self.getTempVar()
+			decl_flat_idx = IR.Decl(flat_idx.idf, Type.Int(bitlen=32), bitlen=32, isSecret=False)
+			# For 4d, generate (i1*s2*s3*s4) + (i2*s3*s4) + (i3*s4) + (i4);
+			flat_idx_expr = IR.Int(0,32)
+			for i in range(len(out_iters)):
+				vol = Util.get_volume(output_shape[i+1:])
+				flat_idx_expr = IRUtil.add(flat_idx_expr, IRUtil.mul(out_iters[i], IR.Int(vol,32)))
+			# inp1_arr_flat[idx], inp2_arr_flat[idx], out_arr_flat[idx]
+			inp1_arr_flat_expr = IRUtil.addIndex(inp1_arr_flat, [flat_idx])
+			inp2_arr_flat_expr = IRUtil.addIndex(inp2_arr_flat, [flat_idx])
+			out_arr_flat_expr = IRUtil.addIndex(out_arr_flat, [flat_idx])
+			# idx = i1*s2 + i2;
+			# inp1_arr_flat[idx] = inp1_arr[i1][i2]
+			# inp2_arr_flat[idx] = inp2_arr[i1][i2]
+			assign_flat_idx_expr = IR.Assn(flat_idx, flat_idx_expr)
+			assign_inp1_arr_flat = IR.Assn(inp1_arr_flat_expr, inp1_arr_expr)
+			assign_inp2_arr_flat = IR.Assn(inp2_arr_flat_expr, inp2_arr_expr)
+			# Flattening loop
+			# for i1=[0:s1]
+			#   for i2=[0:s2]
+			#     idx = i1*s2 + i2
+			#     inp1_arr_flat[idx] = inp1_arr[i1][i2]
+			#     inp2_arr_flat[idx] = inp2_arr[i1][i2]
+			out_loop = IRUtil.loop(output_shape, out_iters, [assign_flat_idx_expr, assign_inp1_arr_flat, assign_inp2_arr_flat])
+			out_prog = IRUtil.Prog(out_loop)
+			decls = [decl_out_arr, decl_inp1_arr_flat, decl_inp2_arr_flat, decl_out_arr_flat, decl_flat_idx]
+			out_prog = IRUtil.prog_merge(IRUtil.Prog(decls), out_prog)
+
+			# Insert call to mul/div functionality
+			argsDict = OrderedDict()
+			argsDict[IR.Int(flat_size, 32)] = "input_shape"
+			if (node.op == AST.Operators.ElemWiseDiv):
+				argsDict[inp1_arr_flat] = "A"
+				argsDict[inp2_arr_flat] = "B"
+				funcName = "ElemwiseSuperDuperSecretDiv"
+				assert False, "Elemwise div not implemented"
+			else:
+				# If either input is a model weight we can use an optimised version for mul
+				# Otherwise if both are derived from client input we use the hadmaard version
+				isMulOptimised = False
+				if not(self.isModel(node.expr1)) and not(self.isModel(node.expr2)):
+					argsDict[inp1_arr_flat] = "A"
+					argsDict[inp2_arr_flat] = "B"
+				else:
+					isMulOptimised = True
+					# Optimised version expects the second parameter to be an input from server
+					if self.isModel(node.expr2):
+						argsDict[inp1_arr_flat] = "A"
+						argsDict[inp2_arr_flat] = "B"
+					else:
+						# Shuffle the params.
+						argsDict[inp2_arr_flat] = "A"
+						argsDict[inp1_arr_flat] = "B"
+				funcName = "ElemWiseActModelVectorMult" if isMulOptimised else "ElemWiseSecretSharedVectorMult"
+			argsDict[out_arr_flat] = "Output"
+			funcCall = IR.FuncCall(funcName, argsDict)
+			out_prog = IRUtil.prog_merge(out_prog, IRUtil.Prog([funcCall]))
+
+			# Clear temp arrays
+			argsDict = OrderedDict()
+			argsDict[IR.Int(flat_size, 32)] = "size"
+			argsDict[inp1_arr_flat] = "A"
+			funcCall = IR.FuncCall("ClearMemSecret1", argsDict)
+			out_prog = IRUtil.prog_merge(out_prog, IRUtil.Prog([funcCall]))
+			argsDict = OrderedDict()
+			argsDict[IR.Int(flat_size, 32)] = "size"
+			argsDict[inp2_arr_flat] = "A"
+			funcCall = IR.FuncCall("ClearMemSecret1", argsDict)
+			out_prog = IRUtil.prog_merge(out_prog, IRUtil.Prog([funcCall]))
+
+			# Unflatten output
+			assign_out_arr_flat = IR.Assn(out_arr_expr, out_arr_flat_expr)
+			out_loop = IRUtil.loop(output_shape, out_iters, [assign_flat_idx_expr, assign_out_arr_flat])
+			out_prog = IRUtil.prog_merge(out_prog, IRUtil.Prog(out_loop))
+
+			argsDict = OrderedDict()
+			argsDict[IR.Int(flat_size, 32)] = "size"
+			argsDict[out_arr_flat] = "A"
+			funcCall = IR.FuncCall("ClearMemSecret1", argsDict)
+			out_prog = IRUtil.prog_merge(out_prog, IRUtil.Prog([funcCall]))
+
 		progExtraBefore = IR.Prog([])
 		progExtraAfter = IR.Prog([])
 		if (Util.Config.disableTruncOpti):
-			progExtraAfter = self.addTruncateFunctionCall(node, funcName, expr_3, Util.Config.consSF)
+			progExtraAfter = self.addTruncateFunctionCall(node, "ElemWiseMul", out_arr, Util.Config.consSF)
 		else:
 			inputs_same = (expr_1.idf == expr_2.idf)
 			expr1_sf = self.scaleFacMapping[expr_1.idf]
 			expr2_sf = self.scaleFacMapping[expr_2.idf]
 			if (expr1_sf > self.scaleFac):
-				progExtraBefore = self.addTruncateFunctionCall(node.expr1, funcName, expr_1, expr1_sf-self.scaleFac)
+				progExtraBefore = self.addTruncateFunctionCall(node.expr1, "ElemWiseMul", expr_1, expr1_sf - self.scaleFac)
 				self.scaleFacMapping[expr_1.idf] = self.scaleFac
 			if (not inputs_same) and (expr2_sf > self.scaleFac):
-				progExtraBefore = IRUtil.prog_merge(progExtraBefore, self.addTruncateFunctionCall(node.expr2, funcName, expr_2, expr2_sf-self.scaleFac))
+				progExtraBefore = IRUtil.prog_merge(progExtraBefore, self.addTruncateFunctionCall(node.expr2, "ElemWiseMul", expr_2, expr2_sf - self.scaleFac))
 				self.scaleFacMapping[expr_2.idf] = self.scaleFac
-			self.scaleFacMapping[expr_3.idf] = 2*self.scaleFac
+			self.scaleFacMapping[out_arr.idf] = 2*self.scaleFac
 
-		funcCall = IR.FuncCall(funcName + self.varNameDelim + str(len(outputShape)), argsDict)
-		prog_3 = IRUtil.prog_merge(prog_1, prog_2, progExtraBefore, IR.Prog([cmd0, funcCall]))
-		prog_3 = IRUtil.prog_merge(IR.Prog([IR.Decl(expr_3.idf, node.type)]), prog_3, progExtraAfter)
-
-		return (prog_3, expr_3)
+		out_prog = IRUtil.prog_merge(IRUtil.Prog([comment, cmd0]), progExtraBefore, out_prog, progExtraAfter)
+		return (out_prog, out_arr)
 
 	def visitBopMul(self, node:AST.BOp, args=None):
 		typ_1 = node.expr1.type
diff --git a/Athos/SeeDot/IR/IRUtil.py b/Athos/SeeDot/IR/IRUtil.py
index 3963607646bd21ed1e1099f40cc7e9411da11728..1c590a71085638c74e913f2206fcedecf5c31f88 100644
--- a/Athos/SeeDot/IR/IRUtil.py
+++ b/Athos/SeeDot/IR/IRUtil.py
@@ -174,3 +174,58 @@ def print_loop(shape:list, iters:list, cmdl_body:CmdList, factor=0) -> CmdList:
 		cmdl_for = [For(iters[i], 0, lt(iters[i], Int(shape[i])), cmdl_for, factor), Print(Var('""'))]
 	return cmdl_for
 
+# For tensor A of shape = 7 x 1 x 5
+# And out_iters         = [i0, i1, i2, i3]
+# Broadcast mask        = [True, False, True, False]
+# We generate iters     = A[i1][0][i3]
+# If input is scalar, broadcast_mask=[] and inp_shape=[]
+def getMaskedIters(broadcast_mask: list, out_iters: list, inp_shape : list):
+	base_idx = len(out_iters) - len(inp_shape)
+	masked_iters = []
+	for i in range(len(broadcast_mask)):
+		if broadcast_mask[i]:
+			masked_iters.append(Int(0,32))
+		else:
+			masked_iters.append(out_iters[base_idx])
+		base_idx +=1
+	return masked_iters
+
+# Given input
+# A      (4d array):  8 x 1 x 6 x 1
+# B      (3d array):      7 x 1 x 5
+# We generate a loop with
+# Result (4d array):  8 x 7 x 6 x 5
+# for i0=[0:8]
+#   for i1=[0:7]
+#     for i2=[0:6]
+#       for i3=[0:8]
+#         Result[i0][i1][i2][i3] = A[i0][0][i2][0] + B[i1][0][i3]
+def generateBroadcastLoopBOp(expr_1, inp1_shape: list, expr_2, inp2_shape : list, expr_out, op: Op.Op):
+	output_shape, broadcast_mask_1, broadcast_mask_2 = Util.getBroadcastShapes(inp1_shape, inp2_shape)
+	out_iters = [Var('i' + str(i)) for i in range(len(output_shape))]
+	inp1_iters = getMaskedIters(broadcast_mask_1, out_iters, inp1_shape)
+	inp2_iters = getMaskedIters(broadcast_mask_2, out_iters, inp2_shape)
+
+	inp1_arr_expr = addIndex(expr_1, inp1_iters)
+	inp2_arr_expr = addIndex(expr_2, inp2_iters)
+	out_arr_expr = addIndex(expr_out, out_iters)
+
+	assign_expr = Assn(out_arr_expr, IntBop(inp1_arr_expr, op, inp2_arr_expr))
+	out_loop = loop(output_shape, out_iters, [assign_expr])
+	out_prog = Prog(out_loop)
+	return out_prog
+
+# Generates the index into a flattened tensor.
+# Example:
+# for i1=[0:s1]
+#   for i2=[0:s2]
+#     for i3=[0:s3]
+#       for i4=[0:s4]
+# generate (i1*s2*s3*s4) + (i2*s3*s4) + (i3*s4) + (i4);
+def getFlatArrIdxExpr(iters:list, shape:list):
+	assert len(iters) == len(shape), "No. of loop idx vars should be equal to loop shapes"
+	flat_idx_expr = Int(0,32)
+	for i in range(len(iters)):
+		vol = get_volume(shape[i+1:])
+		flat_idx_expr = add(flat_idx_expr, mul(iters[i], Int(vol,32)))
+	return flat_idx_expr
\ No newline at end of file
diff --git a/Athos/SeeDot/Type.py b/Athos/SeeDot/Type.py
index 0420e7518368d1e83a46851c425bcdf8081b93b8..4b6b5ec61b807a7888dde44d9c354e84591b79e0 100644
--- a/Athos/SeeDot/Type.py
+++ b/Athos/SeeDot/Type.py
@@ -274,11 +274,9 @@ class InferType(ASTVisitor):
 		node.expr2.gamma = dict(node.gamma)
 		fType = self.visit(node.expr2)
 
-		if node.op in [AST.Operators.ADD, AST.Operators.ElemWiseMul, AST.Operators.ElemWiseDiv]:
+		if node.op in [AST.Operators.ADD, AST.Operators.SUB, AST.Operators.Equal, AST.Operators.ElemWiseMul, AST.Operators.ElemWiseDiv]:
 			# Ops supporting broadcasting
 			return self.typeCheckBroadcastOps(node, eType, fType)
-		elif node.op in [AST.Operators.SUB, AST.Operators.Equal]:
-			return self.visitBopAddLike(node, eType, fType)
 		elif node.op == AST.Operators.MUL:
 			return self.visitBopMul(node, eType, fType)
 		elif node.op == AST.Operators.CONV:
@@ -293,35 +291,18 @@ class InferType(ASTVisitor):
 		# If adding a new op here which supports broadcasting, then be careful!
 		# Currently, its assumed the op is commutative. If that is not true, following will be wrong ! 
 
-		assert node.op in [AST.Operators.ADD, AST.Operators.ElemWiseMul, AST.Operators.ElemWiseDiv]
-		if (len(eType.shape) < len(fType.shape)):
-			# swap expr1 and expr2 -- this is valid for commutative ops
-			# be careful for ops which are not commutative
-			temp = node.expr1
-			node.expr1 = node.expr2
-			node.expr2 = temp
-
-			temp = eType
-			eType = fType
-			fType = temp
-		
-		# Now true that dim(eType) >= dim(fTYpe)
-		assert len(eType.shape) >= len(fType.shape)
-
+		assert node.op in [AST.Operators.ADD, AST.Operators.SUB, AST.Operators.Equal, AST.Operators.ElemWiseMul, AST.Operators.ElemWiseDiv]
 		if isInt(eType) and isInt(fType):
 			node.type = Int(eType.bitlen)
 		elif isTensor(eType) and isTensor(fType):
-			revETypeShape = eType.shape[::-1]
-			revFTypeShape = fType.shape[::-1]
-			for i, fTypeCurDim in enumerate(revFTypeShape):
-				eTypeCurDim = revETypeShape[i]
-				if not(eTypeCurDim==1 or fTypeCurDim==1 or eTypeCurDim==fTypeCurDim):
-					# broadcast not possible - raise error
-					print("Broadcast not possible for current node.", eType.shape, fType.shape)
-					assert False
-
-			# Broadcast possible
-			node.type = copy.copy(eType)
+			output_shape, _, _ = Util.getBroadcastShapes(eType.shape, fType.shape)
+			node.type = Tensor(shape=output_shape, bitlen=eType.bitlen)
+		elif isTensor(eType) and isInt(fType):
+			output_shape, _, _ = Util.getBroadcastShapes(eType.shape, [])
+			node.type = Tensor(shape=output_shape, bitlen=eType.bitlen)
+		elif isInt(eType) and isTensor(fType):
+			output_shape, _, _ = Util.getBroadcastShapes([], fType.shape)
+			node.type = Tensor(shape=output_shape, bitlen=eType.bitlen)
 		else:
 			print(eType, fType)
 			assert False
@@ -444,19 +425,6 @@ class InferType(ASTVisitor):
 		node.type = Tensor(shape, eType.bitlen, eType.isSecret | fType.isSecret, getTaint_type(eType, fType))
 		return node.type
 
-	def visitBopAddLike(self, node:AST.BOp, eType: Type, fType: Type, args=None):
-		if isInt(eType) and isInt(fType):
-			pass
-		elif isTensor(eType) and isTensor(fType):
-			assert eType.shape == fType.shape
-		else:
-			assert False
-		
-		node.type = copy.copy(eType)
-		node.type.taint = getTaint_type(eType, fType)
-		node.type.isSecret = eType.isSecret | fType.isSecret
-		return node.type
-
 	def visitFunc(self, node:AST.Func, args=None):
 		node.expr.gamma = dict(node.gamma)
 		eType = self.visit(node.expr)
diff --git a/Athos/SeeDot/Util.py b/Athos/SeeDot/Util.py
index 20ea4c2835ba43a3b0107df69ecef230ac87b3b6..319ee1f14076d2887e0d3330f2c676f8d1eb1f9d 100644
--- a/Athos/SeeDot/Util.py
+++ b/Athos/SeeDot/Util.py
@@ -81,3 +81,64 @@ def write_debug_info(name_mapping):
 	with open('debug/seedot_ezpc_name_map.txt', 'w') as f:
 		for val in name_mapping:
 			f.write(val + '   ' + name_mapping[val] + '\n')		
+
+# Broadcasting Rules:
+# A      (4d array):  8 x 1 x 6 x 1
+# B      (3d array):      7 x 1 x 5
+# Result (4d array):  8 x 7 x 6 x 5
+# Return Values
+# Shape A broadcast mask:  [False, True, False, True]
+# Shape B broadcast mask:  [True, False, True, False]
+# Result shape:  [8, 7, 6, 5]
+#
+# If input is a scalar, pass shape as []
+def getBroadcastShapes(Shape1 : list, Shape2 : list):
+	#Broadcast rules apply in reverse direction
+	shape1 = Shape1[::-1]
+	shape2 = Shape2[::-1]
+	len1 = len(shape1)
+	len2 = len(shape2)
+	outputshape = []
+	swapped = False
+	if len1 != len2:
+		if len1 > len2:
+			len1, len2 = len2, len1
+			shape1, shape2 = shape2, shape1
+			swapped = True
+		assert len1 < len2
+
+	broadcastMask1 = [False] * len1
+	broadcastMask2 = [False] * len2
+
+	for i in range(len2):
+		length = 0
+		if i >= len1:
+			#broadcastMask1[i] = True
+			outputshape.append(shape2[i])
+			continue
+		if shape1[i] != shape2[i]:
+			if shape1[i] == 1:
+				outputshape.append(shape2[i])
+				broadcastMask1[i] = True
+			elif shape2[i] == 1:
+				outputshape.append(shape1[i])
+				broadcastMask2[i] = True
+			else:
+				print("Dimension no. {} has a mismatch of length.".format(len2 - i))
+				assert False,  "Cannot broadcast. Program is malformed. Atleast one length should have been 1. i1: {} i2: {}".format(shape1[i], shape2[i])
+		else:
+			outputshape.append(shape1[i])
+
+	if swapped:
+		broadcastMask1, broadcastMask2 = broadcastMask2, broadcastMask1
+
+	outputshape.reverse()
+	broadcastMask1.reverse()
+	broadcastMask2.reverse()
+	return outputshape, broadcastMask1, broadcastMask2
+
+def get_volume(shape: list):
+	vol = 1
+	for i in shape:
+		vol = vol * i
+	return vol
\ No newline at end of file