MemberNodes.py 12.57 KiB
from AST import ASTNode, getParseTreeNodes, getASTNode
from LineNodes import BlockNode, VarDclNode
from ExprPrimaryNodes import makeNodeFromExpr
from UnitNodes import ParamNode
from TheTypeNode import TypeNode
from Environment import Env
from collections import OrderedDict
from CodeGenUtils import p, pLabel, genProcedure
# field
class FieldNode(ASTNode):
# always list all fields in the init method to show the class structure
def __init__(self, parseTree, typeName, order):
self.parseTree = parseTree
self.name = ''
self.variableDcl = None
self.mods = []
self.env = None
self.children = []
self.typeName = typeName
self.order = order
for node in parseTree.children:
if node.name == 'methodMod':
for m in node.children:
self.mods.append(m.lex)
elif node.name == 'variableDcl':
self.variableDcl = VarDclNode(node, self.typeName)
self.name = self.variableDcl.name
self.myType = self.variableDcl.myType
self.children.append(self.variableDcl)
def __eq__(self, other):
return self.name == other.name
def checkType(self):
self.variableDcl.checkType()
# check forward reference
if self.variableDcl.variableInit:
allNames = getForwardRefNames(self.variableDcl.variableInit)
from pprint import pprint
for n in allNames:
if n.prefixLink is self.variableDcl:
raise Exception("ERROR: Forward reference of field {} in itself is not allowed.".format(n.prefixLink.name))
if n.prefixLink.__class__.__name__ == 'FieldNode' \
and n.prefixLink.typeName == self.typeName \
and self.order <= n.prefixLink.order \
and "this" not in n.name:
raise Exception("ERROR: Forward reference of field {} is not allowed.".format(n.prefixLink.name))
# Note: 1. Not calling codeGen for variableDcl since all variableDcl code are assuming that the variable is a LOCAL variable
# 2. Only calling codeGen on the variableInit of self.variableDcl if the field is static, since non-static fields should
# be initialized by the constructor
def codeGen(self):
if hasattr(self, "code"):
return
self.code = ""
label = self.typeName + "_" + self.name
# static fields: the pointer lives in assembly
if "static" in self.mods:
self.code += ";Declaring a static field: " + label + "\n"
self.code += pLabel(name=label, type="static") + \
p(instruction="dd", arg1="64", comment="Declaring space on assembly for a static field")
# Initializing static fields
# static fields are intialized in the order of declaration within the class and has to be intialized
# before the test() method is being called
initNode = self.variableDcl.variableInit
if initNode:
initNode.codeGen()
self.code += "; Calculating the initial value of declared field: " + label + "\n"
self.code += initNode.code
# Filling in label with pointer's address
self.code += p(instruction="mov", arg1="ebx", arg2="dword S_"+label) + \
p(instruction="mov", arg1="[ebx]", arg2="eax", comment="eax is a pointer to field value in heap")
self.code += ";End of declaration of static field\n"
###########################################################
# method
class MethodNode(ASTNode):
# always list all fields in the init method to show the class structure
def __init__(self, parseTree, typeName, order):
self.parseTree = parseTree
self.name = ''
self.methodType = ''
self.params = [] # a list of paramNodes
self.mods = []
self.body = None
self.paramTypes = '' # a string of param types (signature) for easy type checking against arguments
self.env = None
self.children = []
self.typeName = typeName
self.order = order
self.myType = None
# get method name
nameNodes = getParseTreeNodes(['ID'], parseTree, ['params', 'type', 'methodBody'])
for n in nameNodes:
self.name = n.lex
# params
nameNodes = getParseTreeNodes(['param'], parseTree, ['methodBody'])
for n in nameNodes:
paramNode = ParamNode(n, self.typeName)
self.params.append(paramNode)
self.paramTypes += paramNode.paramType.myType.name
nameNodes = getParseTreeNodes(['type', 'VOID'], parseTree, ['methodBody', 'params'])
for n in nameNodes:
if n.name == 'VOID':
self.methodType = TypeNode('VOID', self.typeName)
else:
self.methodType = TypeNode(n, self.typeName)
for node in parseTree.children:
if node.name == 'methodMod' or node.name == "interfaceMod":
for m in node.children:
self.mods.append(m.lex)
elif node.name == 'methodBody':
nameNodes = getParseTreeNodes(['block'], node)
for n in nameNodes:
self.body = BlockNode(n, typeName)
if self.body: self.children.append(self.body)
self.children.append(self.methodType)
self.children.extend(self.params)
def __eq__(self, other):
if self.name == other.name and len(self.params) == len(other.params):
for i in range(len(self.params)):
if not self.params[i].paramType == other.params[i].paramType:
if self.name == 'addAll':
raise Exception('HERE {}, {}'.format(self.params[i].paramType.name, other.params[i].paramType.name))
return False
return True
return False
def buildEnv(self, parentEnv):
env = Env(parentEnv)
for p in self.params:
env.addtoEnv(p)
self.env = env
return env
def disambigName(self):
if self.body:
self.body.disambigName()
for p in self.params:
p.disambigName()
def checkType(self):
if self.methodType: # constructor would be None
self.myType = self.methodType.myType
for p in self.params:
p.checkType()
if self.body:
self.body.checkType()
# Checking return types against the function type
# No method body: do not check type as function isn't implemented
if not self.body:
return
# check no use of this in static method
if 'static' in self.mods:
names = getNameNodes(self.body)
for n in names:
if 'this' in n.name or (n.pointToThis and n.prefixLink.__class__.__name__ == ['MethodNode', 'FieldNode'] and 'static' not in n.prefixLink.mods):
raise Exception("ERROR: Cannot use non-static member {} in static method {} in class {}".format(n.name, self.name, self.typeName))
# With method body
returnNodes = getASTNode(["ReturnNode"], self.body)
for n in returnNodes:
n.method = self
# Checking for functions of type void
# Only valid if either the function doesn't have a return statement, or the return statement is a semicolon (return;)
if self.myType and self.myType.name == "void":
if n.myType:
raise Exception("ERROR: return type of function {} doesn't match with return statement.".format(self.name))
return
# Checking for non void cases
if self.myType and not self.myType.assignable(n.myType):
raise Exception("ERROR: return type of function {} doesn't match with return statement.".format(self.name))
return
def reachCheck(self, inMaybe=True):
self.outMaybe = False
# Check reachability of method body
if self.body:
self.body.reachCheck(True) # For method bodies, in[L] = maybe by default
self.outMaybe = self.body.outMaybe
# Check if out[method body] is a maybe for non-void methods
if not self.methodType or self.myType.name == "void": # Omitting the check for constructors and void functions
return
if self.outMaybe:
raise Exception("Non-void method '{}' in class '{}' does not return".format(self.name, self.typeName))
return
def codeGen(self):
if hasattr(self, "code") and self.code != "":
return
self.label = "M_" + self.typeName + "_" + self.name + "_" + self.paramTypes
self.code = pLabel(self.typeName + "_" + self.name + "_" + self.paramTypes, "method")
if self.body:
bodyCode = ""
# push all local var to stack
vars = getVarDclNodes(self.body)
for i in range(len(vars)):
vars[i].offset = i * 4 + 16
bodyCode += p("push", 0)
self.body.codeGen()
bodyCode += self.body.code
bodyCode += self.label + "_end: ; end of method for " + self.name + "\n"
# pop off all the local var
for i in range(len(vars)):
bodyCode += p("pop", "edx")
self.code += genProcedure(bodyCode, "method definition for " + self.name)
else:
self.code += ("mov", "eax", "0")
self.code += p("ret", "")
# This method is called instead of codeGen if this is a constructor
def codeGenConstructor(self):
if hasattr(self, "code") and self.code != "":
return
myClass = self.env.getNode(self.typeName, 'type')
self.label = "M_" + self.typeName + "_" + self.name + "_" + self.paramTypes
self.code = pLabel(self.typeName + "_" + self.name + "_" + self.paramTypes, "method") # label
thisLoc = len(self.params) * 4 + 8
bodyCode = ""
if self.body:
# push all local var to stack
vars = getVarDclNodes(self.body)
for i in range(len(vars)):
vars[i].offset = i * 4 + 16
bodyCode += p("push", 0)
# call parent constructor
if myClass.superClass:
suLabel = "M_" + myClass.superClass.name + "_"
bodyCode += importHelper(myClass.superClass.name, self.typeName, suLabel)
bodyCode += p("mov", "eax", "[ebp - " + thisLoc + "]")
bodyCode += p("push", "eax", None, "# Pass THIS as argument to superClass.")
bodyCode += p("call", suLabel)
# init fields
fields = sorted(myClass.fields, key=lambda x: x.order)
for f in fields:
if not 'static' in field.mods and f.variableDcl.variableInit:
f.variableDcl.variableInit.right.codeGen()
bodyCode += f.variableDcl.variableInit.right.code
bodyCode += p("mov", "ebx", "[ebp - " + thisLoc + "]") # THIS
bodyCode += p("mov", "[ebx + " + f.offset + " ]", "eax")
# body code
if self.body:
self.body.codeGen()
bodyCode += self.body.code
bodyCode += self.label + "_end: ; end of method for " + self.name + "\n"
# pop off all the local var
for i in range(len(vars)):
bodyCode += p("pop", "edx")
else:
bodyCode += ("mov", "eax", "0")
self.code += genProcedure(bodyCode, "Constructor definition for " + self.name + " " + self.paramTypes)
############# helper for forward ref checking ########
# Input: AST Node
# Output: A list of names to be check
def getForwardRefNames(node):
if node.__class__.__name__ == 'NameNode':
return [node]
result = []
if node.__class__.__name__ == 'AssignNode':
result.extend(getForwardRefNames(node.right))
else:
for c in node.children:
result.extend(getForwardRefNames(c))
return result
# Input: AST Node
# Output: A list of names to be check
def getNameNodes(node):
if not node:
return []
if node.__class__.__name__ == 'NameNode':
return [node]
result = []
for c in node.children:
result.extend(getNameNodes(c))
return result
# Input: Block Node
# Output: A list of local var dcl to be pushed onto the stack
def getVarDclNodes(node):
result = []
for s in node.statements:
if s.__class__.__name__ == 'VarDclNode':
result += [s]
if s.__class__.__name__ == 'BlockNode':
result += getVarDclNodes(node)
return result