Skip to content
Snippets Groups Projects
Commit 07901e61 authored by Bhatu's avatar Bhatu
Browse files

Add a new garbage collector pass.

For identity like ops (a=b), we sometimes run into use-after-free and
double-free bugs.

For this snippet
    J100 = J99
    J101 = J99 + 3         <- last use of J99
    J102 = J100 * 2        <- last use of J100
before we were doing:
    J100 = J99
    J101 = J99 + 3
    free(J99)
    J102 = J100 * 2        <- use-after-free
    free(J100)             <- double-free
now we do:
    J100 = J99
    J101 = J99 + 3
    J102 = J100 * 2
    free(J100)

Algorithm:
We iterate through the program in reverse order and every time we see a
use of a variable, we insert a free after it, unless we have already
freed it before. When we check a variable has been freed, we also check
whether any of its aliases have also been freed.

For alias analysis, we maintain alias sets using disjoint sets. Whenever
we encounter an a=b statement, we simply do a union of a and b sets.

This replaces the old LivenessOpti pass.
parent c449271d
No related branches found
No related tags found
No related merge requests found
......@@ -37,7 +37,8 @@ from AST.MtdAST import MtdAST
from IR.IRBuilderCSF import IRBuilderCSF
from Codegen.EzPC import EzPC as EzPCCodegen
import Optimizations.ReluMaxpoolOpti as ReluMaxpoolOpti
import Optimizations.LivenessOpti as LivenessOpti
import Optimizations.GarbageCollector as GarbageCollector
from collections import OrderedDict
class Compiler:
def __init__(self, version, target, sfType, astFile, printASTBool, consSF, bitlen, outputFileName,
......@@ -117,17 +118,18 @@ class Compiler:
print("Relu-maxpool optimization done.")
if not(Util.Config.disableLivenessOpti):
print("Performing Liveness Optimization...")
print("Performing Garbage colelction...")
mtdAST = MtdAST()
LivenessOpti.LivenessAnalysis().visit(ast)
LivenessOpti.LivenessOpti().visit(ast, [mtdAST, 0, {}])
print("Liveness optimization done.")
GC = GarbageCollector.GarbageCollector(ast)
GC.run([mtdAST])
print("Garbage collection done.")
# Perform type inference and annotate nodes with type information
InferType().visit(ast)
if Util.Config.printASTBool:
PrintAST().visit(ast)
print("\n")
sys.stdout.flush()
IRUtil.init()
......
......@@ -783,7 +783,7 @@ class IRBuilderCSF(IRBuilderAST):
# and in inference, in every linear layer, either of A or B will be a model weight.
# This is required because for some backends, knowing which of A or B is a model weight
# can make a difference in their performance.
modelIsA = True
assert (self.isModel(node.expr1) or self.isModel(node.expr2)), "Expecting one of A or B to be an input by the server (model weight)."
modelIsA = True
if (not self.isModel(node.expr1)):
......
'''
Authors: Nishant Kumar.
Copyright:
Copyright (c) 2020 Microsoft Research
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
'''
import AST.AST as AST
from AST.ASTVisitor import ASTVisitor
from AST.MtdAST import MtdAST
#In the below analysis, each node saves what all unbound variables
# are used in its sub-tree. If the set is empty, nothing is saved.
# A subsequent pass then finds the variables
# wnich can be cleared.
class LivenessAnalysis(ASTVisitor):
optidictKey = "LivenessAnalysis" #This key will be used to store in optidict of the ASTNode
# list of all variables which are unbound in that sub-tree.
def visitInt(self, node:AST.Int, args):
return []
def visitFloat(self, node:AST.Float, args):
return []
def visitId(self, node:AST.ID, args):
unboundVars = [node.name]
node.optidict[self.optidictKey] = unboundVars
return unboundVars
def visitDecl(self, node:AST.Decl, args):
return []
def visitTranspose(self, node:AST.Transpose, args):
unboundVars = self.visit(node.expr, args)
node.optidict[self.optidictKey] = unboundVars
return unboundVars
def visitSlice(self, node:AST.Slice, args):
unboundVars = self.visit(node.expr, args)
node.optidict[self.optidictKey] = unboundVars
return unboundVars
def visitReshape(self, node:AST.Reshape, args):
unboundVars = self.visit(node.expr, args)
node.optidict[self.optidictKey] = unboundVars
return unboundVars
def visitPool(self, node:AST.Pool, args):
unboundVars = self.visit(node.expr, args)
node.optidict[self.optidictKey] = unboundVars
return unboundVars
def visitUOp(self, node:AST.UOp, args):
unboundVars = self.visit(node.expr, args)
node.optidict[self.optidictKey] = unboundVars
return unboundVars
def visitBOp(self, node:AST.BOp, args):
unboundVars = list(set(self.visit(node.expr1, args) + self.visit(node.expr2, args)))
node.optidict[self.optidictKey] = unboundVars
return unboundVars
def visitFunc(self, node:AST.Func, args):
unboundVars = self.visit(node.expr, args)
node.optidict[self.optidictKey] = unboundVars
return unboundVars
def visitLet(self, node:AST.Let, args):
declVars = self.visit(node.decl, args)
exprVars = self.visit(node.expr, args)
unboundVars = list((set(declVars)|set(exprVars))-set([node.name.name]))
if isinstance(node.decl, AST.ID):
#This is of the type let J1 = J2 in J1.
# Since J1 and J2 refer to the same variable, J2 should remain bounded.
unboundVars = list(set(unboundVars) - set([node.decl.name]))
node.optidict[self.optidictKey] = unboundVars
return unboundVars
def visitUninterpFuncCall(self, node:AST.UninterpFuncCall, args):
unboundVarsSet = set([])
for elem in node.argsList:
unboundVarsSet |= set(self.visit(elem, args))
unboundVars = list(unboundVarsSet)
node.optidict[self.optidictKey] = unboundVars
return unboundVars
def visitArgMax(self, node:AST.ArgMax, args):
unboundVars = list(set(self.visit(node.expr, args) + self.visit(node.dim, args)))
node.optidict[self.optidictKey] = unboundVars
return unboundVars
def visitReduce(self, node:AST.Reduce, args):
unboundVars = list(set(self.visit(node.expr, args)))
node.optidict[self.optidictKey] = unboundVars
return unboundVars
def visitInput(self, node:AST.Input, args):
return []
def visitFusedBatchNorm(self, node:AST.FusedBatchNorm, args):
unboundVars = list(set(self.visit(node.expr, args) + self.visit(node.multExpr, args) + self.visit(node.addExpr, args)))
node.optidict[self.optidictKey] = unboundVars
return unboundVars
class LivenessOpti(ASTVisitor):
def visitLet(self, node:AST.Let, args):
assert(isinstance(args, list))
assert(isinstance(args[0], MtdAST))
assert(isinstance(args[1], int))
assert(isinstance(args[2], dict)) #dict {variable name string -> isSecretVariable bool}
curUnboundVars = []
exprUnboundVars = []
if LivenessAnalysis.optidictKey in node.optidict:
curUnboundVars = node.optidict[LivenessAnalysis.optidictKey]
if LivenessAnalysis.optidictKey in node.expr.optidict:
exprUnboundVars = node.expr.optidict[LivenessAnalysis.optidictKey]
varsToDeAllocate = list(set(curUnboundVars)-set(exprUnboundVars))
origNodeExpr = node.expr
astSubTree = node.expr
mtdForNewASTNodes = {AST.ASTNode.mtdKeyTFOpName : "No-op: ClearMem",
AST.ASTNode.mtdKeyTFNodeName : ""}
for ii, curVarName in enumerate(varsToDeAllocate):
assert(curVarName in args[2])
newSubTree = AST.Let(AST.ID("cv"+str(args[1]+ii)),
AST.Func(AST.Operators.ClearMemSecret if args[2][curVarName] else AST.Operators.ClearMemPublic,
AST.ID(curVarName)),
AST.ID(""))
args[0].visit(newSubTree, mtdForNewASTNodes)
newSubTree.expr = astSubTree
node.expr = newSubTree
astSubTree = node.expr
self.visit(node.name, [args[0], args[1]+len(varsToDeAllocate), args[2]])
self.visit(node.decl, [args[0], args[1]+len(varsToDeAllocate), args[2]])
isCurrentLetDeclarationSecret = True
if hasattr(node.decl, 'isSecret'):
isCurrentLetDeclarationSecret = node.decl.isSecret
assert(type(isCurrentLetDeclarationSecret)==bool)
self.visit(origNodeExpr, [args[0], args[1]+len(varsToDeAllocate), {**args[2], **{node.name.name: isCurrentLetDeclarationSecret}}])
'''
Authors: Pratik Bhatu
Copyright:
Copyright (c) 2020 Microsoft Research
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
'''
import AST.AST as AST
import Util
from AST.ASTVisitor import ASTVisitor
from AST.MtdAST import MtdAST
class SecretFlowAnalysis(ASTVisitor):
def __init__(self):
self.idf_to_secret = {}
self.node_to_secret = {}
def isSecret(self, idf:str):
return self.idf_to_secret[idf]
def visitInt(self, node:AST.Int, args):
self.node_to_secret[node] = node.isSecret
def visitFloat(self, node:AST.Float, args):
self.node_to_secret[node] = node.isSecret
def visitInput(self, node:AST.Input, args):
self.node_to_secret[node] = node.isSecret
def visitId(self, node:AST.ID, args):
self.node_to_secret[node] = self.idf_to_secret[node.name]
def visitLet(self, node:AST.Let, args):
self.visit(node.decl, args)
self.idf_to_secret[node.name.name] = self.node_to_secret[node.decl]
self.visit(node.expr, args)
def visitDecl(self, node:AST.Decl, args):
self.node_to_secret[node] = node.isSecret
if node.valueList:
for elem in node.valueList:
self.visit(elem, args)
def visitUninterpFuncCall(self, node:AST.UninterpFuncCall, args):
self.node_to_secret[node] = node.isSecret
for elem in node.argsList:
self.visit(elem, args)
def visitTranspose(self, node:AST.Transpose, args):
self.visit(node.expr, args)
self.node_to_secret[node] = self.node_to_secret[node.expr]
def visitSlice(self, node:AST.Slice, args):
self.visit(node.expr, args)
self.node_to_secret[node] = self.node_to_secret[node.expr]
def visitReshape(self, node:AST.Reshape, args):
self.visit(node.expr, args)
self.node_to_secret[node] = self.node_to_secret[node.expr]
def visitPool(self, node:AST.Pool, args):
self.visit(node.expr, args)
self.node_to_secret[node] = self.node_to_secret[node.expr]
def visitUOp(self, node:AST.UOp, args):
self.visit(node.expr, args)
self.node_to_secret[node] = self.node_to_secret[node.expr]
def visitBOp(self, node:AST.BOp, args):
self.visit(node.expr1, args)
self.visit(node.expr2, args)
self.node_to_secret[node] = self.node_to_secret[node.expr1] | self.node_to_secret[node.expr1]
def visitFunc(self, node:AST.Func, args):
self.visit(node.expr, args)
self.node_to_secret[node] = self.node_to_secret[node.expr]
def visitArgMax(self, node:AST.ArgMax, args):
self.visit(node.expr, args)
self.visit(node.dim, args)
self.node_to_secret[node] = self.node_to_secret[node.expr] | self.node_to_secret[node.dim]
def visitReduce(self, node:AST.Reduce, args):
self.visit(node.expr, args)
self.node_to_secret[node] = self.node_to_secret[node.expr]
def visitFusedBatchNorm(self, node:AST.FusedBatchNorm, args):
self.visit(node.expr, args)
self.visit(node.multExpr, args)
self.visit(node.addExpr, args)
self.node_to_secret[node] = self.node_to_secret[node.expr] | self.node_to_secret[node.multExpr] | self.node_to_secret[node.addExpr]
# A very basic alias analysis pass which creates alias sets for variables created
# through identity ops
# let a = b
class AliasAnalysis(ASTVisitor):
def __init__(self):
self.alias_sets = Util.DisjointSet()
super().__init__()
def add_alias(self, inp1, inp2):
self.alias_sets.make_set(inp1)
self.alias_sets.make_set(inp2)
self.alias_sets.union(inp1, inp2)
def get_alias_set(self, inp):
return self.alias_sets.get_key_set(inp)
def visitLet(self, node:AST.Let, args):
self.visit(node.decl)
self.visit(node.expr)
# Two IDs with same name can have diff pointers. Hence we store ID names instead of pointers.
if isinstance(node.decl, AST.ID):
self.add_alias(node.name.name, node.decl.name)
'''
We visit the program bottom up. Every time we encounter a use of a variable, we insert
a free instruction after it, unless the variable has already been freed.
We are basically freeing variables after their last use.
However, we also need to check for aliases of variables to avoid double frees and
use after free.
J100 = J99
J101 = J99 + 3 <- last use of J99
J102 = J100 * 2 <- last use of J100
if we transform this to:
J100 = J99
J101 = J99 + 3
free(J99)
J102 = J100 * 2 <- use after free
free(J100) <- double free
instead we want to do:
J100 = J99
J101 = J99 + 3
J102 = J100 * 2
free(J100)
..
'''
class GarbageCollector(ASTVisitor):
def __init__(self, ast):
self.ast = ast
self.secret_analysis = SecretFlowAnalysis()
self.secret_analysis.visit(self.ast)
self.alias_analysis = AliasAnalysis()
self.alias_analysis.visit(self.ast)
self.freed_nodes = set()
self.counter = 0
super().__init__()
def run(self, args):
self.visit(self.ast, args)
def isVarFreed(self, inp):
alias_set = self.alias_analysis.get_alias_set(inp)
if alias_set is None:
return inp in self.freed_nodes
for i in alias_set:
if i in self.freed_nodes:
return True
return False
def visitLet(self, node:AST.Let, args):
assert(isinstance(args, list))
assert(isinstance(args[0], MtdAST))
self.visit(node.expr, args)
usedVars = self.visit(node.decl, args)
if usedVars is None:
assert False, " visit of {} not implemented in GarbageCollector pass".format(str(type(node.decl)))
varsToDeAllocate = [i for i in usedVars if not self.isVarFreed(i)]
self.freed_nodes = self.freed_nodes.union(set(varsToDeAllocate))
astSubTree = node.expr
mtdForNewASTNodes = {AST.ASTNode.mtdKeyTFOpName : "No-op: ClearMem",
AST.ASTNode.mtdKeyTFNodeName : ""}
for ii, curVarName in enumerate(varsToDeAllocate):
newSubTree = AST.Let(AST.ID("cv"+str(self.counter+ii)),
AST.Func(AST.Operators.ClearMemSecret if self.secret_analysis.isSecret(curVarName) else AST.Operators.ClearMemPublic,
AST.ID(curVarName)),
AST.ID(""))
self.counter += 1
args[0].visit(newSubTree, mtdForNewASTNodes)
newSubTree.expr = astSubTree
node.expr = newSubTree
astSubTree = node.expr
def visitInt(self, node:AST.Int, args):
return set()
def visitFloat(self, node:AST.Float, args):
return set()
def visitInput(self, node:AST.Input, args):
return set()
def visitId(self, node:AST.ID, args):
return set([node.name])
def visitDecl(self, node:AST.Decl, args):
return set()
def visitTranspose(self, node:AST.Transpose, args):
usedVars = self.visit(node.expr, args)
return usedVars
def visitSlice(self, node:AST.Slice, args):
usedVars = self.visit(node.expr, args)
return usedVars
def visitReshape(self, node:AST.Reshape, args):
usedVars = self.visit(node.expr, args)
return usedVars
def visitPool(self, node:AST.Pool, args):
usedVars = self.visit(node.expr, args)
return usedVars
def visitUOp(self, node:AST.UOp, args):
usedVars = self.visit(node.expr, args)
return usedVars
def visitBOp(self, node:AST.BOp, args):
usedVars = self.visit(node.expr1, args) | self.visit(node.expr2, args)
return usedVars
def visitFunc(self, node:AST.Func, args):
usedVars = self.visit(node.expr, args)
return usedVars
def visitUninterpFuncCall(self, node:AST.UninterpFuncCall, args):
usedVars = set([])
for elem in node.argsList:
usedVars |= self.visit(elem, args)
return usedVars
def visitArgMax(self, node:AST.ArgMax, args):
usedVars = self.visit(node.expr, args) | self.visit(node.dim, args)
return usedVars
def visitReduce(self, node:AST.Reduce, args):
usedVars = self.visit(node.expr, args)
return usedVars
\ No newline at end of file
......@@ -141,4 +141,83 @@ def get_volume(shape: list):
vol = 1
for i in shape:
vol = vol * i
return vol
\ No newline at end of file
return vol
class DisjointSet:
class Node:
def __init__(self):
self.parent = self
self.children = []
def get_root(self):
if (self.parent != self):
old_parent = self.parent
self.parent = self.parent.get_root()
if self.parent != old_parent:
self.parent.children.append(self)
old_parent.children.remove(self)
return self.parent
else:
return self
def get_all_children(self):
all_children = []
all_children.extend(self.children)
tmp = []
for i in all_children:
tmp.extend(i.get_all_children())
all_children.extend(tmp)
return all_children
def __init__(self):
self.key_to_node = {}
self.node_to_key = {}
def inSet(self, inp):
return inp in self.key_to_node
def make_set(self, inp):
if self.inSet(inp):
return
n = self.Node()
self.key_to_node[inp] = n
self.node_to_key[n] = inp
def union(self, inp1, inp2):
n1 = self.key_to_node[inp1]
n2 = self.key_to_node[inp2]
r1 = n1.get_root()
r2 = n2.get_root()
if (r1 != r2):
r1.parent = r2
r2.children.append(r1)
def find(self, inp):
if not self.inSet(inp):
return None
return self.key_to_node[inp].get_root()
def find_key(self, inp):
node = self.find(inp)
if node is None:
return None
return self.node_to_key[node]
def get_set(self, inp):
if not self.inSet(inp):
return None
n = self.key_to_node[inp].get_root()
return [n] + n.get_all_children()
def get_key_set(self, inp):
nodes = self.get_set(inp)
if nodes is None:
return None
return [self.node_to_key[i] for i in nodes]
def print(self):
print(self.key_to_node)
print(self.node_to_key)
def print_set(self, inp):
print(self.get_key_set(inp))
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment