Skip to content
Snippets Groups Projects
Commit f416af1e authored by Bhatu's avatar Bhatu
Browse files

Support for Split operation

We support splitting of a tensor along an axis into n pieces, where n
has to be a constant.
Eg:
  Split(Tensor of shape(5,30), splits=3, axis=1)
  returns 3 tensors of shape(5,10) each.

Currently we do not suport splitting into tensors of specified shape
(num_or_size_splits) though that functionality will be added later.

We also do not support splitting into n pieces where n is a runtime
value because we do not support run-time code generation yet.

This also adds support in the frontend for an op to return multiple
values.
parent 3844d8cf
No related branches found
No related tags found
No related merge requests found
......@@ -196,6 +196,19 @@ class Transpose(ASTNode):
self.expr = expr
self.perm = perm
# expr : ASTNode, perm : list of ints
class Slice(ASTNode):
def __init__(self, expr: ASTNode, subscriptRanges: list = None):
if assertInputTypes:
assert isinstance(expr, ID)
if subscriptRanges:
for elem in subscriptRanges:
assert isinstance(elem[0], int)
assert isinstance(elem[1], int)
super().__init__()
self.expr = expr
self.subscriptRanges = subscriptRanges
# expr : ASTNode, shape : list of int, order : int : optional
class Reshape(ASTNode):
def __init__(self, expr: ASTNode, shape: list, order: list):
......@@ -363,7 +376,7 @@ class Reduce(ASTNode):
# NOTE: Though datatype is being passed to this function, the output code eventually only has
# int in the apt bitlen for which the whole compilation is done
# Also, take note of the last parameter - "inputByParty". This can be used to set the party which
# which will do the input for this variable. Defaults to 0, which is interpretted as SERVER by the codegen.
# which will do the input for this variable. Defaults to SERVER.
class Input(ASTNode):
def __init__(self, shape:list, dataType:str, isSecret=True, inputByParty=Party.SERVER):
if assertInputTypes:
......
......@@ -42,6 +42,9 @@ class ASTVisitor:
def visitTranspose(self, node:AST.Transpose, args=None):
self.visit(node.expr, args)
def visitSlice(self, node:AST.Slice, args=None):
self.visit(node.expr, args)
def visitReshape(self, node:AST.Reshape, args=None):
self.visit(node.expr, args)
......@@ -97,6 +100,8 @@ class ASTVisitor:
return self.visitDecl(node, args)
elif isinstance(node, AST.Transpose):
return self.visitTranspose(node, args)
elif isinstance(node, AST.Slice):
return self.visitSlice(node, args)
elif isinstance(node, AST.Reshape):
return self.visitReshape(node, args)
elif isinstance(node, AST.Pool):
......
......@@ -42,6 +42,10 @@ class MtdAST(ASTVisitor):
node.metadata.update(mtd)
self.visit(node.expr, mtd)
def visitSlice(self, node:AST.Slice, mtd:dict):
node.metadata.update(mtd)
self.visit(node.expr, mtd)
def visitReshape(self, node:AST.Reshape, mtd:dict):
node.metadata.update(mtd)
self.visit(node.expr, mtd)
......
......@@ -51,6 +51,12 @@ class PrintAST(ASTVisitor):
self.visit(node.expr)
print("^Transpose", end=' ')
def visitSlice(self, node:AST.Transpose, args=None):
node.expr.depth = node.depth + 1
print(indent * node.depth, end=' ')
self.visit(node.expr)
print("extract slice", end=' ')
def visitReshape(self, node:AST.Reshape, args=None):
node.expr.depth = node.depth + 1
print(indent * node.depth, "reshape", end=' ')
......
......@@ -231,6 +231,37 @@ class IRBuilderCSF(IRBuilderAST):
return (final_prog, out_arr)
def visitSlice(self, node:AST.Slice, args=None):
(inp_prog, inp_arr) = self.visit(node.expr)
inp_type = node.expr.type
out_type = node.type
out_iters = self.getTempIterators(out_type.dim)
inp_iters = []
subscriptRanges = node.subscriptRanges
for idx,subrange in enumerate(subscriptRanges):
start = subrange[0]
inp_iters.append(IRUtil.add(out_iters[idx], IR.Int(start)))
out_arr = self.getTempVar()
out_arr_expr = IRUtil.addIndex(out_arr, out_iters)
inp_arr_expr = IRUtil.addIndex(inp_arr, inp_iters)
assign_expr = IR.Assn(out_arr_expr, inp_arr_expr)
loop = IRUtil.loop(out_type.shape, out_iters, [assign_expr])
# Finalize
comment1 = IR.Comment(str(node.metadata))
comment2 = IR.Comment("slice(" + inp_arr.idf + ", [" + ', '.join(str(e) for e in inp_type.shape) + "] --> [" + ', '.join(str(e) for e in out_type.shape) + "])")
slice_prog = IR.Prog([comment1, comment2] + loop)
final_prog = IRUtil.prog_merge(inp_prog, slice_prog)
for var in out_iters:
final_prog = IRUtil.prog_merge(IR.Prog([IR.Decl(var.idf, Type.Int(), isSecret=False)]), final_prog)
final_prog = IRUtil.prog_merge(IR.Prog([IR.Decl(out_arr.idf, out_type)]), final_prog)
if not(Util.Config.disableTruncOpti):
self.scaleFacMapping[out_arr.idf] = self.scaleFacMapping[inp_arr.idf]
return (final_prog, out_arr)
def visitReshape(self, node:AST.Reshape, args=None):
(prog_1, expr_1) = self.visit(node.expr)
......
......@@ -52,6 +52,11 @@ class LivenessAnalysis(ASTVisitor):
node.optidict[self.optidictKey] = unboundVars
return unboundVars
def visitSlice(self, node:AST.Slice, args):
unboundVars = self.visit(node.expr, args)
node.optidict[self.optidictKey] = unboundVars
return unboundVars
def visitReshape(self, node:AST.Reshape, args):
unboundVars = self.visit(node.expr, args)
node.optidict[self.optidictKey] = unboundVars
......
......@@ -207,6 +207,26 @@ class InferType(ASTVisitor):
node.type = Tensor(new_shape, exprType.bitlen, exprType.isSecret, exprType.taint)
return node.type
def visitSlice(self, node:AST.Slice, args=None):
node.expr.gamma = dict(node.gamma)
exprType = self.visit(node.expr)
assert isTensor(exprType)
subscriptRanges = node.subscriptRanges
shape = []
for i in subscriptRanges:
start = i[0]
end = i[1]
size = end - start + 1
shape.append(size)
assert(len(shape) == len(exprType.shape))
for i in range(0,len(shape)):
assert(shape[i] <= exprType.shape[i])
node.type = Tensor(shape, exprType.bitlen, exprType.isSecret, exprType.taint)
return node.type
def visitReshape(self, node:AST.Reshape, args=None):
node.expr.gamma = dict(node.gamma)
exprType = self.visit(node.expr)
......
......@@ -37,8 +37,8 @@ def generateASTForNode(graph, curNode, dictNodeNameToOutVarStr, extraNodeInfoDic
curNodeOp = curNode.getOp()
ast = None
func = getattr(TFNodesAST, curNodeOp)
(assignedVarAST, curAST) = func(graph, curNode, dictNodeNameToOutVarStr, extraNodeInfoDict)
return (assignedVarAST, curAST)
(assignedVarAST, curASTs) = func(graph, curNode, dictNodeNameToOutVarStr, extraNodeInfoDict)
return (assignedVarAST, curASTs)
#Takes the graph DS and outputs IR in SeeDot for the same
def generateIRCode(graph, extraInfoDict):
......@@ -51,29 +51,29 @@ def generateIRCode(graph, extraInfoDict):
for curNode in graph.getAllNodesRef():
for curInp in curNode.getInputsRef():
assert(curInp in dictNodeNameToOutVarStr) #Consequence of topological sorting of the TF graph
(assignedVarAST, curAst) = generateASTForNode(graph, curNode, dictNodeNameToOutVarStr, extraInfoDict)
mtdForCurAST = {AST.ASTNode.mtdKeyTFOpName : curNode.getOp(),
AST.ASTNode.mtdKeyTFNodeName : curNode.getName()}
if (curAst is None):
dictNodeNameToOutVarStr[curNode.getName()] = None
continue
curOutVarStr = outVarPrefix + str(outVarCt)
curOutVarAstNode = (assignedVarAST if assignedVarAST else AST.ID(curOutVarStr))
if program:
assert(type(innerMostLetASTNode) is AST.Let)
newNode = AST.Let(curOutVarAstNode, curAst, curOutVarAstNode)
mtdAST.visit(newNode, mtdForCurAST)
innerMostLetASTNode.expr = newNode
innerMostLetASTNode = newNode
else:
innerMostLetASTNode = AST.Let(AST.ID(curOutVarStr), curAst, curOutVarAstNode)
mtdAST.visit(innerMostLetASTNode, mtdForCurAST)
innerMostLetASTNode.depth = 0
program = innerMostLetASTNode
dictNodeNameToOutVarStr[curNode.getName()] = curOutVarStr
outVarCt += 1
(assignedVarAST, curAsts) = generateASTForNode(graph, curNode, dictNodeNameToOutVarStr, extraInfoDict)
for outputName, curAst in curAsts.items():
mtdForCurAST = {AST.ASTNode.mtdKeyTFOpName : curNode.getOp(),
AST.ASTNode.mtdKeyTFNodeName : outputName}
if (curAst is None):
dictNodeNameToOutVarStr[outputName] = None
continue
curOutVarStr = outVarPrefix + str(outVarCt)
curOutVarAstNode = (assignedVarAST if assignedVarAST else AST.ID(curOutVarStr))
if program:
assert(type(innerMostLetASTNode) is AST.Let)
newNode = AST.Let(curOutVarAstNode, curAst, curOutVarAstNode)
mtdAST.visit(newNode, mtdForCurAST)
innerMostLetASTNode.expr = newNode
innerMostLetASTNode = newNode
else:
innerMostLetASTNode = AST.Let(AST.ID(curOutVarStr), curAst, curOutVarAstNode)
mtdAST.visit(innerMostLetASTNode, mtdForCurAST)
innerMostLetASTNode.depth = 0
program = innerMostLetASTNode
dictNodeNameToOutVarStr[outputName] = curOutVarStr
outVarCt += 1
return (program, dictNodeNameToOutVarStr)
def readSizeInfo(fileName):
......
This diff is collapsed.
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment