import os, sys sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'SeeDot')) #Add SeeDot directory to path import Graph, AST.AST as AST, _pickle as pickle, os from TFNodesAST import TFNodesAST from AST.PrintAST import PrintAST from AST.MtdAST import MtdAST def checkTFNodeNameForEq(curNodeOp:str, givenOp:str): return (curNodeOp == "\"" + givenOp + "\"") def generateASTForNode(graph, curNode, dictNodeNameToOutVarStr, extraNodeInfoDict): # print("===>>> Generating AST for (nodeOp, nodeName) : (" + curNode.getOp() + ", " + curNode.getName() + ")") curNodeOp = curNode.getOp() ast = None func = getattr(TFNodesAST, curNodeOp[1:-1]) #To remove the " at the begin and end (assignedVarAST, curAST) = func(graph, curNode, dictNodeNameToOutVarStr, extraNodeInfoDict) return (assignedVarAST, curAST) #Takes the graph DS and outputs IR in SeeDot for the same def generateIRCode(graph, extraInfoDict): program = None innerMostLetASTNode = None dictNodeNameToOutVarStr = {} outVarCt = 0 outVarPrefix = "J" mtdAST = MtdAST() for curNode in graph.getAllNodesRef(): for curInp in curNode.getInputsRef(): assert(curInp in dictNodeNameToOutVarStr) #Consequence of topological sorting of the TF graph (assignedVarAST, curAst) = generateASTForNode(graph, curNode, dictNodeNameToOutVarStr, extraInfoDict) mtdForCurAST = {AST.ASTNode.mtdKeyTFOpName : curNode.getOp()[1:-1], AST.ASTNode.mtdKeyTFNodeName : curNode.getName()[1:-1]} if (curAst is None): dictNodeNameToOutVarStr[curNode.getName()] = None continue curOutVarStr = outVarPrefix + str(outVarCt) curOutVarAstNode = (assignedVarAST if assignedVarAST else AST.ID(curOutVarStr)) if program: assert(type(innerMostLetASTNode) is AST.Let) newNode = AST.Let(curOutVarAstNode, curAst, curOutVarAstNode) mtdAST.visit(newNode, mtdForCurAST) innerMostLetASTNode.expr = newNode innerMostLetASTNode = newNode else: innerMostLetASTNode = AST.Let(AST.ID(curOutVarStr), curAst, curOutVarAstNode) mtdAST.visit(innerMostLetASTNode, mtdForCurAST) innerMostLetASTNode.depth = 0 program = innerMostLetASTNode dictNodeNameToOutVarStr[curNode.getName()] = curOutVarStr outVarCt += 1 return (program, dictNodeNameToOutVarStr) def readSizeInfo(fileName): allLines = None with open(fileName) as f: allLines = f.readlines() sizeInfo = {} for line in allLines: tokens = line.split() #assert(len(tokens) > 1) # Nodes with no size info are not getting outputted right now nodeName = tokens[0] tokens = tokens[1:] nodeOPSize = [] if (not tokens): nodeOPSize = [1] else: for curDimStr in tokens: if (curDimStr == ''): continue nodeOPSize.append(int(curDimStr)) sizeInfo[nodeName] = nodeOPSize return sizeInfo # Since later on in the pipeline, the placeholder nodes which come up as cin statements # are to be excluded from the timing calculation, output all such PlaceHolder nodes together first. # This doesn't violate the topological ordering because all such PlaceHolder nodes are leaf nodes # in the graph. def prefixAllPlaceHolderNodes(graph): allNodes = graph.getAllNodesRef() placeHolderNodes = [] remNodes = [] for curNode in allNodes: if (curNode.getOp() == "\"Placeholder\"" or curNode.getOp() == "\"VariableV2\""): # Assert this is indeed a leaf node assert(len(curNode.getInputsRef()) == 0) placeHolderNodes.append(curNode) else: remNodes.append(curNode) graph.setNodesList(placeHolderNodes + remNodes) def main(): sys.setrecursionlimit(10000) # First read the graph file if (len(sys.argv) < 2): print("TF python file unspecified.", file=sys.stderr) exit(1) filename = sys.argv[1] folderName = os.path.dirname(filename) graphFileName = os.path.join(folderName, 'graphDef.mtdata') graph = Graph.Graph() with open(graphFileName) as file: graph.readFromFilePointer(file) # Read the sizeInfo also sizeInfoFileName = os.path.join(folderName, 'sizeInfo.mtdata') sizeInfo = readSizeInfo(sizeInfoFileName) # Place all PlaceHolder nodes together at the beginning prefixAllPlaceHolderNodes(graph) # Re-format the input names of nodes for curNode in graph.getAllNodesRef(): inputsRef = curNode.getInputsRef() for i,curInput in enumerate(inputsRef): #TODO for training : below is not correct # if (curInput.endswith(':1"')): # inputsRef[i] = curInput.split(':1')[0] + '"' if (curInput.startswith('"^')): # My hypothesis from empirical observation is that inputs which have '^' ahead of the node name # denote control flow dependency and not data dependency. # For all purposes for this compilation, control and data dependency is considered same. # The reasoning being that everything is serial -- and graph execution is done in a # a topological sort. inputsRef[i] = '"' + curInput.split('^')[-1] # Create extra info dict # Format : (sizeInfo) extraInfoDict = {} for k,v in sizeInfo.items(): extraInfoDict["\"" + k + "\""] = (v,) for curNode in graph.getAllNodesRef(): if (curNode.getName() not in extraInfoDict): extraInfoDict[curNode.getName()] = (None,) print("Generating code from TF graph def : ", graphFileName, " ...") (program, dictNodeNameToOutVarStr) = generateIRCode(graph, extraInfoDict) print("SeeDot AST generation done. Pickling the AST.") with open(os.path.join(folderName, 'astOutput.pkl'), 'wb') as f: pickle.dump(program, f) if __name__ == "__main__": main()