-
Deevashwer authored
* CrypTFlow2 code * Some small corrections. * Updated License to 2020 * Updated repo README to include SCI (CrypTFlow2) and moved CrypTFlow2/ to SCI/ Co-authored-by:
Nishant Kumar <nishant.kr10@gmail.com>
Deevashwer authored* CrypTFlow2 code * Some small corrections. * Updated License to 2020 * Updated repo README to include SCI (CrypTFlow2) and moved CrypTFlow2/ to SCI/ Co-authored-by:
Nishant Kumar <nishant.kr10@gmail.com>
IRBuilderCSF.py 43.56 KiB
'''
Authors: Nishant Kumar.
Copyright:
Copyright (c) 2020 Microsoft Research
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
'''
import os, math, sys
import operator
import numpy as np
from collections import OrderedDict
import Util
import IR.IR as IR
import Type as Type
import AST.AST as AST
import IR.IRUtil as IRUtil
from AST.ASTVisitor import ASTVisitor
class IRBuilderCSF(ASTVisitor):
varNameDelim = ''
def __init__(self, intPartBitwidth=-1):
# For tracking temp variables
self._var_cnt = 0
self._iter_cnt = 0
# Global variables
# Used to note declarations which will go before any statements
# But since this affects memory consumption, use carefully
self.globalDecls = {} #Mapping of (identifier name (string) -> list of [type, secret/public variable, bitlen of decl])
# The 2nd arg can be either 'secret' or 'public'.
# If public/secret unspecified, default to 'secret'.
# The 3rd arg is used to specify the bitlen of the decl.
# Name mapping from SeeDot names to new names is useful for debugging
self.name_mapping = {}
#This is for optimizing the #truncation calls
self.scaleFac = Util.Config.consSF
self.bitwidth = Util.Config.wordLength
self.intPartBitwidth = intPartBitwidth
if (self.intPartBitwidth==-1):
self.intPartBitwidth = self.bitwidth - 2*self.scaleFac
self.scaleFacMapping = {}
self.inputByPartyMapping = {}
def getConsSF(self):
return Util.Config.consSF
# Variable and iterators creation
def getTempVars(self, n:int):
return [self.getTempVar() for i in range(n)]
def getTempVar(self):
var = IR.Var('tmp' + str(self._var_cnt))
self._var_cnt += 1
return var
def getTempIterators(self, n:int):
return [self.getTempIterator() for i in range(n)]
def getTempIterator(self):
var = IR.Var('i' + str(self._iter_cnt))
self._iter_cnt += 1
return var
# Computing exponent and intervals
def get_expnt(self, maxabs:float): # -> int
return self.getConsSF()
def addTruncateFunctionCallHelper(self, exprNumToScaleDown:int, expr1:IR.Var, expr2:IR.Var, node:AST.BOp):
assert(isinstance(node, AST.BOp))
if (exprNumToScaleDown==1):
exprToScaleDown = expr1
nodeToScaleDown = node.expr1
else:
assert(exprNumToScaleDown==2)
exprToScaleDown = expr2
nodeToScaleDown = node.expr2
return (exprToScaleDown, nodeToScaleDown)
def addTruncateFunctionCall(self, node:AST.ASTNode, nodeTypeStr: str, expr:IR.Var, consSF:int):
comment = IR.Comment("Truncation before {0} node.".format(nodeTypeStr))
argsDict = OrderedDict()
funcName = "ScaleDown"
if not(Type.isInt(node.type)):
outputShape = node.type.shape
for ii,curDimSize in enumerate(outputShape):
argsDict[IR.Int(curDimSize, 32)] = "size_" + str(ii)
funcName = funcName + str(len(outputShape))
argsDict[expr] = "expr"
argsDict[IR.Int(consSF,32)] = "consSF"
funcCall = IR.FuncCall(funcName, argsDict)
prog = IR.Prog([comment, funcCall])
return prog
#=================
# Visit Functions
#=================
def visitInt(self, node:AST.Int, args=None):
n = node.value
prog = IR.Prog([IR.Comment('Int node, isSecret = {0}.'.format(node.isSecret))])
expr = self.getTempVar()
bitlen = -1
if node.bitLen:
bitlen = node.bitLen
prog = IRUtil.prog_merge(IR.Prog([IR.Decl(expr.idf, node.type, bitlen, node.isSecret, [n])]), prog)
if (not(Util.Config.disableTruncOpti)):
self.scaleFacMapping[expr.idf] = self.scaleFac if node.isScaled else 0
return (prog, expr)
def visitFloat(self, node:AST.Float, args=None):
r = node.value
p = self.get_expnt(abs(r))
k = IR.DataType.getInt(np.ldexp(r, p))
expr = self.getTempVar()
prog = IR.Prog([IR.Comment('Float to int : {0} to {1}, isSecret = {2}.'.format(str(r), str(k), node.isSecret))])
prog = IRUtil.prog_merge(IR.Prog([IR.Decl(expr.idf, node.type, -1, node.isSecret, [k])]), prog)
if (not(Util.Config.disableTruncOpti)):
self.scaleFacMapping[expr.idf] = self.scaleFac
return (prog, expr)
def visitId(self, node:AST.ID, args=None):
idf = node.name
prog = IR.Prog([])
expr = IR.Var(idf)
if not(Util.Config.disableTruncOpti):
assert(expr.idf in self.scaleFacMapping)
return (prog, expr)
def visitDecl(self, node:AST.Decl, args=None):
def helperAssignGen(l1, l2, allComb):
if l2 == []:
allComb.append(l1)
else:
for cur in range(l2[0]):
helperAssignGen(l1 + [cur], l2[1:], allComb)
prog = IR.Prog([])
expr = self.getTempVar()
expr.inputVar = True
# If there is a valueList, then add assignment commands
specialBitLen = -1
if node.valueList:
# Add assignment statements for each element of the tensor in a different array
comment = IR.Comment(str(node.metadata))
prog = IRUtil.prog_merge(prog, IR.Prog([comment, IR.Comment('Element assignments for ' + expr.idf)]))
allComb = []
helperAssignGen([], node.shape, allComb)
for i,curComb in enumerate(allComb):
curVal = node.valueList[i]
finalVal = None
if isinstance(curVal, AST.Int):
finalVal = IR.Int(curVal.value, curVal.bitLen)
if (specialBitLen == -1 and curVal.bitLen != Util.Config.wordLength):
specialBitLen = curVal.bitLen
elif isinstance(curVal, AST.Float):
finalVal = IR.DataType.getInt(np.ldexp(curVal.value, Util.Config.consSF))
else:
# Assuming the elements can only be either int or floats
assert False
prog = IRUtil.prog_merge(prog, IR.Prog([IR.Assn(IRUtil.addIndex(expr, list(map(lambda x: IR.Int(x), curComb))),
finalVal)]))
prog = IRUtil.prog_merge(IR.Prog([IR.Decl(expr.idf,
node.type,
Util.Config.wordLength if specialBitLen == -1 else specialBitLen,
node.isSecret
)]), prog)
if not(Util.Config.disableTruncOpti):
self.scaleFacMapping[expr.idf] = self.scaleFac if node.isScaled else 0
return (prog, expr)
def visitTranspose(self, node:AST.Transpose, args=None):
(inp_prog, inp_arr) = self.visit(node.expr)
inp_type = node.expr.type
out_type = node.type
inp_iters = self.getTempIterators(inp_type.dim)
out_iters = []
perm = node.perm
if (perm is None):
perm = [i for i in reversed(range(len(inp_type.shape)))]
for i in perm:
out_iters.append(inp_iters[i])
out_arr = self.getTempVar()
out_arr_expr = IRUtil.addIndex(out_arr, out_iters)
inp_arr_expr = IRUtil.addIndex(inp_arr, inp_iters)
assign_expr = IR.Assn(out_arr_expr, inp_arr_expr)
loop = IRUtil.loop(inp_type.shape, inp_iters, [assign_expr])
# Finalize
comment1 = IR.Comment(str(node.metadata))
comment2 = IR.Comment("transpose(" + inp_arr.idf + ", [" + ', '.join(str(e) for e in inp_type.shape) + "] --> [" + ', '.join(str(e) for e in out_type.shape) + "])")
transpose_prog = IR.Prog([comment1, comment2] + loop)
final_prog = IRUtil.prog_merge(inp_prog, transpose_prog)
for var in inp_iters:
final_prog = IRUtil.prog_merge(IR.Prog([IR.Decl(var.idf, Type.Int(), isSecret=False)]), final_prog)
final_prog = IRUtil.prog_merge(IR.Prog([IR.Decl(out_arr.idf, out_type)]), final_prog)
if not(Util.Config.disableTruncOpti):
self.scaleFacMapping[out_arr.idf] = self.scaleFacMapping[inp_arr.idf]
return (final_prog, out_arr)
def visitReshape(self, node:AST.Reshape, args=None):
(prog_1, expr_1) = self.visit(node.expr)
'''
reshape(A, n, h, w)
cmd1: t1 = t2 = t3 = 0;
loop2: for n in 0:N:
for h in 0:H:
for w in 0:W:
cmd3: B[n][h][w] = A[t1][t2][t3]
cmd4: t3++;
cmd5: if (t3 == WW)
t3 = 0;
t2++;
if (t2 == HH)
t2 = 0;
t1++;
'''
typ_1 = node.expr.type
typ_2 = node.type
# Declare variables
expr_2 = self.getTempVar()
iters_1 = self.getTempIterators(typ_1.dim)
iters_2 = self.getTempIterators(typ_2.dim)
# Initialize to 0
cmd1 = [IR.Assn(var, IRUtil.zero) for var in iters_1]
# Incrementing the first index
first_iter = iters_1[0]
cmd4 = IRUtil.incCmd(first_iter)
# Incrementing other indices using a loop
cmd5 = [cmd4]
for i in range(1, typ_1.dim):
curr_iter = iters_1[i]
curr_size = IR.Int(typ_1.shape[i])
cmd5 = [IRUtil.incCmd(curr_iter), IR.If(IRUtil.eq(curr_iter, curr_size), [IRUtil.initVarToZero(curr_iter)] + cmd5)]
# Outer loop
# The iterators are selected based on the selection order specified by the user
loopShape = []
loopIters = []
if(node.order):
for order in node.order:
order = order - 1
loopShape.append(typ_2.shape[order])
loopIters.append(iters_2[order])
else:
loopShape = typ_2.shape
loopIters = iters_2
loop2 = IRUtil.loop(loopShape, loopIters, [IR.Assn(IRUtil.addIndex(expr_2, iters_2), IRUtil.addIndex(expr_1, iters_1))] + cmd5)
# Finalize
comment1 = IR.Comment(str(node.metadata))
comment2 = IR.Comment("reshape(" + expr_1.idf + ", " + ', '.join(str(e) for e in typ_2.shape) + ")")
reshape_prog = IR.Prog([comment1, comment2] + cmd1 + loop2)
prog_2 = IRUtil.prog_merge(prog_1, reshape_prog)
for var in iters_1:
prog_2 = IRUtil.prog_merge(IR.Prog([IR.Decl(var.idf, Type.Int(), isSecret=False)]), prog_2)
for var in iters_2:
prog_2 = IRUtil.prog_merge(IR.Prog([IR.Decl(var.idf, Type.Int(), isSecret=False)]), prog_2)
prog_2 = IRUtil.prog_merge(IR.Prog([IR.Decl(expr_2.idf, typ_2)]), prog_2)
if not(Util.Config.disableTruncOpti):
self.scaleFacMapping[expr_2.idf] = self.scaleFacMapping[expr_1.idf]
return (prog_2, expr_2)
def visitPool(self, node:AST.Pool, args=None):
(prog_1, expr_1) = self.visit(node.expr)
[N, H, W, CI] = node.expr.type.shape
[N1, outH, outW, CI1] = node.type.shape
assert(N==N1 and CI==CI1)
[expr_2] = self.getTempVars(1)
comment = IR.Comment(str(node.metadata))
funcCallArgsDict = OrderedDict()
funcCallArgsDict[IR.Int(N1, 32)] = "N1"
funcCallArgsDict[IR.Int(outH, 32)] = "outH"
funcCallArgsDict[IR.Int(outW, 32)] = "outW"
funcCallArgsDict[IR.Int(CI1, 32)] = "CI1"
funcCallArgsDict[IR.Int(node.options[AST.PaddingKeysDict.FH], 32)] = "FH"
funcCallArgsDict[IR.Int(node.options[AST.PaddingKeysDict.FW], 32)] = "FW"
funcCallArgsDict[IR.Int(node.options[AST.PaddingKeysDict.zPadHLeft], 32)] = "zPadHLeft"
funcCallArgsDict[IR.Int(node.options[AST.PaddingKeysDict.zPadHRight], 32)] = "zPadHRight"
funcCallArgsDict[IR.Int(node.options[AST.PaddingKeysDict.zPadWLeft], 32)] = "zPadWLeft"
funcCallArgsDict[IR.Int(node.options[AST.PaddingKeysDict.zPadWRight], 32)] = "zPadWRight"
funcCallArgsDict[IR.Int(node.options[AST.PaddingKeysDict.strideH], 32)] = "strideH"
funcCallArgsDict[IR.Int(node.options[AST.PaddingKeysDict.strideW], 32)] = "strideW"
funcCallArgsDict[IR.Int(N, 32)] = "N"
funcCallArgsDict[IR.Int(H, 32)] = "H"
funcCallArgsDict[IR.Int(W, 32)] = "W"
funcCallArgsDict[IR.Int(CI, 32)] = "CI"
funcCallArgsDict[expr_1] = "input"
funcCallArgsDict[expr_2] = "output"
funcCall = IR.FuncCall(node.poolType, funcCallArgsDict)
prog_pool = IR.Prog([comment, funcCall])
prog_2 = IRUtil.prog_merge(prog_1, prog_pool)
prog_2 = IRUtil.prog_merge(IR.Prog([IR.Decl(expr_2.idf, node.type)]), prog_2)
if not(Util.Config.disableTruncOpti):
self.scaleFacMapping[expr_2.idf] = self.scaleFacMapping[expr_1.idf]
return (prog_2, expr_2)
def visitUOp(self, node:AST.UOp, args=None):
(prog_1, expr_1) = self.visit(node.expr)
op = node.op
if op == AST.Operators.ADD:
return (prog_1, expr_1)
assert op == AST.Operators.SUB
typ_2 = node.type
expr_2 = self.getTempVar()
if Type.isInt(typ_2):
comment = IR.Comment(str(node.metadata))
bitlen = node.expr.bitlen
decl = IR.Decl(expr_2.idf, node.type, typ_2.bitlen, typ_2.isSecret)
assign = IR.Assn(expr_2, IRUtil.negate(expr_1))
prog_2 = IRUtil.prog_merge(prog_1, IR.Prog([comment, decl, assign]))
else:
# decl fresh vars
iters = self.getTempIterators(typ_2.dim)
# cmdl_assn
expr_1_elt = IRUtil.addIndex(expr_1, iters)
expr_2_elt = IRUtil.addIndex(expr_2, iters)
cmdl_assn = IRUtil.loop(typ_2.shape, iters, [IR.Assn(expr_2_elt, IRUtil.negate(expr_1_elt))])
comment = IR.Comment(str(node.metadata))
prog_2 = IRUtil.prog_merge(prog_1, IR.Prog([comment] + cmdl_assn))
prog_2 = IRUtil.prog_merge(IR.Prog([IR.Decl(expr_2.idf, node.type)]), prog_2)
if not(Util.Config.disableTruncOpti):
self.scaleFacMapping[expr_2.idf] = self.scaleFacMapping[expr_1.idf]
return (prog_2, expr_2)
def visitBOp(self, node:AST.BOp, args=None):
op = node.op
if (op in [AST.Operators.ADD, AST.Operators.SUB, AST.Operators.Equal]): return self.visitBopAddOrSubLike(node)
elif (op in [AST.Operators.ElemWiseMul, AST.Operators.ElemWiseDiv]): return self.visitBopElemWiseOp(node)
elif op == AST.Operators.MUL: return self.visitBopMul(node)
elif op == AST.Operators.CONV: return self.visitBopConv(node)
elif op == AST.Operators.CONVTRANSPOSE: return self.visitBopConvTranspose(node)
else: assert False
def visitBopAddOrSubLike(self, node:AST.BOp, args=None):
(prog_1, expr_1) = self.visit(node.expr1)
(prog_2, expr_2) = self.visit(node.expr2)
op = node.op
if (op == AST.Operators.ADD):
(op_ir, op_fn) = (IR.Op.Op['+'], operator.add)
funcName = "MatAdd"
elif (op == AST.Operators.SUB):
(op_ir, op_fn) = (IR.Op.Op['-'], operator.sub)
funcName = "MatSub"
elif (op == AST.Operators.Equal):
(op_ir, op_fn) = (IR.Op.Op['=='], operator.eq)
funcName = "MatEqual"
else:
assert False
typ_3 = node.type
expr_3 = self.getTempVar()
cmd0 = IR.Comment(expr_1.idf + ' ' + op_ir.name + ' ' + expr_2.idf)
comment = IR.Comment(str(node.metadata))
if not(Util.Config.disableTruncOpti):
expr1_sf = self.scaleFacMapping[expr_1.idf]
expr2_sf = self.scaleFacMapping[expr_2.idf]
scaleUpFactor = -1
if (expr1_sf > expr2_sf):
exprToScale = expr_2
typeOfExprToScale = node.expr2.type
scaleUpFactor = expr1_sf - expr2_sf
self.scaleFacMapping[expr_2.idf] = expr1_sf
elif (expr2_sf > expr1_sf):
exprToScale = expr_1
typeOfExprToScale = node.expr1.type
scaleUpFactor = expr2_sf - expr1_sf
self.scaleFacMapping[expr_1.idf] = expr2_sf
if scaleUpFactor!=-1:
comm = IR.Comment('Scale up of args needed was found while doing OptimizeTruncations.')
argsDict = OrderedDict()
curFuncName = "ScaleUp"
if not(Type.isInt(typeOfExprToScale)):
outputShape = typeOfExprToScale.shape
for ii,curDimSize in enumerate(outputShape):
argsDict[IR.Int(curDimSize, 32)] = "size_" + str(ii)
curFuncName += str(len(outputShape))
argsDict[exprToScale] = "exprToScale, arg#{0}".format(2 if (expr1_sf>expr2_sf) else 1)
argsDict[IR.Int(scaleUpFactor, 32)] = "ScaleUpFactor"
funcCall = IR.FuncCall(curFuncName, argsDict)
curProg = IR.Prog([comm,funcCall])
prog_1 = IRUtil.prog_merge(curProg, prog_1)
self.scaleFacMapping[expr_3.idf] = self.scaleFacMapping[expr_1.idf]
if Type.isInt(typ_3):
decl = IR.Decl(expr_3.idf, typ_3, typ_3.bitlen, typ_3.isSecret)
assign = IR.Assn(expr_3, IR.IntBop(expr_1, op_ir, expr_2))
prog_3 = IRUtil.prog_merge(prog_1, prog_2, IR.Prog([comment, cmd0, decl, assign]))
else:
## TODO
if (node.type.dim != node.expr1.type.dim):
# This needs broadcast of expr1
assert False # For now this shouldn't occur
if (node.type.dim != node.expr2.type.dim):
# This needs broadcast of expr2
funcName += 'BroadCast'
outputShape = typ_3.shape
argsDict = OrderedDict()
inp1_shape = node.expr1.type.shape
inp2_shape = node.expr2.type.shape
for ii,curDimSize in enumerate(inp1_shape):
argsDict[IR.Int(curDimSize, 32)] = "size_" + str(ii)
for ii,curDimSize in enumerate(inp2_shape):
argsDict[IR.Int(curDimSize, 32)] = "size_" + str(ii)
for ii,curDimSize in enumerate(outputShape):
argsDict[IR.Int(curDimSize, 32)] = "size_" + str(ii)
argsDict[expr_1] = "A"
argsDict[expr_2] = "B"
argsDict[expr_3] = "C"
funcCall = IR.FuncCall(funcName + self.varNameDelim + str(len(outputShape)), argsDict)
prog_3 = IRUtil.prog_merge(prog_1, prog_2, IR.Prog([comment, cmd0, funcCall]))
prog_3 = IRUtil.prog_merge(IR.Prog([IR.Decl(expr_3.idf, node.type)]), prog_3)
return (prog_3, expr_3)
def visitBopElemWiseOp(self, node:AST.BOp, args=None):
(prog_1, expr_1) = self.visit(node.expr1)
(prog_2, expr_2) = self.visit(node.expr2)
if (node.op == AST.Operators.ElemWiseMul):
op_ir = IR.Op.Op['.*']
funcName = "ElemWiseMul"
elif (node.op == AST.Operators.ElemWiseDiv):
op_ir = IR.Op.Op['./']
funcName = "ElemWiseDiv"
typ_3 = node.type
expr_3 = self.getTempVar()
cmd0 = IR.Comment(expr_1.idf + ' ' + op_ir.name + ' ' + expr_2.idf)
outputShape = typ_3.shape
argsDict = OrderedDict()
for ii,curDimSize in enumerate(outputShape):
argsDict[IR.Int(curDimSize, 32)] = "size_" + str(ii)
argsDict[expr_1] = "A"
argsDict[expr_2] = "B"
argsDict[expr_3] = "C"
progExtraBefore = IR.Prog([])
progExtraAfter = IR.Prog([])
if (Util.Config.disableTruncOpti):
progExtraAfter = self.addTruncateFunctionCall(node, funcName, expr_3, Util.Config.consSF)
else:
expr1_sf = self.scaleFacMapping[expr_1.idf]
expr2_sf = self.scaleFacMapping[expr_2.idf]
if (expr1_sf > self.scaleFac):
progExtraBefore = self.addTruncateFunctionCall(node.expr1, funcName, expr_1, expr1_sf-self.scaleFac)
self.scaleFacMapping[expr_1.idf] = self.scaleFac
if (expr2_sf > self.scaleFac):
progExtraBefore = IRUtil.prog_merge(progExtraBefore, self.addTruncateFunctionCall(node.expr2, funcName, expr_2, expr2_sf-self.scaleFac))
self.scaleFacMapping[expr_2.idf] = self.scaleFac
self.scaleFacMapping[expr_3.idf] = 2*self.scaleFac
funcCall = IR.FuncCall(funcName + self.varNameDelim + str(len(outputShape)), argsDict)
prog_3 = IRUtil.prog_merge(prog_1, prog_2, progExtraBefore, IR.Prog([cmd0, funcCall]))
prog_3 = IRUtil.prog_merge(IR.Prog([IR.Decl(expr_3.idf, node.type)]), prog_3, progExtraAfter)
return (prog_3, expr_3)
def visitBopMul(self, node:AST.BOp, args=None):
typ_1 = node.expr1.type
typ_2 = node.expr2.type
typ_3 = node.type
if (Type.isInt(typ_3)): return self.visitBopMulInt(node)
elif (typ_1.dim == 0 or Type.isInt(typ_1) or typ_2.dim == 0 or Type.isInt(typ_2)): return self.visitBopMulScalar1DTensor(node)
else: return self.visitBopMul2DTensor(node)
def visitBopMulInt(self, node:AST.BOp, args=None):
(prog_1, expr_1) = self.visit(node.expr1)
(prog_2, expr_2) = self.visit(node.expr2)
expr_3 = self.getTempVar()
comment = IR.Comment(str(node.metadata))
bitlen = node.expr.bitlen
decl = IR.Decl(expr_3.idf, node.type, node.type.bitlen, node.type.isSecret)
assign = IR.Assn(expr_3, IRUtil.mul(expr_1, expr_2))
prog_3 = IRUtil.prog_merge(prog_1, prog_2, IR.Prog([comment, decl, assign]))
progExtraBefore = IR.Prog([])
progExtraAfter = IR.Prog([])
if (Util.Config.disableTruncOpti):
progExtraAfter = self.addTruncateFunctionCall(node, "MulInt", expr_3, Util.Config.consSF)
else:
expr1_sf = self.scaleFacMapping[expr_1.idf]
expr2_sf = self.scaleFacMapping[expr_2.idf]
if (expr1_sf > self.scaleFac):
progExtraBefore = self.addTruncateFunctionCall(node.expr1, "MulInt", expr_1, expr1_sf-self.scaleFac)
self.scaleFacMapping[expr_1.idf] = self.scaleFac
if (expr2_sf > self.scaleFac):
progExtraBefore = IRUtil.prog_merge(progExtraBefore, self.addTruncateFunctionCall(node.expr2, "MulInt", expr_2, expr2_sf-self.scaleFac))
self.scaleFacMapping[expr_2.idf] = self.scaleFac
self.scaleFacMapping[expr_3.idf] = 2*self.scaleFac
prog_3 = IRUtil.prog_merge(progExtraBefore, prog_3, progExtraAfter)
return (prog_3, expr_3)
def visitBopMulScalar1DTensor(self, node:AST.BOp, args=None):
(prog_1, expr_1) = self.visit(node.expr1)
(prog_2, expr_2) = self.visit(node.expr2)
typ_1 = node.expr1.type
typ_2 = node.expr2.type
typ_3 = node.type
isIntMult = False
if typ_1.dim == 0 or Type.isInt(typ_1):
a, b = expr_1, expr_2
outputShape = typ_2.shape
isIntMult = (Type.isInt(typ_1))
else:
a, b = expr_2, expr_1
outputShape = typ_1.shape
isIntMult = (Type.isInt(typ_2))
# decl fresh vars
expr_3 = self.getTempVar()
cmd0 = IR.Comment(expr_1.idf + ' * ' + expr_2.idf)
funcCallArgsDict = OrderedDict()
for ii,curDimSize in enumerate(outputShape):
funcCallArgsDict[IR.Int(curDimSize, 32)] = "size_" + str(ii)
funcCallArgsDict[a] = "A"
funcCallArgsDict[b] = "B"
funcCallArgsDict[expr_3] = "C"
progExtraBefore = IR.Prog([])
progExtraAfter = IR.Prog([])
if (Util.Config.disableTruncOpti):
progExtraAfter = self.addTruncateFunctionCall(node, "ScalarMul", expr_3, Util.Config.consSF)
else:
expr1_sf = self.scaleFacMapping[expr_1.idf]
expr2_sf = self.scaleFacMapping[expr_2.idf]
if (expr1_sf > self.scaleFac):
progExtraBefore = self.addTruncateFunctionCall(node.expr1, "ScalarMul", expr_1, expr1_sf-self.scaleFac)
self.scaleFacMapping[expr_1.idf] = self.scaleFac
if (expr2_sf > self.scaleFac):
progExtraBefore = IRUtil.prog_merge(progExtraBefore, self.addTruncateFunctionCall(node.expr2, "ScalarMul", expr_2, expr2_sf-self.scaleFac))
self.scaleFacMapping[expr_2.idf] = self.scaleFac
self.scaleFacMapping[expr_3.idf] = 2*self.scaleFac
funcCall = IR.FuncCall('ScalarMul' + self.varNameDelim + str(len(outputShape)), funcCallArgsDict)
prog_3 = IRUtil.prog_merge(prog_1, prog_2, progExtraBefore, IR.Prog([cmd0, funcCall]))
prog_3 = IRUtil.prog_merge(IR.Prog([IR.Decl(expr_3.idf, node.type)]), prog_3, progExtraAfter)
return (prog_3, expr_3)
def visitBopMul2DTensor(self, node:AST.BOp, args=None):
(prog_1, expr_1) = self.visit(node.expr1)
(prog_2, expr_2) = self.visit(node.expr2)
# decl fresh vars
expr_3 = self.getTempVar()
typ_1 = node.expr1.type
typ_2 = node.expr2.type
typ_3 = node.type
[I, J] = typ_1.shape
[J, K] = typ_2.shape
typ_mul = Type.Tensor([J])
shrT = Util.Config.consSF
cmd0 = IR.Comment(expr_1.idf + ' * ' + expr_2.idf)
funcCallArgsDict = OrderedDict()
funcCallArgsDict[IR.Int(I, 32)] = "I"
funcCallArgsDict[IR.Int(J, 32)] = "J"
funcCallArgsDict[IR.Int(K, 32)] = "K"
funcCallArgsDict[expr_1] = "A"
funcCallArgsDict[expr_2] = "B"
funcCallArgsDict[expr_3] = "C"
#Add an arg as to which arg out of A or B is the model
# This is ok, since Athos is right now tailored for neural network inference
# and in inference, in every linear layer, either of A or B will be the model.
# This is required because for some backends, knowing which of A or B is the model
# can make a difference in their performance.
modelIsA = True
if ((expr_1.idf in self.inputByPartyMapping) and (self.inputByPartyMapping[expr_1.idf]==0)):
modelIsA = True
elif ((expr_2.idf in self.inputByPartyMapping) and (self.inputByPartyMapping[expr_2.idf]==0)):
modelIsA = False
else:
print("Expecting one of A or B to be the model, input by the server.")
assert(False)
funcCallArgsDict[IR.Bool(modelIsA)] = "modelIsA"
progExtraBefore = IR.Prog([])
progExtraAfter = IR.Prog([])
if (Util.Config.disableTruncOpti):
progExtraAfter = self.addTruncateFunctionCall(node, "MatMul2D", expr_3, Util.Config.consSF)
else:
expr1_sf = self.scaleFacMapping[expr_1.idf]
expr2_sf = self.scaleFacMapping[expr_2.idf]
if (expr1_sf > self.scaleFac):
progExtraBefore = self.addTruncateFunctionCall(node.expr1, "MatMul2D", expr_1, expr1_sf-self.scaleFac)
self.scaleFacMapping[expr_1.idf] = self.scaleFac
if (expr2_sf > self.scaleFac):
progExtraBefore = IRUtil.prog_merge(progExtraBefore, self.addTruncateFunctionCall(node.expr2, "MatMul2D", expr_2, expr2_sf-self.scaleFac))
self.scaleFacMapping[expr_2.idf] = self.scaleFac
self.scaleFacMapping[expr_3.idf] = 2*self.scaleFac
funcCall = IR.FuncCall("MatMul2D", funcCallArgsDict)
comment = IR.Comment(str(node.metadata))
prog_3 = IRUtil.prog_merge(prog_1, prog_2, progExtraBefore, IR.Prog([comment, cmd0, funcCall]))
prog_3 = IRUtil.prog_merge(IR.Prog([IR.Decl(expr_3.idf, node.type)]), prog_3, progExtraAfter)
return (prog_3, expr_3)
def visitBopConv(self, node:AST.BOp, args=None):
(prog1, expr1) = self.visit(node.expr1)
(prog2, expr2) = self.visit(node.expr2)
convDim = 2
if (AST.PaddingKeysDict.ConvDim in node.options):
convDim = node.options[AST.PaddingKeysDict.ConvDim]
if convDim == 2:
[N, H, W, CI] = node.expr1.type.shape
[FH, FW, CI1, CO] = node.expr2.type.shape
elif convDim == 3:
[N, D, H, W, CI] = node.expr1.type.shape
[FD, FH, FW, CI1, CO] = node.expr2.type.shape
else:
assert(False)
returnExpr = self.getTempVar()
comment = IR.Comment(expr1.idf + ' # ' + expr2.idf + ', convDim = ' + str(convDim))
funcCallArgsDict = OrderedDict()
funcCallArgsDict[IR.Int(N, 32)] = "N"
if convDim == 3:
funcCallArgsDict[IR.Int(D, 32)] = "D"
funcCallArgsDict[IR.Int(H, 32)] = "H"
funcCallArgsDict[IR.Int(W, 32)] = "W"
funcCallArgsDict[IR.Int(CI, 32)] = "CI"
if convDim == 3:
funcCallArgsDict[IR.Int(FD, 32)] = "FD"
funcCallArgsDict[IR.Int(FH, 32)] = "FH"
funcCallArgsDict[IR.Int(FW, 32)] = "FW"
funcCallArgsDict[IR.Int(CO, 32)] = "CO"
if convDim == 3:
funcCallArgsDict[IR.Int(node.options[AST.PaddingKeysDict.zPadDLeft], 32)] = "zPadDLeft"
funcCallArgsDict[IR.Int(node.options[AST.PaddingKeysDict.zPadDRight], 32)] = "zPadDRight"
funcCallArgsDict[IR.Int(node.options[AST.PaddingKeysDict.zPadHLeft], 32)] = "zPadHLeft"
funcCallArgsDict[IR.Int(node.options[AST.PaddingKeysDict.zPadHRight], 32)] = "zPadHRight"
funcCallArgsDict[IR.Int(node.options[AST.PaddingKeysDict.zPadWLeft], 32)] = "zPadWLeft"
funcCallArgsDict[IR.Int(node.options[AST.PaddingKeysDict.zPadWRight], 32)] = "zPadWRight"
if convDim == 3:
funcCallArgsDict[IR.Int(node.options[AST.PaddingKeysDict.strideD], 32)] = "strideD"
funcCallArgsDict[IR.Int(node.options[AST.PaddingKeysDict.strideH], 32)] = "strideH"
funcCallArgsDict[IR.Int(node.options[AST.PaddingKeysDict.strideW], 32)] = "strideW"
isGroupConv = False
if AST.PaddingKeysDict.group in node.options.keys():
funcCallArgsDict[IR.Int(node.options[AST.PaddingKeysDict.group], 32)] = "G"
isGroupConv = True
funcCallArgsDict[expr1] = "input"
funcCallArgsDict[expr2] = "filter"
if convDim == 3:
funcCallArgsDict[IR.Int(Util.Config.consSF, 32)] = "consSF"
funcCallArgsDict[returnExpr] = "output"
if convDim == 2:
funcCallName = "Conv2D"
else:
funcCallName = "Conv3D"
if isGroupConv:
funcCallName += "Group"
funcCallName += "Wrapper"
funcCall = IR.FuncCall(funcCallName, funcCallArgsDict)
progConv = IR.Prog([comment, funcCall])
progExtraBefore = IR.Prog([])
progExtraAfter = IR.Prog([])
if (Util.Config.disableTruncOpti):
progExtraAfter = self.addTruncateFunctionCall(node, "Conv", returnExpr, Util.Config.consSF)
else:
expr1_sf = self.scaleFacMapping[expr1.idf]
expr2_sf = self.scaleFacMapping[expr2.idf]
if (expr1_sf > self.scaleFac):
progExtraBefore = self.addTruncateFunctionCall(node.expr1, "Conv", expr1, expr1_sf-self.scaleFac)
self.scaleFacMapping[expr1.idf] = self.scaleFac
if (expr2_sf > self.scaleFac):
progExtraBefore = IRUtil.prog_merge(progExtraBefore, self.addTruncateFunctionCall(node.expr2, "Conv", expr2, expr2_sf-self.scaleFac))
self.scaleFacMapping[expr_2.idf] = self.scaleFac
self.scaleFacMapping[returnExpr.idf] = 2*self.scaleFac
returnProg = IRUtil.prog_merge(prog1, prog2, progExtraBefore, progConv)
returnProg = IRUtil.prog_merge(IR.Prog([IR.Decl(returnExpr.idf, node.type)]), returnProg, progExtraAfter)
return (returnProg, returnExpr)
def visitBopConvTranspose(self, node:AST.BOp, args=None):
(prog1, expr1) = self.visit(node.expr1)
(prog2, expr2) = self.visit(node.expr2)
convDim = 2
if (AST.PaddingKeysDict.ConvDim in node.options):
convDim = node.options[AST.PaddingKeysDict.ConvDim]
if convDim==2:
[N, H_prime, W_prime, CI1] = node.expr1.type.shape
[FH, FW, CO, CI] = node.expr2.type.shape
elif convDim==3:
[N, D_prime, H_prime, W_prime, CI1] = node.expr1.type.shape
[FD, FH, FW, CO, CI] = node.expr2.type.shape
else:
assert(False)
assert(CI1 == CI)
H = node.options[AST.PaddingKeysDict.outputImgH] #outputH
W = node.options[AST.PaddingKeysDict.outputImgW] #outputW
pad_h_total = node.options[AST.PaddingKeysDict.zPadHLeft] + node.options[AST.PaddingKeysDict.zPadHRight]
pad_w_total = node.options[AST.PaddingKeysDict.zPadWLeft] + node.options[AST.PaddingKeysDict.zPadWRight]
strideH = node.options[AST.PaddingKeysDict.strideH]
strideW = node.options[AST.PaddingKeysDict.strideW]
[pad_h_tr_total, stride_h_tr, h_prime_tilde] = AST.Operators.findConvTransposePadding(H, H_prime, FH, pad_h_total, strideH)
[pad_w_tr_total, stride_w_tr, w_prime_tilde] = AST.Operators.findConvTransposePadding(W, W_prime, FW, pad_w_total, strideW)
[pad_h_tr_left, pad_h_tr_right] = AST.Operators.findLeftRightPaddingFromTotalPadding(pad_h_tr_total)
[pad_w_tr_left, pad_w_tr_right] = AST.Operators.findLeftRightPaddingFromTotalPadding(pad_w_tr_total)
assert(AST.Operators.findConvOutputImgSize(h_prime_tilde, pad_h_tr_total, FH, stride_h_tr) == H)
assert(AST.Operators.findConvOutputImgSize(w_prime_tilde, pad_w_tr_total, FW, stride_w_tr) == W)
if convDim == 3:
D = node.options[AST.PaddingKeysDict.outputImgD] #outputD
pad_d_total = node.options[AST.PaddingKeysDict.zPadDLeft] + node.options[AST.PaddingKeysDict.zPadDRight]
strideD = node.options[AST.PaddingKeysDict.strideD]
[pad_d_tr_total, stride_d_tr, d_prime_tilde] = AST.Operators.findConvTransposePadding(D, D_prime, FD, pad_d_total, strideD)
[pad_d_tr_left, pad_d_tr_right] = AST.Operators.findLeftRightPaddingFromTotalPadding(pad_d_tr_total)
assert(AST.Operators.findConvOutputImgSize(d_prime_tilde, pad_d_tr_total, FD, stride_d_tr) == D)
returnExpr = self.getTempVar()
comment = IR.Comment(expr1.idf + ' #T ' + expr2.idf + ', convDim = ' + str(convDim))
funcCallArgsDict = OrderedDict()
funcCallArgsDict[IR.Int(N, 32)] = "N"
if convDim==3:
funcCallArgsDict[IR.Int(D_prime, 32)] = "D_prime"
funcCallArgsDict[IR.Int(H_prime, 32)] = "H_prime"
funcCallArgsDict[IR.Int(W_prime, 32)] = "W_prime"
funcCallArgsDict[IR.Int(CI, 32)] = "CI"
if convDim==3:
funcCallArgsDict[IR.Int(FD, 32)] = "FD"
funcCallArgsDict[IR.Int(FH, 32)] = "FH"
funcCallArgsDict[IR.Int(FW, 32)] = "FW"
funcCallArgsDict[IR.Int(CO, 32)] = "CO"
if convDim==3:
funcCallArgsDict[IR.Int(D, 32)] = "D"
funcCallArgsDict[IR.Int(H, 32)] = "H"
funcCallArgsDict[IR.Int(W, 32)] = "W"
if convDim==3:
funcCallArgsDict[IR.Int(pad_d_tr_left, 32)] = "pad_d_tr_left"
funcCallArgsDict[IR.Int(pad_d_tr_right, 32)] = "pad_d_tr_right"
funcCallArgsDict[IR.Int(pad_h_tr_left, 32)] = "pad_h_tr_left"
funcCallArgsDict[IR.Int(pad_h_tr_right, 32)] = "pad_h_tr_right"
funcCallArgsDict[IR.Int(pad_w_tr_left, 32)] = "pad_w_tr_left"
funcCallArgsDict[IR.Int(pad_w_tr_right, 32)] = "pad_w_tr_right"
if convDim==3:
funcCallArgsDict[IR.Int(strideD, 32)] = "strideD"
funcCallArgsDict[IR.Int(strideH, 32)] = "strideH"
funcCallArgsDict[IR.Int(strideW, 32)] = "strideW"
funcCallArgsDict[expr1] = "input"
funcCallArgsDict[expr2] = "filter"
if convDim == 3:
funcCallArgsDict[IR.Int(Util.Config.consSF, 32)] = "consSF"
funcCallArgsDict[returnExpr] = "output"
if convDim == 2:
funcCallName = "ConvTranspose2D"
else:
funcCallName = "ConvTranspose3D"
funcCallName += "Wrapper"
funcCall = IR.FuncCall(funcCallName, funcCallArgsDict)
progConv = IR.Prog([comment, funcCall])
progExtraBefore = IR.Prog([])
progExtraAfter = IR.Prog([])
if (Util.Config.disableTruncOpti):
progExtraAfter = self.addTruncateFunctionCall(node, "ConvTranspose", returnExpr, self.scaleFac)
else:
expr1_sf = self.scaleFacMapping[expr1.idf]
expr2_sf = self.scaleFacMapping[expr2.idf]
if (expr1_sf > self.scaleFac):
progExtraBefore = self.addTruncateFunctionCall(node.expr1, "ConvTranspose", expr1, expr1_sf-self.scaleFac)
self.scaleFacMapping[expr1.idf] = self.scaleFac
if (expr2_sf > self.scaleFac):
progExtraBefore = IRUtil.prog_merge(progExtraBefore, self.addTruncateFunctionCall(node.expr2, "ConvTranspose", expr2, expr2_sf-self.scaleFac))
self.scaleFacMapping[expr2.idf] = self.scaleFac
self.scaleFacMapping[returnExpr.idf] = 2*self.scaleFac
returnProg = IRUtil.prog_merge(prog1, prog2, progExtraBefore, progConv)
returnProg = IRUtil.prog_merge(IR.Prog([IR.Decl(returnExpr.idf, node.type)]), returnProg, progExtraAfter)
return (returnProg, returnExpr)
def visitFunc(self, node:AST.Func, args=None):
op = node.op
assert(op in [AST.Operators.Floor, AST.Operators.Shape, AST.Operators.RELU, AST.Operators.ClearMemSecret, AST.Operators.ClearMemPublic])
return self.visitFloorLike(node)
def visitFloorLike(self, node:AST.Func, args=None):
(prog1, expr1) = self.visit(node.expr)
tmpExpr = self.getTempVar()
if node.op == AST.Operators.Floor:
funcName = "Floor"
elif node.op == AST.Operators.Shape:
funcName = "Shape"
elif node.op == AST.Operators.RELU:
funcName = "Relu"
elif node.op == AST.Operators.ClearMemSecret:
funcName = "ClearMemSecret"
elif node.op == AST.Operators.ClearMemPublic:
funcName = "ClearMemPublic"
else:
assert False
argsList = OrderedDict()
inputType = node.expr.type
if Type.isTensor(inputType):
for ii, curDim in enumerate(inputType.shape):
argsList[IR.Int(curDim, 32)] = "inShape_" + str(ii)
argsList[expr1] = "inArr"
if Type.isTensor(node.type):
argsList[tmpExpr] = "outArr"
if node.op == AST.Operators.Floor:
argsList[IR.Int(Util.Config.consSF,32)] = "curScale"
progExtra = IR.Prog([])
if (Util.Config.disableTruncOpti):
if node.op == AST.Operators.RELU:
argsList[IR.Int(Util.Config.consSF,32)] = "consSF"
argsList[IR.Bool(False)] = "doTruncation"
else:
final_sf = self.scaleFacMapping[expr1.idf]
if node.op == AST.Operators.RELU:
argsList[IR.Int(final_sf - self.scaleFac,32)] = "consSF"
if (final_sf > self.scaleFac):
#If it can't tolerate one more mult operation, then scale down here
final_sf = self.scaleFac
argsList[IR.Bool(True)] = "doTruncation"
else:
argsList[IR.Bool(False)] = "doTruncation"
self.scaleFacMapping[tmpExpr.idf] = final_sf
comment = IR.Comment(str(node.metadata))
funcNameSuffix = ""
if Type.isTensor(inputType):
funcNameSuffix = str(len(inputType.shape))
progFinal = IRUtil.prog_merge(prog1 , IR.Prog([comment, IR.FuncCall(funcName + self.varNameDelim + funcNameSuffix, argsList)]))
if Type.isTensor(node.type):
progFinal = IRUtil.prog_merge(IR.Prog([IR.Decl(tmpExpr.idf, node.type)]), progFinal)
progFinal = IRUtil.prog_merge(progFinal, progExtra)
return (progFinal, tmpExpr)
def visitLet(self, node:AST.Let, args=None):
(prog_1, expr_1) = self.visit(node.decl)
typ_1 = node.decl.type
idf = node.name.name
if (not(Type.isInt(typ_1))):
self.name_mapping[idf] = expr_1.idf
if (not(Util.Config.disableTruncOpti)):
self.scaleFacMapping[idf] = self.scaleFacMapping[expr_1.idf]
if (expr_1.idf in self.inputByPartyMapping):
self.inputByPartyMapping[idf] = self.inputByPartyMapping[expr_1.idf]
del self.inputByPartyMapping[expr_1.idf]
(prog_2, expr_2) = self.visit(node.expr)
prog_2 = prog_2.subst(idf, expr_1)
expr_2 = expr_2.subst(idf, expr_1)
prog_3 = IRUtil.prog_merge(prog_1, prog_2)
return (prog_3, expr_2)
def visitUninterpFuncCall(self, node:AST.UninterpFuncCall, args=None):
progList = []
exprList = []
for ii, curArg in enumerate(node.argsList):
(progN, exprN) = self.visit(curArg)
progList.append(progN)
exprList.append(exprN)
returnExpr = self.getTempVar()
funcName = node.funcName
funcName += self.varNameDelim + str(len(node.outputShape))
for ii, curArg in enumerate(node.argsList):
if Type.isTensor(curArg.type):
curShape = curArg.type.shape
# If len(shape) == 0 : that means its a float - no need to qualify
# the function name with 0 in that case, since its essentially
# become an int.
if (len(curShape) > 0):
funcName += self.varNameDelim + str(len(curShape))
### TODO : right now if random strings like int are passed, its being set as datatype int -- int datatype in
# unintrepreted func call is being used in a hacky way right now
# Policy :
# First output tensor sizes are inserted in args.
# Then for each input tensor, its shape is inserted in args, followed by the input tensor itself.
# If the current input tensor has the same shape as any of the previous tensors, then its shape is not inserted.
funcArgsList = OrderedDict()
if not(Util.Config.disableTruncOpti):
#TODO -- remove CreateTensor from uninterp function calls
for ii, curArg in enumerate(node.argsList):
curExpr = exprList[ii]
curScale = self.scaleFacMapping[curExpr.idf]
curType = curArg.type
if (not(Type.isInt(curType))) and (curScale > self.scaleFac) and (curType.isSecret):
curProg = self.addTruncateFunctionCall(curArg, "UninterpFuncCall", curExpr, curScale - self.scaleFac)
progList.insert(0,curProg)
self.scaleFacMapping[curExpr.idf] = self.scaleFac
self.scaleFacMapping[returnExpr.idf] = self.scaleFac
tensorShapesFound = {}
outputShape = node.type.shape
for ii, curDim in enumerate(outputShape):
funcArgsList[IR.Int(curDim, 32)] = "OutputShape_" + str(ii)
tensorShapesFound[tuple(outputShape)] = True
for ii in range(0, len(node.argsList)):
if node.outputDiffInpDims < 2 and Type.isTensor(node.argsList[ii].type):
curInpShape = node.argsList[ii].type.shape
if ((node.outputDiffInpDims == 1) or (node.outputDiffInpDims == 0 and tuple(curInpShape) not in tensorShapesFound)):
for jj, curDim in enumerate(curInpShape):
funcArgsList[IR.Int(curDim, 32)] = "Input_" + str(ii) + self.varNameDelim + str(jj)
tensorShapesFound[tuple(curInpShape)] = True
funcArgsList[exprList[ii]] = "inpExpr_" + str(ii)
funcArgsList[returnExpr] = "output"
comment = IR.Comment(str(node.metadata))
progFinal = progList[0]
if len(progList) > 1:
for ii in range(1, len(progList)):
progFinal = IRUtil.prog_merge(progFinal, progList[ii])
progFinal = IRUtil.prog_merge(progFinal, IR.Prog([comment, IR.FuncCall(funcName, funcArgsList)]))
progFinal = IRUtil.prog_merge(IR.Prog([IR.Decl(returnExpr.idf,
node.type,
isSecret=False if node.isSecret is False else "secret")]),
progFinal)
return (progFinal, returnExpr)
def visitArgMax(self, node:AST.ArgMax, args=None):
(prog_1, expr1) = self.visit(node.expr)
(prog_2, expr2) = self.visit(node.dim)
tmpExpr = self.getTempVar()
outputShape = node.type.shape
funcArgsList = OrderedDict()
outputShape = node.type.shape
for ii, curDim in enumerate(outputShape):
funcArgsList[IR.Int(curDim, 32)] = "OutputShape_" + str(ii)
for ii, curDim in enumerate(node.inShape):
funcArgsList[IR.Int(curDim, 32)] = "OutputShape_" + str(ii)
funcArgsList[expr1] = "inArr"
funcArgsList[expr2] = "dim"
funcArgsList[tmpExpr] = "outArr"
if not(Util.Config.disableTruncOpti):
self.scaleFacMapping[tmpExpr.idf] = 0 #TODO -- is this the right thing to do?
funcCall = IR.FuncCall("ArgMax" + self.varNameDelim + str(len(outputShape)), funcArgsList)
comment = IR.Comment(str(node.metadata))
prog_3 = IRUtil.prog_merge(prog_1, prog_2, IR.Prog([comment, funcCall]))
prog_3 = IRUtil.prog_merge(IR.Prog([IR.Decl(tmpExpr.idf, node.type)]), prog_3)
return (prog_3, tmpExpr)
def visitInput(self, node:AST.Input, args=None):
returnExpr = self.getTempVar()
returnExpr.inputVar = True
comment = IR.Comment(str(node.metadata))
if not(Util.Config.disableTruncOpti):
self.scaleFacMapping[returnExpr.idf] = self.scaleFac
self.inputByPartyMapping[returnExpr.idf] = node.inputByParty
return (IR.Prog([comment, IR.Input(returnExpr, node.shape, node.dataType, node.isSecret, node.inputByParty)]), returnExpr)
def visitReduce(self, node:AST.Reduce, args=None):
(prog_1, expr1) = self.visit(node.expr)
(prog_2, expr2) = self.visit(node.dim)
returnExpr = self.getTempVar()
assert(node.op in [AST.Operators.ADD, AST.Operators.Mean])
if (node.op == AST.Operators.ADD):
funcName = "ReduceSum"
elif (node.op == AST.Operators.Mean):
funcName = "ReduceMean"
if not(Util.Config.disableTruncOpti):
self.scaleFacMapping[returnExpr.idf] = self.scaleFacMapping[expr1.idf]
funcArgsList = OrderedDict()
outputShape = node.type.shape
for ii, curDim in enumerate(outputShape):
funcArgsList[IR.Int(curDim, 32)] = "OutputShape_" + str(ii)
inputShape = node.expr.type.shape
for ii, curDim in enumerate(inputShape):
funcArgsList[IR.Int(curDim, 32)] = "InputShape_" + str(ii)
funcArgsList[expr1] = "inputArr"
funcArgsList[expr2] = "dimension"
funcArgsList[returnExpr] = "outArr"
funcCall = IR.FuncCall(funcName + self.varNameDelim + str(len(outputShape)) + self.varNameDelim + str(len(inputShape)), funcArgsList)
comment = IR.Comment(str(node.metadata))
prog_3 = IRUtil.prog_merge(prog_1, prog_2, IR.Prog([comment, funcCall]))
prog_3 = IRUtil.prog_merge(IR.Prog([IR.Decl(returnExpr.idf, node.type)]), prog_3)
return (prog_3, returnExpr)
def visitFusedBatchNorm(self, node:AST.FusedBatchNorm, args=None):
(prog1, expr1) = self.visit(node.expr)
(prog2, expr2) = self.visit(node.multExpr)
(prog3, expr3) = self.visit(node.addExpr)
returnExpr = self.getTempVar()
funcArgsList = OrderedDict()
for ii, elem in enumerate(node.type.shape):
funcArgsList[IR.Int(elem, 32)] = "elem_"+str(ii)
funcArgsList[expr1] = "expr"
funcArgsList[expr2] = "multExpr"
funcArgsList[expr3] = "addExpr"
progExtraBefore = IR.Prog([])
multExprScaleDownSf = self.scaleFac
addExprScaleUpSf = 0
if not(Util.Config.disableTruncOpti):
#TruncOpti is on
multExprScaleDownSf = 0
addExprScaleUpSf = 0
expr_sf = self.scaleFacMapping[expr1.idf]
multExpr_sf = self.scaleFacMapping[expr2.idf]
addExpr_sf = self.scaleFacMapping[expr3.idf]
if (expr_sf > self.scaleFac):
#Scale down needed
progExtraBefore = self.addTruncateFunctionCall(node.expr, "FusedBatchNorm", expr1, expr_sf - self.scaleFac)
self.scaleFacMapping[expr1.idf] = self.scaleFac
if (multExpr_sf > self.scaleFac):
#Scale down needed
progExtraBefore = IRUtil.prog_merge(progExtraBefore, self.addTruncateFunctionCall(node.multExpr, "FusedBatchNorm", expr2, multExpr_sf - self.scaleFac))
self.scaleFacMapping[expr2.idf] = self.scaleFac
final_sf = 2*self.scaleFac
assert(final_sf >= addExpr_sf)
if (final_sf > addExpr_sf):
addExprScaleUpSf = final_sf - addExpr_sf
self.scaleFacMapping[expr3.idf] += addExprScaleUpSf
self.scaleFacMapping[returnExpr.idf] = final_sf
funcArgsList[IR.Int(multExprScaleDownSf, 32)] = "multExprScaleDownSf"
funcArgsList[IR.Int(addExprScaleUpSf, 32)] = "addExprScaleUpSf"
funcArgsList[returnExpr] = "returnExpr"
funcCallIR = IR.FuncCall("FusedBatchNorm" + self.varNameDelim
+ str(len(node.type.shape)) + self.varNameDelim #one for output
+ str(len(node.type.shape)) + self.varNameDelim #one for input
+ str(len(node.multExpr.type.shape)) + self.varNameDelim
+ str(len(node.addExpr.type.shape)),
funcArgsList)
comment = IR.Comment(str(node.metadata))
returnProg = IRUtil.prog_merge(prog1, prog2, prog3, progExtraBefore, IR.Prog([comment, funcCallIR]))
returnProg = IRUtil.prog_merge(IR.Prog([IR.Decl(returnExpr.idf, node.type)]), returnProg)
return (returnProg, returnExpr)