Skip to content
Snippets Groups Projects
Commit 39e78075 authored by Bhatu's avatar Bhatu
Browse files

Remove double quotes from attributes

Remove the "" from attributes while parsing the graph def itself.
eg: "\"dtype\"" -> "dtype"
So we can directly refer to the attributes without adding double quotes to them.
parent 2172e484
No related branches found
No related tags found
No related merge requests found
......@@ -520,7 +520,7 @@ class Node:
return self.__attr
def getAttrVal(self, attrName):
qName = '"' + attrName + '"'
qName = attrName
if not qName in self.__attr:
return None
return self.__attr[qName]
......@@ -541,7 +541,7 @@ class Node:
#keyStr is already non-None .. there is then probably some error
print("Too many keys found while parsing attr for node at line =", cnt, file=sys.stderr)
return (False, cnt)
keyStr = tokens[1]
keyStr = tokens[1][1:-1]
elif (curToken == "value"):
curVal = Value()
(noParseError, cnt) = curVal.readFromFilePointer(fileP, cnt)
......@@ -570,13 +570,13 @@ class Node:
return (True, cnt)
elif (curToken == "name:"):
if (errIfTokensNotMinLen(tokens, 2, cnt, "node")): return (False, cnt)
self.__name = tokens[1]
self.__name = tokens[1][1:-1]
elif (curToken == "op:"):
if (errIfTokensNotMinLen(tokens, 2, cnt, "node")): return (False, cnt)
self.__op = tokens[1]
self.__op = tokens[1][1:-1]
elif (curToken == "input:"):
if (errIfTokensNotMinLen(tokens, 2, cnt, "node")): return (False, cnt)
self.__inputs.append(tokens[1])
self.__inputs.append(tokens[1][1:-1])
elif (curToken == "attr"):
(noParseError, cnt) = self.readAttrFromFilePointer(fileP, cnt)
if (not(noParseError)):
......
......@@ -36,7 +36,7 @@ def checkTFNodeNameForEq(curNodeOp:str, givenOp:str):
def generateASTForNode(graph, curNode, dictNodeNameToOutVarStr, extraNodeInfoDict):
curNodeOp = curNode.getOp()
ast = None
func = getattr(TFNodesAST, curNodeOp[1:-1]) #To remove the " at the begin and end
func = getattr(TFNodesAST, curNodeOp)
(assignedVarAST, curAST) = func(graph, curNode, dictNodeNameToOutVarStr, extraNodeInfoDict)
return (assignedVarAST, curAST)
......@@ -53,8 +53,8 @@ def generateIRCode(graph, extraInfoDict):
assert(curInp in dictNodeNameToOutVarStr) #Consequence of topological sorting of the TF graph
(assignedVarAST, curAst) = generateASTForNode(graph, curNode, dictNodeNameToOutVarStr, extraInfoDict)
mtdForCurAST = {AST.ASTNode.mtdKeyTFOpName : curNode.getOp()[1:-1],
AST.ASTNode.mtdKeyTFNodeName : curNode.getName()[1:-1]}
mtdForCurAST = {AST.ASTNode.mtdKeyTFOpName : curNode.getOp(),
AST.ASTNode.mtdKeyTFNodeName : curNode.getName()}
if (curAst is None):
dictNodeNameToOutVarStr[curNode.getName()] = None
......@@ -104,7 +104,7 @@ def prefixAllPlaceHolderNodes(graph):
placeHolderNodes = []
remNodes = []
for curNode in allNodes:
if (curNode.getOp() == "\"Placeholder\"" or curNode.getOp() == "\"VariableV2\""):
if (curNode.getOp() == "Placeholder" or curNode.getOp() == "VariableV2"):
# Assert this is indeed a leaf node
assert(len(curNode.getInputsRef()) == 0)
placeHolderNodes.append(curNode)
......@@ -138,19 +138,19 @@ def main():
for curNode in graph.getAllNodesRef():
inputsRef = curNode.getInputsRef()
for i,curInput in enumerate(inputsRef):
if (curInput.startswith('"^')):
if (curInput.startswith('^')):
# My hypothesis from empirical observation is that inputs which have '^' ahead of the node name
# denote control flow dependency and not data dependency.
# For all purposes for this compilation, control and data dependency is considered same.
# The reasoning being that everything is serial -- and graph execution is done in a
# a topological sort.
inputsRef[i] = '"' + curInput.split('^')[-1]
inputsRef[i] = curInput.split('^')[-1]
# Create extra info dict
# Format : (sizeInfo)
extraInfoDict = {}
for k,v in sizeInfo.items():
extraInfoDict["\"" + k + "\""] = (v,)
extraInfoDict[k] = (v,)
for curNode in graph.getAllNodesRef():
if (curNode.getName() not in extraInfoDict):
extraInfoDict[curNode.getName()] = (None,)
......
......@@ -68,17 +68,17 @@ class TFNodesAST:
attrMapRef = curNode.getAttrMapRef()
transposeABool = transposeBBool = False
if ("\"transpose_a\"" in attrMapRef):
transposeABool = attrMapRef["\"transpose_a\""].getB()
if ("\"transpose_b\"" in attrMapRef):
transposeBBool = attrMapRef["\"transpose_b\""].getB()
if ("transpose_a" in attrMapRef):
transposeABool = attrMapRef["transpose_a"].getB()
if ("transpose_b" in attrMapRef):
transposeBBool = attrMapRef["transpose_b"].getB()
if (transposeABool): inp1AST = AST.Transp(inp1AST)
if (transposeBBool): inp2AST = AST.Transp(inp2AST)
return (None, AST.BOp(inp1AST, TFNodesAST.getOperatorsIdx('*'), inp2AST))
def Placeholder(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dict, extraNodeInfoDict : dict):
curNodeShapeLi = extraNodeInfoDict[curNode.getName()][0]
curNodeInputType = curNode.getAttrMapRef()["\"dtype\""].getDataType()
curNodeInputType = curNode.getAttrMapRef()["dtype"].getDataType()
assert(curNodeInputType is not Graph.DataTypeEnum.DT_INVALID)
# NOTE: There has to be some way for Athos to differentiate model from image, since in the compiled code
......@@ -104,7 +104,7 @@ class TFNodesAST:
inputsRef = curNode.getInputsRef()
assert(len(inputsRef)==1)
curNodeDataType = curNode.getAttrMapRef()["\"T\""].getDataType()
curNodeDataType = curNode.getAttrMapRef()["T"].getDataType()
assert(curNodeDataType is not Graph.DataTypeEnum.DT_INVALID)
curNodeShape = extraNodeInfoDict[curNode.getName()][0]
......@@ -175,8 +175,8 @@ class TFNodesAST:
return (None, AST.Func(TFNodesAST.getOperatorsIdx('floor'), realDivAST))
def VariableV2(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dict, extraNodeInfoDict : dict):
curNodeShapeLi = curNode.getAttrMapRef()["\"shape\""].getShape().getDimRef()[:]
curNodeInputType = curNode.getAttrMapRef()["\"dtype\""].getDataType()
curNodeShapeLi = curNode.getAttrMapRef()["shape"].getShape().getDimRef()[:]
curNodeInputType = curNode.getAttrMapRef()["dtype"].getDataType()
# NOTE : since this becomes an input node right now, i have also added to be prefixed at top in ProcessTFGraph::prefixAllPlaceHolderNodes()
# NOTE: There has to be some way for Athos to differentiate model from image, since in the compiled code
......@@ -188,8 +188,8 @@ class TFNodesAST:
def Const(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dict, extraNodeInfoDict : dict):
assert(len(curNode.getInputsRef()) == 0)
tensor = curNode.getAttrMapRef()["\"value\""].getTensor()
curNodeDataType = curNode.getAttrMapRef()["\"dtype\""].getDataType()
tensor = curNode.getAttrMapRef()["value"].getTensor()
curNodeDataType = curNode.getAttrMapRef()["dtype"].getDataType()
curNodeShape = tensor.getShapeRef()[:] #create a different copy to not change the original copy
tensorConstantVal = tensor.getConstantVal()
......@@ -235,8 +235,8 @@ class TFNodesAST:
def Cast(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dict, extraNodeInfoDict : dict):
inputsRef = curNode.getInputsRef()
assert(len(inputsRef) == 1)
sourceType = curNode.getAttrMapRef()["\"SrcT\""].getDataType()
destType = curNode.getAttrMapRef()["\"DstT\""].getDataType()
sourceType = curNode.getAttrMapRef()["SrcT"].getDataType()
destType = curNode.getAttrMapRef()["DstT"].getDataType()
return (None, AST.UninterpFuncCall(extraNodeInfoDict[curNode.getName()][0],
TFNodesAST.UninterpFuncCallNames.Cast.name,
[AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]),
......@@ -247,7 +247,7 @@ class TFNodesAST:
def ZerosLike(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dict, extraNodeInfoDict : dict):
inputsRef = curNode.getInputsRef()
assert(len(inputsRef)==1)
curNodeOutputType = curNode.getAttrMapRef()["\"T\""].getDataType()
curNodeOutputType = curNode.getAttrMapRef()["T"].getDataType()
assert(curNodeOutputType is not Graph.DataTypeEnum.DT_INVALID)
retAST = AST.UninterpFuncCall(extraNodeInfoDict[curNode.getName()][0],
TFNodesAST.UninterpFuncCallNames.CreateTensor.name,
......@@ -261,7 +261,7 @@ class TFNodesAST:
curNodeOutputShape = extraNodeInfoDict[inputsRef[0]][0]
assert(len(curNodeOutputShape) == 1) #inputsRef[0] denotes a shape and should have a rank of 1
curNodeOutputType = curNode.getAttrMapRef()["\"T\""].getDataType()
curNodeOutputType = curNode.getAttrMapRef()["T"].getDataType()
assert(curNodeOutputType is not Graph.DataTypeEnum.DT_INVALID)
retAST = AST.UninterpFuncCall(extraNodeInfoDict[curNode.getName()][0],
......@@ -280,7 +280,7 @@ class TFNodesAST:
assert(FD)
assert(strideD)
zPadHLeft = zPadHRight = zPadWLeft = zPadWRight = zPadDLeft = zPadDRight = -1
if (paddingUsedStr == "\"SAME\""):
if (paddingUsedStr == "SAME"):
# Reference for following:
# https://web.archive.org/web/20171223022012/https://www.tensorflow.org/api_guides/python/nn
totalPaddingH = totalPaddingW = totalPaddingD = 0
......@@ -310,7 +310,7 @@ class TFNodesAST:
zPadDLeft = totalPaddingD // 2
zPadDRight = totalPaddingD - zPadDLeft
elif (paddingUsedStr == "\"VALID\""):
elif (paddingUsedStr == "VALID"):
zPadHLeft = zPadHRight = zPadWLeft = zPadWRight = zPadDLeft = zPadDRight = 0
else:
zPadHLeft = zPadHRight = zPadWLeft = zPadWRight = zPadDLeft = zPadDRight = -1
......@@ -324,7 +324,7 @@ class TFNodesAST:
inputsRef = curNode.getInputsRef()
assert(len(inputsRef)==2)
stridesUsed = curNode.getAttrMapRef()["\"strides\""].getList().getILi()
stridesUsed = curNode.getAttrMapRef()["strides"].getList().getILi()
assert(stridesUsed[0]==1 and stridesUsed[3]==1)
strideH = stridesUsed[1]
strideW = stridesUsed[2]
......@@ -337,7 +337,7 @@ class TFNodesAST:
FH = filterShape[0]
FW = filterShape[1]
paddingUsedStr = curNode.getAttrMapRef()["\"padding\""].getS()
paddingUsedStr = curNode.getAttrMapRef()["padding"].getS()
[zPadHLeft, zPadHRight, zPadWLeft, zPadWRight] = TFNodesAST.helper_findPadding(imgH, imgW,
FH, FW,
......@@ -363,7 +363,7 @@ class TFNodesAST:
inputsRef = curNode.getInputsRef()
assert(len(inputsRef)==2)
stridesUsed = curNode.getAttrMapRef()["\"strides\""].getList().getILi()
stridesUsed = curNode.getAttrMapRef()["strides"].getList().getILi()
assert(stridesUsed[0]==1 and stridesUsed[4]==1)
strideD = stridesUsed[1]
strideH = stridesUsed[2]
......@@ -379,7 +379,7 @@ class TFNodesAST:
FH = filterShape[1]
FW = filterShape[2]
paddingUsedStr = curNode.getAttrMapRef()["\"padding\""].getS()
paddingUsedStr = curNode.getAttrMapRef()["padding"].getS()
[zPadDLeft, zPadDRight, zPadHLeft, zPadHRight, zPadWLeft, zPadWRight] = TFNodesAST.helper_findPadding(imgH, imgW, FH, FW, strideH, strideW, paddingUsedStr, imgD, FD, strideD )
......@@ -406,7 +406,7 @@ class TFNodesAST:
inputsRef = curNode.getInputsRef()
assert(len(inputsRef)==3) #output_shape, filter, input
stridesUsed = curNode.getAttrMapRef()["\"strides\""].getList().getILi()
stridesUsed = curNode.getAttrMapRef()["strides"].getList().getILi()
assert(stridesUsed[0]==1 and stridesUsed[4]==1)
strideD = stridesUsed[1]
strideH = stridesUsed[2]
......@@ -427,7 +427,7 @@ class TFNodesAST:
outputH = outputShape[2]
outputW = outputShape[3]
paddingUsedStr = curNode.getAttrMapRef()["\"padding\""].getS()
paddingUsedStr = curNode.getAttrMapRef()["padding"].getS()
# Important: Using outputH and outputW in the below is not an error!
# For convTranspose, the parameters passed in the node are of the conv of which this convTranspose is an inverse.
......@@ -463,12 +463,12 @@ class TFNodesAST:
options = {}
stridesUsed = curNode.getAttrMapRef()["\"strides\""].getList().getILi()
stridesUsed = curNode.getAttrMapRef()["strides"].getList().getILi()
assert((stridesUsed[0] == 1) and (stridesUsed[3] == 1))
strideH = stridesUsed[1]
strideW = stridesUsed[2]
kSizeUsed = curNode.getAttrMapRef()["\"ksize\""].getList().getILi()
kSizeUsed = curNode.getAttrMapRef()["ksize"].getList().getILi()
assert((kSizeUsed[0] == 1) and (kSizeUsed[3] == 1))
kSizeH = kSizeUsed[1]
kSizeW = kSizeUsed[2]
......@@ -477,7 +477,7 @@ class TFNodesAST:
imgH = inputShape[1]
imgW = inputShape[2]
paddingUsedStr = curNode.getAttrMapRef()["\"padding\""].getS()
paddingUsedStr = curNode.getAttrMapRef()["padding"].getS()
[zPadHLeft, zPadHRight, zPadWLeft, zPadWRight] = TFNodesAST.helper_findPadding(imgH, imgW,
kSizeH, kSizeW,
strideH, strideW,
......@@ -512,7 +512,7 @@ class TFNodesAST:
def ConcatV2(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dict, extraNodeInfoDict : dict):
inputsRef = curNode.getInputsRef()
N = curNode.getAttrMapRef()["\"N\""].getI()
N = curNode.getAttrMapRef()["N"].getI()
assert(len(inputsRef) == N+1) #One extra for axis
#TODO : Since the axis of concat is constant, therefore, its known here - the input's sizes along that dim should be
# passed as input to the below function.
......@@ -535,7 +535,7 @@ class TFNodesAST:
def Slice(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dict, extraNodeInfoDict : dict):
inputsRef = curNode.getInputsRef()
assert(len(inputsRef) == 3)
curNodeDataType = curNode.getAttrMapRef()["\"T\""].getDataType()
curNodeDataType = curNode.getAttrMapRef()["T"].getDataType()
retAST = AST.UninterpFuncCall(extraNodeInfoDict[curNode.getName()][0],
TFNodesAST.UninterpFuncCallNames.CreateCopy.name,
[AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]), # of this
......@@ -556,8 +556,8 @@ class TFNodesAST:
attrMapRef = curNode.getAttrMapRef()
assert(len(inputsRef) == 2)
keepdims = False
if ("\"keep_dims\"" in attrMapRef):
keepdims = attrMapRef["\"keep_dims\""].getB()
if ("keep_dims" in attrMapRef):
keepdims = attrMapRef["keep_dims"].getB()
curNodeShapeLi = extraNodeInfoDict[curNode.getName()][0]
return (None, AST.Reduce(AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]),
AST.ID(dictNodeNameToOutVarStr[inputsRef[1]]),
......@@ -570,8 +570,8 @@ class TFNodesAST:
attrMapRef = curNode.getAttrMapRef()
assert(len(inputsRef) == 2)
keepdims = False
if ("\"keep_dims\"" in attrMapRef):
keepdims = attrMapRef["\"keep_dims\""].getB()
if ("keep_dims" in attrMapRef):
keepdims = attrMapRef["keep_dims"].getB()
curNodeShapeLi = extraNodeInfoDict[curNode.getName()][0]
return (None, AST.Reduce(AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]),
AST.ID(dictNodeNameToOutVarStr[inputsRef[1]]),
......@@ -601,12 +601,12 @@ class TFNodesAST:
def Pad(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dict, extraNodeInfoDict : dict):
# Mode refers to 'CONSTANT', 'REFLECT' or 'SYMMETRIC'
mode = 0
if ("\"mode\"" in curNode.getAttrMapRef()):
mode = curNode.getAttrMapRef()["\"mode\""].getI()
if ("mode" in curNode.getAttrMapRef()):
mode = curNode.getAttrMapRef()["mode"].getI()
constant_values = 0
if ("\"constant_values\"" in curNode.getAttrMapRef()):
constant_values = curNode.getAttrMapRef()["\"constant_values\""].getI()
if ("constant_values" in curNode.getAttrMapRef()):
constant_values = curNode.getAttrMapRef()["constant_values"].getI()
assert(mode == 0 and constant_values == 0) # For now to make life easy - deal with SYMMETRIC AND REFLECT when time comes
inputsRef = curNode.getInputsRef()
......@@ -650,7 +650,7 @@ class TFNodesAST:
inputTensorShape = extraNodeInfoDict[inputsRef[0]][0]
inputTensorRank = len(inputTensorShape)
squeezeDims = curNode.getAttrMapRef()["\"squeeze_dims\""].getList().getILi()
squeezeDims = curNode.getAttrMapRef()["squeeze_dims"].getList().getILi()
squeezeDimsRank = len(squeezeDims)
return (None, AST.UninterpFuncCall(extraNodeInfoDict[curNode.getName()][0],
......@@ -682,4 +682,4 @@ class TFNodesAST:
return (None, AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]))
def VarHandleOp(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dict, extraNodeInfoDict : dict):
return TFNodesAST.VariableV2(graph, curNode, dictNodeNameToOutVarStr, extraNodeInfoDict)
return TFNodesAST.VariableV2(graph, curNode, dictNodeNameToOutVarStr, extraNodeInfoDict)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment