Skip to content
Snippets Groups Projects
Commit db51c95f authored by Bhatu's avatar Bhatu
Browse files

Convert Slice operator to compile-time codgen.

parent 1a3b304c
No related branches found
No related tags found
No related merge requests found
......@@ -555,14 +555,21 @@ class TFNodesAST:
def Slice(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dict, extraNodeInfoDict : dict):
inputsRef = curNode.getInputsRef()
assert(len(inputsRef) == 3)
curNodeDataType = curNode.getAttrMapRef()["T"].getDataType()
retAST = AST.UninterpFuncCall(extraNodeInfoDict[curNode.getName()][0],
TFNodesAST.UninterpFuncCallNames.CreateCopy.name,
[AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]), # of this
AST.ID(dictNodeNameToOutVarStr[inputsRef[1]]), # begin idx
AST.ID(dictNodeNameToOutVarStr[inputsRef[2]]) # size
])
return (None, { curNode.getName() : retAST})
beginNode = graph.__getitem__(inputsRef[1])
sizeNode = graph.__getitem__(inputsRef[2])
assert beginNode.getAttrVal("value") is not None, "begin {} of Slice node {} has to be a constant".format(inputsRef[1], curNode.getName())
assert sizeNode.getAttrVal("value") is not None, "size {} of Slice node {} has to be a constant".format(inputsRef[2], curNode.getName())
begin = beginNode.getAttrVal("value").getTensor().getContentAsValArr()
size = sizeNode.getAttrVal("value").getTensor().getContentAsValArr()
assert begin is not None
assert size is not None
assert len(begin) == len(size)
subscriptRanges = []
for i in range(0,len(size)):
subscriptRanges.append((begin[i], begin[i] + size[i] - 1))
return (None, { curNode.getName() : AST.Slice(AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]),
subscriptRanges)})
def Tile(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dict, extraNodeInfoDict : dict):
inputsRef = curNode.getInputsRef()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment