from AST import ASTNode, getParseTreeNodes, getASTNode from LineNodes import BlockNode, VarDclNode from ExprPrimaryNodes import makeNodeFromExpr from UnitNodes import ParamNode from TheTypeNode import TypeNode from Environment import Env from collections import OrderedDict from CodeGenUtils import p, pLabel, genProcedure, importHelper # field class FieldNode(ASTNode): # always list all fields in the init method to show the class structure def __init__(self, parseTree, typeName, order): self.parseTree = parseTree self.name = '' self.variableDcl = None self.mods = [] self.env = None self.children = [] self.typeName = typeName self.order = order for node in parseTree.children: if node.name == 'methodMod': for m in node.children: self.mods.append(m.lex) elif node.name == 'variableDcl': self.variableDcl = VarDclNode(node, self.typeName, isField=self) self.name = self.variableDcl.name self.myType = self.variableDcl.myType self.children.append(self.variableDcl) def __eq__(self, other): return self.name == other.name def checkType(self): self.variableDcl.checkType() # check forward reference if self.variableDcl.variableInit: allNames = getForwardRefNames(self.variableDcl.variableInit) from pprint import pprint for n in allNames: if n.prefixLink is self.variableDcl: raise Exception("ERROR: Forward reference of field {} in itself is not allowed.".format(n.prefixLink.name)) if n.prefixLink.__class__.__name__ == 'FieldNode' \ and n.prefixLink.typeName == self.typeName \ and self.order <= n.prefixLink.order \ and "this" not in n.name: raise Exception("ERROR: Forward reference of field {} is not allowed.".format(n.prefixLink.name)) # Note: 1. Not calling codeGen for variableDcl since all variableDcl code are assuming that the variable is a LOCAL variable # 2. Only calling codeGen on the variableInit of self.variableDcl if the field is static, since non-static fields should # be initialized by the constructor def codeGen(self): if hasattr(self, "code"): return self.code = "" self.data = "" label = self.typeName + "_" + self.name # static fields: the pointer lives in assembly if "static" in self.mods: self.data += ";Declaring a static field: " + label + "\n" self.data += pLabel(name=label, type="static") + \ p(instruction="dd", arg1=0, comment="Declaring space on assembly for a static field") self.data += ";End of declaration of static field\n" # Initializing static fields # static fields are intialized in the order of declaration within the class and has to be intialized # before the test() method is being called initNode = self.variableDcl.variableInit if initNode: self.code += ";Start of initialization of static field\n" initNode.codeGen() self.code += "; Calculating the initial value of declared field: " + label + "\n" self.code += initNode.code # Filling in label with pointer's address self.code += p(instruction="mov", arg1="ebx", arg2="dword S_"+label) + \ p(instruction="mov", arg1="[ebx]", arg2="eax", comment="eax is a pointer to field value in heap") self.code += ";End of initialization of static field\n" ########################################################### # method class MethodNode(ASTNode): # always list all fields in the init method to show the class structure def __init__(self, parseTree, typeName, order): self.parseTree = parseTree self.name = '' self.methodType = '' self.params = [] # a list of paramNodes self.mods = [] self.body = None self.paramTypes = '' # a string of param types (signature) for easy type checking against arguments self.env = None self.children = [] self.typeName = typeName self.order = order self.myType = None self.isInterM = False # if self is an interface method self.replace = None # pointer to the method that this method overrides # get method name nameNodes = getParseTreeNodes(['ID'], parseTree, ['params', 'type', 'methodBody']) for n in nameNodes: self.name = n.lex # params nameNodes = getParseTreeNodes(['param'], parseTree, ['methodBody']) for n in nameNodes: paramNode = ParamNode(n, self.typeName) self.params.append(paramNode) if paramNode.paramType.myType.isArray: self.paramTypes += paramNode.paramType.myType.name + "Array" else: self.paramTypes += paramNode.paramType.myType.name nameNodes = getParseTreeNodes(['type', 'VOID'], parseTree, ['methodBody', 'params']) for n in nameNodes: if n.name == 'VOID': self.methodType = TypeNode('VOID', self.typeName) else: self.methodType = TypeNode(n, self.typeName) for node in parseTree.children: if node.name == 'methodMod' or node.name == "interfaceMod": for m in node.children: self.mods.append(m.lex) elif node.name == 'methodBody': nameNodes = getParseTreeNodes(['block'], node) for n in nameNodes: self.body = BlockNode(n, typeName) if self.body: self.children.append(self.body) self.children.append(self.methodType) self.children.extend(self.params) def __eq__(self, other): if self.name == other.name and len(self.params) == len(other.params): for i in range(len(self.params)): if not self.params[i].paramType == other.params[i].paramType: if self.name == 'addAll': raise Exception('HERE {}, {}'.format(self.params[i].paramType.name, other.params[i].paramType.name)) return False return True return False def buildEnv(self, parentEnv): env = Env(parentEnv) for p in self.params: env.addtoEnv(p) self.env = env return env def disambigName(self): if self.body: self.body.disambigName() for p in self.params: p.disambigName() def checkType(self): if self.methodType: # constructor would be None self.myType = self.methodType.myType for p in self.params: p.checkType() if self.body: self.body.checkType() # Checking return types against the function type # No method body: do not check type as function isn't implemented if not self.body: return # check no use of this in static method if 'static' in self.mods: names = getNameNodes(self.body) for n in names: if 'this' in n.name or (n.pointToThis and n.prefixLink.__class__.__name__ == ['MethodNode', 'FieldNode'] and 'static' not in n.prefixLink.mods): raise Exception("ERROR: Cannot use non-static member {} in static method {} in class {}".format(n.name, self.name, self.typeName)) # With method body returnNodes = getASTNode(["ReturnNode"], self.body) for n in returnNodes: n.method = self # Checking for functions of type void # Only valid if either the function doesn't have a return statement, or the return statement is a semicolon (return;) if self.myType and self.myType.name == "void": if n.myType: raise Exception("ERROR: return type of function {} doesn't match with return statement.".format(self.name)) return # Checking for non void cases if self.myType and not self.myType.assignable(n.myType): raise Exception("ERROR: return type of function {} doesn't match with return statement.".format(self.name)) return def reachCheck(self, inMaybe=True): self.outMaybe = False # Check reachability of method body if self.body: self.body.reachCheck(True) # For method bodies, in[L] = maybe by default self.outMaybe = self.body.outMaybe # Check if out[method body] is a maybe for non-void methods if not self.methodType or self.myType.name == "void": # Omitting the check for constructors and void functions return if self.outMaybe: raise Exception("Non-void method '{}' in class '{}' does not return".format(self.name, self.typeName)) return def codeGen(self): if hasattr(self, "code") and self.code != "": return self.label = "M_" + self.typeName + "_" + self.name + "_" + self.paramTypes self.code = pLabel(self.typeName + "_" + self.name + "_" + self.paramTypes, "method") # params stackoffset = 12 if 'static' in self.mods: stackoffset = 8 # no object as argument for static methods rparams = self.params.copy() rparams.reverse() for i, param in enumerate(rparams): param.offset = i * 4 + stackoffset # 12 since the stack is now of the order: ebp, eip, o, params if self.body: bodyCode = "" # push all local var to stack vars = getVarDclNodes(self.body) for i in range(len(vars)): vars[i].offset = i * 4 + 16 bodyCode += p("push", 0) self.body.codeGen() bodyCode += self.body.code bodyCode += self.label + "_end: ; end of method for " + self.name + "\n" # pop off all the local var for i in range(len(vars)): bodyCode += p("pop", "edx") self.code += genProcedure(bodyCode, "method definition for " + self.name) else: self.code += p("mov", "eax", "0") self.code += p("ret", "") # This method is called instead of codeGen if this is a constructor def codeGenConstructor(self): if hasattr(self, "code") and self.code != "": return # populate param offsets # params stackoffset = 12 rparams = self.params.copy() rparams.reverse() for i, param in enumerate(rparams): param.offset = i * 4 + stackoffset # 12 since the stack is now of the order: ebp, eip, o, params myClass = self.env.getNode(self.typeName, 'type') self.label = "M_" + self.typeName + "_" + self.name + "_" + self.paramTypes self.code = pLabel(self.typeName + "_" + self.name + "_" + self.paramTypes, "method") # label thisLoc = 8 # Right after ebp, eip bodyCode = "" if self.body: # push all local var to stack vars = getVarDclNodes(self.body) for i in range(len(vars)): vars[i].offset = i * 4 + 16 bodyCode += p("push", 0) # call parent constructor(zero argument) if myClass.superClass: suLabel = "M_" + myClass.superClass.name + "_" + myClass.superClass.name + "_" bodyCode += importHelper(myClass.superClass.name, self.typeName, suLabel) bodyCode += p("mov", "eax", "[ebp + " + str(thisLoc) + "]") bodyCode += p("push", "eax", None, "# Pass THIS as argument to superClass.") bodyCode += p("call", suLabel) bodyCode += p("add", "esp", "4") # pop object off stack # init fields fields = sorted(myClass.fields, key=lambda x: x.order) for f in fields: if not 'static' in f.mods and f.variableDcl.variableInit: f.variableDcl.variableInit.codeGen() bodyCode += f.variableDcl.variableInit.code bodyCode += p("mov", "ebx", "[ebp + " + str(thisLoc) + "]") # THIS bodyCode += p("mov", "[ebx + " + str(f.offset) + " ]", "eax") # body code if self.body: self.body.codeGen() bodyCode += self.body.code bodyCode += self.label + "_end: ; end of method for " + self.name + "\n" # pop off all the local var for i in range(len(vars)): bodyCode += p("pop", "edx") else: bodyCode += p("mov", "eax", "0") self.code += genProcedure(bodyCode, "Constructor definition for " + self.name + " " + self.paramTypes) # gets the top-most level method that this method overrides def getTopReplace(self): if not self.replace: return self else: return self.replace.getTopReplace() ############# helper for forward ref checking ######## # Input: AST Node # Output: A list of names to be check def getForwardRefNames(node): if node.__class__.__name__ == 'NameNode': return [node] result = [] if node.__class__.__name__ == 'AssignNode': result.extend(getForwardRefNames(node.right)) else: for c in node.children: result.extend(getForwardRefNames(c)) return result # Input: AST Node # Output: A list of names to be check def getNameNodes(node): if not node: return [] if node.__class__.__name__ == 'NameNode': return [node] result = [] for c in node.children: result.extend(getNameNodes(c)) return result # Input: Block Node # Output: A list of local var dcl to be pushed onto the stack def getVarDclNodes(node): result = [] if node.__class__.__name__ == 'VarDclNode': return [node] elif node.__class__.__name__ == 'BlockNode': for s in node.statements: if s.__class__.__name__ == 'VarDclNode': result += [s] if s.__class__.__name__ == 'BlockNode': result += getVarDclNodes(node) if s.__class__.__name__ == 'ForNode': if s.forInit: result += [s.forInit] if s.bodyStatement: result += getVarDclNodes(s.bodyStatement) if s.__class__.__name__ =='WhileNode' and s.whileBody: result += getVarDclNodes(s.whileBody) if s.__class__.__name__ == "IfNode": result += getVarDclNodes(s.ifBody) if s.elseBody: result += getVarDclNodes(s.elseBody) return result