From c449271d373b693fded6924eab27453350fd793c Mon Sep 17 00:00:00 2001 From: Bhatu <prbhatu@microsoft.com> Date: Thu, 26 Nov 2020 18:10:06 +0530 Subject: [PATCH] Add support for broadcasting semantics for binops. We add broadcasting support for add, sub, mul and equal. The broadcasting semantics are specified here https://numpy.org/doc/stable/user/basics.broadcasting.html Say we are given input A (4d array): 8 x 1 x 6 x 1 B (3d array): 7 x 1 x 5 We generate a loop with Result (4d array): 8 x 7 x 6 x 5 for i0=[0:8] for i1=[0:7] for i2=[0:6] for i3=[0:8] Result[i0][i1][i2][i3] = A[i0][0][i2][0] {+,*,-,==} B[i1][0][i3] --- Athos/SeeDot/IR/IRBuilderCSF.py | 236 ++++++++++++++++++++++++-------- Athos/SeeDot/IR/IRUtil.py | 55 ++++++++ Athos/SeeDot/Type.py | 52 ++----- Athos/SeeDot/Util.py | 61 +++++++++ 4 files changed, 303 insertions(+), 101 deletions(-) diff --git a/Athos/SeeDot/IR/IRBuilderCSF.py b/Athos/SeeDot/IR/IRBuilderCSF.py index ea9b808..6f889ae 100644 --- a/Athos/SeeDot/IR/IRBuilderCSF.py +++ b/Athos/SeeDot/IR/IRBuilderCSF.py @@ -425,19 +425,16 @@ class IRBuilderCSF(IRBuilderAST): op = node.op if (op == AST.Operators.ADD): - (op_ir, op_fn) = (IR.Op.Op['+'], operator.add) - funcName = "MatAdd" + op_ir = IR.Op.Op['+'] elif (op == AST.Operators.SUB): - (op_ir, op_fn) = (IR.Op.Op['-'], operator.sub) - funcName = "MatSub" + op_ir = IR.Op.Op['-'] elif (op == AST.Operators.Equal): - (op_ir, op_fn) = (IR.Op.Op['=='], operator.eq) - funcName = "MatEqual" + op_ir = IR.Op.Op['=='] else: assert False - typ_3 = node.type - expr_3 = self.getTempVar() + node_type = node.type + out_arr = self.getTempVar() cmd0 = IR.Comment(expr_1.idf + ' ' + op_ir.name + ' ' + expr_2.idf) comment = IR.Comment(str(node.metadata)) @@ -468,43 +465,54 @@ class IRBuilderCSF(IRBuilderAST): argsDict[exprToScale] = "exprToScale, arg#{0}".format(2 if (expr1_sf>expr2_sf) else 1) argsDict[IR.Int(scaleUpFactor, 32)] = "ScaleUpFactor" funcCall = IR.FuncCall(curFuncName, argsDict) - curProg = IR.Prog([comm,funcCall]) + + if Type.isInt(typeOfExprToScale) or typeOfExprToScale.shape == []: + assn_expr = IR.Assn(exprToScale, funcCall) + curProg = IR.Prog([comm,assn_expr]) + else: + curProg = IR.Prog([comm,funcCall]) prog_1 = IRUtil.prog_merge(curProg, prog_1) - self.scaleFacMapping[expr_3.idf] = self.scaleFacMapping[expr_1.idf] + self.scaleFacMapping[out_arr.idf] = self.scaleFacMapping[expr_1.idf] - if Type.isInt(typ_3): - decl = IR.Decl(expr_3.idf, typ_3, typ_3.bitlen, typ_3.isSecret) - assign = IR.Assn(expr_3, IR.IntBop(expr_1, op_ir, expr_2)) - prog_3 = IRUtil.prog_merge(prog_1, prog_2, IR.Prog([comment, cmd0, decl, assign])) - else: - ## TODO - if (node.type.dim != node.expr1.type.dim): - # This needs broadcast of expr1 - assert False # For now this shouldn't occur - if (node.type.dim != node.expr2.type.dim): - # This needs broadcast of expr2 - funcName += 'BroadCast' - - outputShape = typ_3.shape - argsDict = OrderedDict() - inp1_shape = node.expr1.type.shape - inp2_shape = node.expr2.type.shape - for ii,curDimSize in enumerate(inp1_shape): - argsDict[IR.Int(curDimSize, 32)] = "size_" + str(ii) - for ii,curDimSize in enumerate(inp2_shape): - argsDict[IR.Int(curDimSize, 32)] = "size_" + str(ii) - for ii,curDimSize in enumerate(outputShape): - argsDict[IR.Int(curDimSize, 32)] = "size_" + str(ii) - argsDict[expr_1] = "A" - argsDict[expr_2] = "B" - argsDict[expr_3] = "C" - funcCall = IR.FuncCall(funcName + self.varNameDelim + str(len(outputShape)), argsDict) - prog_3 = IRUtil.prog_merge(prog_1, prog_2, IR.Prog([comment, cmd0, funcCall])) - prog_3 = IRUtil.prog_merge(IR.Prog([IR.Decl(expr_3.idf, node.type)]), prog_3) - - return (prog_3, expr_3) + decl = IR.Decl(out_arr.idf, node_type, node_type.bitlen, node_type.isSecret) + if Type.isInt(node_type): + assign = IR.Assn(out_arr, IR.IntBop(expr_1, op_ir, expr_2)) + out_prog = IR.Prog([assign]) + else: + outputShape = node_type.shape + inp1_shape = [] if Type.isInt(node.expr1.type) else node.expr1.type.shape + inp2_shape = [] if Type.isInt(node.expr2.type) else node.expr2.type.shape + + expected_output_shape, _, _ = Util.getBroadcastShapes(inp1_shape, inp2_shape) + assert(outputShape == expected_output_shape) + out_prog = IRUtil.generateBroadcastLoopBOp(expr_1, inp1_shape, expr_2, inp2_shape, out_arr, op_ir) + + out_prog = IRUtil.prog_merge(IR.Prog([comment, cmd0, decl]), out_prog) + out_prog = IRUtil.prog_merge(prog_1, prog_2, out_prog) + return (out_prog, out_arr) + + + # We first reshape both inputs and flatten them into 1d dims. + # For simplicity consider a non-broadcast example: + # inputs : inp1_arr[s1][s2], inp2_arr[s1][s2] + # after flattening : inp1_arr_flat[s1*s2], inp2_arr_flat[s1*s2] + # for i1=[0:s1] + # for i2=[0:s2] + # idx = i1*s2 + i2 + # inp1_arr_flat[idx] = inp1_arr[i1][i2] + # inp2_arr_flat[idx] = inp2_arr[i1][i2] + # If one input is from server and the other from model we can call an optimized version of mul + # ElemWiseActModelVectorMult(s1*s2, inp1_arr_flat, inp2_arr_flat, out_arr_flat) <- optimized + # OR + # ElemWiseSecretSharedVectorMult(s1*s2, inp1_arr_flat, inp2_arr_flat, out_arr_flat) + # Finally we reshape the flattened output + # for i1=[0:s1] + # for i2=[0:s2] + # idx = i1*s2 + i2 + # out_arr[i1][i2] = out_arr_flat[idx] + # Standard broadcast rules apply to generate these flattened tensors. def visitBopElemWiseOp(self, node:AST.BOp, args=None): (prog_1, expr_1) = self.visit(node.expr1) (prog_2, expr_2) = self.visit(node.expr2) @@ -515,38 +523,148 @@ class IRBuilderCSF(IRBuilderAST): elif (node.op == AST.Operators.ElemWiseDiv): op_ir = IR.Op.Op['./'] funcName = "ElemWiseDiv" + assert False, "Did not implement div yet" + else: + assert False, "Non mul/div elemwise op" - typ_3 = node.type - expr_3 = self.getTempVar() + comment = IR.Comment(str(node.metadata)) cmd0 = IR.Comment(expr_1.idf + ' ' + op_ir.name + ' ' + expr_2.idf) - outputShape = typ_3.shape - argsDict = OrderedDict() - for ii,curDimSize in enumerate(outputShape): - argsDict[IR.Int(curDimSize, 32)] = "size_" + str(ii) - argsDict[expr_1] = "A" - argsDict[expr_2] = "B" - argsDict[expr_3] = "C" + + node_type = node.type + # outArr[s1][s2] + out_arr = self.getTempVar() + decl_out_arr = IR.Decl(out_arr.idf, node_type, node_type.bitlen, node_type.isSecret) + + if Type.isInt(node_type): + assign = IR.Assn(out_arr, IR.IntBop(expr_1, op_ir, expr_2)) + out_prog = IR.Prog([assign]) + else: + # Flattening inputs + output_shape = node_type.shape + inp1_shape = [] if Type.isInt(node.expr1.type) else node.expr1.type.shape + inp2_shape = [] if Type.isInt(node.expr2.type) else node.expr2.type.shape + out_iters = self.getTempIterators(len(output_shape)) + expected_output_shape, broadcast_mask_1, broadcast_mask_2 = Util.getBroadcastShapes(inp1_shape, inp2_shape) + assert(expected_output_shape == output_shape) + + # inp1_arr[i1][i2], inp2_arr[i1][i2], out_arr[i1][i2] + inp1_iters = IRUtil.getMaskedIters(broadcast_mask_1, out_iters, inp1_shape) + inp2_iters = IRUtil.getMaskedIters(broadcast_mask_2, out_iters, inp2_shape) + inp1_arr_expr = IRUtil.addIndex(expr_1, inp1_iters) + inp2_arr_expr = IRUtil.addIndex(expr_2, inp2_iters) + out_arr_expr = IRUtil.addIndex(out_arr, out_iters) + + flat_size = Util.get_volume(output_shape) + inp1_arr_flat = self.getTempVar() + inp2_arr_flat = self.getTempVar() + out_arr_flat = self.getTempVar() + flat_type = Type.Tensor([flat_size], node.expr1.type.bitlen, node.expr1.type.isSecret, node.expr1.type.taint) + # inp1_arr_flat[s1*s2] + # inp2_arr_flat[s1*s2] + # out_arr_flat[s1*s2] + decl_inp1_arr_flat = IR.Decl(inp1_arr_flat.idf, flat_type, node.expr1.type.bitlen, node.expr1.type.isSecret) + decl_inp2_arr_flat = IR.Decl(inp2_arr_flat.idf, flat_type, node.expr2.type.bitlen, node.expr2.type.isSecret) + decl_out_arr_flat = IR.Decl(out_arr_flat.idf, flat_type, node.type.bitlen, node.type.isSecret) + # idx + flat_idx = self.getTempVar() + decl_flat_idx = IR.Decl(flat_idx.idf, Type.Int(bitlen=32), bitlen=32, isSecret=False) + # For 4d, generate (i1*s2*s3*s4) + (i2*s3*s4) + (i3*s4) + (i4); + flat_idx_expr = IR.Int(0,32) + for i in range(len(out_iters)): + vol = Util.get_volume(output_shape[i+1:]) + flat_idx_expr = IRUtil.add(flat_idx_expr, IRUtil.mul(out_iters[i], IR.Int(vol,32))) + # inp1_arr_flat[idx], inp2_arr_flat[idx], out_arr_flat[idx] + inp1_arr_flat_expr = IRUtil.addIndex(inp1_arr_flat, [flat_idx]) + inp2_arr_flat_expr = IRUtil.addIndex(inp2_arr_flat, [flat_idx]) + out_arr_flat_expr = IRUtil.addIndex(out_arr_flat, [flat_idx]) + # idx = i1*s2 + i2; + # inp1_arr_flat[idx] = inp1_arr[i1][i2] + # inp2_arr_flat[idx] = inp2_arr[i1][i2] + assign_flat_idx_expr = IR.Assn(flat_idx, flat_idx_expr) + assign_inp1_arr_flat = IR.Assn(inp1_arr_flat_expr, inp1_arr_expr) + assign_inp2_arr_flat = IR.Assn(inp2_arr_flat_expr, inp2_arr_expr) + # Flattening loop + # for i1=[0:s1] + # for i2=[0:s2] + # idx = i1*s2 + i2 + # inp1_arr_flat[idx] = inp1_arr[i1][i2] + # inp2_arr_flat[idx] = inp2_arr[i1][i2] + out_loop = IRUtil.loop(output_shape, out_iters, [assign_flat_idx_expr, assign_inp1_arr_flat, assign_inp2_arr_flat]) + out_prog = IRUtil.Prog(out_loop) + decls = [decl_out_arr, decl_inp1_arr_flat, decl_inp2_arr_flat, decl_out_arr_flat, decl_flat_idx] + out_prog = IRUtil.prog_merge(IRUtil.Prog(decls), out_prog) + + # Insert call to mul/div functionality + argsDict = OrderedDict() + argsDict[IR.Int(flat_size, 32)] = "input_shape" + if (node.op == AST.Operators.ElemWiseDiv): + argsDict[inp1_arr_flat] = "A" + argsDict[inp2_arr_flat] = "B" + funcName = "ElemwiseSuperDuperSecretDiv" + assert False, "Elemwise div not implemented" + else: + # If either input is a model weight we can use an optimised version for mul + # Otherwise if both are derived from client input we use the hadmaard version + isMulOptimised = False + if not(self.isModel(node.expr1)) and not(self.isModel(node.expr2)): + argsDict[inp1_arr_flat] = "A" + argsDict[inp2_arr_flat] = "B" + else: + isMulOptimised = True + # Optimised version expects the second parameter to be an input from server + if self.isModel(node.expr2): + argsDict[inp1_arr_flat] = "A" + argsDict[inp2_arr_flat] = "B" + else: + # Shuffle the params. + argsDict[inp2_arr_flat] = "A" + argsDict[inp1_arr_flat] = "B" + funcName = "ElemWiseActModelVectorMult" if isMulOptimised else "ElemWiseSecretSharedVectorMult" + argsDict[out_arr_flat] = "Output" + funcCall = IR.FuncCall(funcName, argsDict) + out_prog = IRUtil.prog_merge(out_prog, IRUtil.Prog([funcCall])) + + # Clear temp arrays + argsDict = OrderedDict() + argsDict[IR.Int(flat_size, 32)] = "size" + argsDict[inp1_arr_flat] = "A" + funcCall = IR.FuncCall("ClearMemSecret1", argsDict) + out_prog = IRUtil.prog_merge(out_prog, IRUtil.Prog([funcCall])) + argsDict = OrderedDict() + argsDict[IR.Int(flat_size, 32)] = "size" + argsDict[inp2_arr_flat] = "A" + funcCall = IR.FuncCall("ClearMemSecret1", argsDict) + out_prog = IRUtil.prog_merge(out_prog, IRUtil.Prog([funcCall])) + + # Unflatten output + assign_out_arr_flat = IR.Assn(out_arr_expr, out_arr_flat_expr) + out_loop = IRUtil.loop(output_shape, out_iters, [assign_flat_idx_expr, assign_out_arr_flat]) + out_prog = IRUtil.prog_merge(out_prog, IRUtil.Prog(out_loop)) + + argsDict = OrderedDict() + argsDict[IR.Int(flat_size, 32)] = "size" + argsDict[out_arr_flat] = "A" + funcCall = IR.FuncCall("ClearMemSecret1", argsDict) + out_prog = IRUtil.prog_merge(out_prog, IRUtil.Prog([funcCall])) + progExtraBefore = IR.Prog([]) progExtraAfter = IR.Prog([]) if (Util.Config.disableTruncOpti): - progExtraAfter = self.addTruncateFunctionCall(node, funcName, expr_3, Util.Config.consSF) + progExtraAfter = self.addTruncateFunctionCall(node, "ElemWiseMul", out_arr, Util.Config.consSF) else: inputs_same = (expr_1.idf == expr_2.idf) expr1_sf = self.scaleFacMapping[expr_1.idf] expr2_sf = self.scaleFacMapping[expr_2.idf] if (expr1_sf > self.scaleFac): - progExtraBefore = self.addTruncateFunctionCall(node.expr1, funcName, expr_1, expr1_sf-self.scaleFac) + progExtraBefore = self.addTruncateFunctionCall(node.expr1, "ElemWiseMul", expr_1, expr1_sf - self.scaleFac) self.scaleFacMapping[expr_1.idf] = self.scaleFac if (not inputs_same) and (expr2_sf > self.scaleFac): - progExtraBefore = IRUtil.prog_merge(progExtraBefore, self.addTruncateFunctionCall(node.expr2, funcName, expr_2, expr2_sf-self.scaleFac)) + progExtraBefore = IRUtil.prog_merge(progExtraBefore, self.addTruncateFunctionCall(node.expr2, "ElemWiseMul", expr_2, expr2_sf - self.scaleFac)) self.scaleFacMapping[expr_2.idf] = self.scaleFac - self.scaleFacMapping[expr_3.idf] = 2*self.scaleFac + self.scaleFacMapping[out_arr.idf] = 2*self.scaleFac - funcCall = IR.FuncCall(funcName + self.varNameDelim + str(len(outputShape)), argsDict) - prog_3 = IRUtil.prog_merge(prog_1, prog_2, progExtraBefore, IR.Prog([cmd0, funcCall])) - prog_3 = IRUtil.prog_merge(IR.Prog([IR.Decl(expr_3.idf, node.type)]), prog_3, progExtraAfter) - - return (prog_3, expr_3) + out_prog = IRUtil.prog_merge(IRUtil.Prog([comment, cmd0]), progExtraBefore, out_prog, progExtraAfter) + return (out_prog, out_arr) def visitBopMul(self, node:AST.BOp, args=None): typ_1 = node.expr1.type diff --git a/Athos/SeeDot/IR/IRUtil.py b/Athos/SeeDot/IR/IRUtil.py index 3963607..1c590a7 100644 --- a/Athos/SeeDot/IR/IRUtil.py +++ b/Athos/SeeDot/IR/IRUtil.py @@ -174,3 +174,58 @@ def print_loop(shape:list, iters:list, cmdl_body:CmdList, factor=0) -> CmdList: cmdl_for = [For(iters[i], 0, lt(iters[i], Int(shape[i])), cmdl_for, factor), Print(Var('""'))] return cmdl_for +# For tensor A of shape = 7 x 1 x 5 +# And out_iters = [i0, i1, i2, i3] +# Broadcast mask = [True, False, True, False] +# We generate iters = A[i1][0][i3] +# If input is scalar, broadcast_mask=[] and inp_shape=[] +def getMaskedIters(broadcast_mask: list, out_iters: list, inp_shape : list): + base_idx = len(out_iters) - len(inp_shape) + masked_iters = [] + for i in range(len(broadcast_mask)): + if broadcast_mask[i]: + masked_iters.append(Int(0,32)) + else: + masked_iters.append(out_iters[base_idx]) + base_idx +=1 + return masked_iters + +# Given input +# A (4d array): 8 x 1 x 6 x 1 +# B (3d array): 7 x 1 x 5 +# We generate a loop with +# Result (4d array): 8 x 7 x 6 x 5 +# for i0=[0:8] +# for i1=[0:7] +# for i2=[0:6] +# for i3=[0:8] +# Result[i0][i1][i2][i3] = A[i0][0][i2][0] + B[i1][0][i3] +def generateBroadcastLoopBOp(expr_1, inp1_shape: list, expr_2, inp2_shape : list, expr_out, op: Op.Op): + output_shape, broadcast_mask_1, broadcast_mask_2 = Util.getBroadcastShapes(inp1_shape, inp2_shape) + out_iters = [Var('i' + str(i)) for i in range(len(output_shape))] + inp1_iters = getMaskedIters(broadcast_mask_1, out_iters, inp1_shape) + inp2_iters = getMaskedIters(broadcast_mask_2, out_iters, inp2_shape) + + inp1_arr_expr = addIndex(expr_1, inp1_iters) + inp2_arr_expr = addIndex(expr_2, inp2_iters) + out_arr_expr = addIndex(expr_out, out_iters) + + assign_expr = Assn(out_arr_expr, IntBop(inp1_arr_expr, op, inp2_arr_expr)) + out_loop = loop(output_shape, out_iters, [assign_expr]) + out_prog = Prog(out_loop) + return out_prog + +# Generates the index into a flattened tensor. +# Example: +# for i1=[0:s1] +# for i2=[0:s2] +# for i3=[0:s3] +# for i4=[0:s4] +# generate (i1*s2*s3*s4) + (i2*s3*s4) + (i3*s4) + (i4); +def getFlatArrIdxExpr(iters:list, shape:list): + assert len(iters) == len(shape), "No. of loop idx vars should be equal to loop shapes" + flat_idx_expr = Int(0,32) + for i in range(len(iters)): + vol = get_volume(shape[i+1:]) + flat_idx_expr = add(flat_idx_expr, mul(iters[i], Int(vol,32))) + return flat_idx_expr \ No newline at end of file diff --git a/Athos/SeeDot/Type.py b/Athos/SeeDot/Type.py index 0420e75..4b6b5ec 100644 --- a/Athos/SeeDot/Type.py +++ b/Athos/SeeDot/Type.py @@ -274,11 +274,9 @@ class InferType(ASTVisitor): node.expr2.gamma = dict(node.gamma) fType = self.visit(node.expr2) - if node.op in [AST.Operators.ADD, AST.Operators.ElemWiseMul, AST.Operators.ElemWiseDiv]: + if node.op in [AST.Operators.ADD, AST.Operators.SUB, AST.Operators.Equal, AST.Operators.ElemWiseMul, AST.Operators.ElemWiseDiv]: # Ops supporting broadcasting return self.typeCheckBroadcastOps(node, eType, fType) - elif node.op in [AST.Operators.SUB, AST.Operators.Equal]: - return self.visitBopAddLike(node, eType, fType) elif node.op == AST.Operators.MUL: return self.visitBopMul(node, eType, fType) elif node.op == AST.Operators.CONV: @@ -293,35 +291,18 @@ class InferType(ASTVisitor): # If adding a new op here which supports broadcasting, then be careful! # Currently, its assumed the op is commutative. If that is not true, following will be wrong ! - assert node.op in [AST.Operators.ADD, AST.Operators.ElemWiseMul, AST.Operators.ElemWiseDiv] - if (len(eType.shape) < len(fType.shape)): - # swap expr1 and expr2 -- this is valid for commutative ops - # be careful for ops which are not commutative - temp = node.expr1 - node.expr1 = node.expr2 - node.expr2 = temp - - temp = eType - eType = fType - fType = temp - - # Now true that dim(eType) >= dim(fTYpe) - assert len(eType.shape) >= len(fType.shape) - + assert node.op in [AST.Operators.ADD, AST.Operators.SUB, AST.Operators.Equal, AST.Operators.ElemWiseMul, AST.Operators.ElemWiseDiv] if isInt(eType) and isInt(fType): node.type = Int(eType.bitlen) elif isTensor(eType) and isTensor(fType): - revETypeShape = eType.shape[::-1] - revFTypeShape = fType.shape[::-1] - for i, fTypeCurDim in enumerate(revFTypeShape): - eTypeCurDim = revETypeShape[i] - if not(eTypeCurDim==1 or fTypeCurDim==1 or eTypeCurDim==fTypeCurDim): - # broadcast not possible - raise error - print("Broadcast not possible for current node.", eType.shape, fType.shape) - assert False - - # Broadcast possible - node.type = copy.copy(eType) + output_shape, _, _ = Util.getBroadcastShapes(eType.shape, fType.shape) + node.type = Tensor(shape=output_shape, bitlen=eType.bitlen) + elif isTensor(eType) and isInt(fType): + output_shape, _, _ = Util.getBroadcastShapes(eType.shape, []) + node.type = Tensor(shape=output_shape, bitlen=eType.bitlen) + elif isInt(eType) and isTensor(fType): + output_shape, _, _ = Util.getBroadcastShapes([], fType.shape) + node.type = Tensor(shape=output_shape, bitlen=eType.bitlen) else: print(eType, fType) assert False @@ -444,19 +425,6 @@ class InferType(ASTVisitor): node.type = Tensor(shape, eType.bitlen, eType.isSecret | fType.isSecret, getTaint_type(eType, fType)) return node.type - def visitBopAddLike(self, node:AST.BOp, eType: Type, fType: Type, args=None): - if isInt(eType) and isInt(fType): - pass - elif isTensor(eType) and isTensor(fType): - assert eType.shape == fType.shape - else: - assert False - - node.type = copy.copy(eType) - node.type.taint = getTaint_type(eType, fType) - node.type.isSecret = eType.isSecret | fType.isSecret - return node.type - def visitFunc(self, node:AST.Func, args=None): node.expr.gamma = dict(node.gamma) eType = self.visit(node.expr) diff --git a/Athos/SeeDot/Util.py b/Athos/SeeDot/Util.py index 20ea4c2..319ee1f 100644 --- a/Athos/SeeDot/Util.py +++ b/Athos/SeeDot/Util.py @@ -81,3 +81,64 @@ def write_debug_info(name_mapping): with open('debug/seedot_ezpc_name_map.txt', 'w') as f: for val in name_mapping: f.write(val + ' ' + name_mapping[val] + '\n') + +# Broadcasting Rules: +# A (4d array): 8 x 1 x 6 x 1 +# B (3d array): 7 x 1 x 5 +# Result (4d array): 8 x 7 x 6 x 5 +# Return Values +# Shape A broadcast mask: [False, True, False, True] +# Shape B broadcast mask: [True, False, True, False] +# Result shape: [8, 7, 6, 5] +# +# If input is a scalar, pass shape as [] +def getBroadcastShapes(Shape1 : list, Shape2 : list): + #Broadcast rules apply in reverse direction + shape1 = Shape1[::-1] + shape2 = Shape2[::-1] + len1 = len(shape1) + len2 = len(shape2) + outputshape = [] + swapped = False + if len1 != len2: + if len1 > len2: + len1, len2 = len2, len1 + shape1, shape2 = shape2, shape1 + swapped = True + assert len1 < len2 + + broadcastMask1 = [False] * len1 + broadcastMask2 = [False] * len2 + + for i in range(len2): + length = 0 + if i >= len1: + #broadcastMask1[i] = True + outputshape.append(shape2[i]) + continue + if shape1[i] != shape2[i]: + if shape1[i] == 1: + outputshape.append(shape2[i]) + broadcastMask1[i] = True + elif shape2[i] == 1: + outputshape.append(shape1[i]) + broadcastMask2[i] = True + else: + print("Dimension no. {} has a mismatch of length.".format(len2 - i)) + assert False, "Cannot broadcast. Program is malformed. Atleast one length should have been 1. i1: {} i2: {}".format(shape1[i], shape2[i]) + else: + outputshape.append(shape1[i]) + + if swapped: + broadcastMask1, broadcastMask2 = broadcastMask2, broadcastMask1 + + outputshape.reverse() + broadcastMask1.reverse() + broadcastMask2.reverse() + return outputshape, broadcastMask1, broadcastMask2 + +def get_volume(shape: list): + vol = 1 + for i in shape: + vol = vol * i + return vol \ No newline at end of file -- GitLab