Skip to content
Snippets Groups Projects
Commit 014420b3 authored by Xun Yang's avatar Xun Yang
Browse files

some type checking stuff

parent c604e257
No related branches found
No related tags found
No related merge requests found
......@@ -12,6 +12,7 @@ class ASTNode():
def __init__(self, parseTree):
self.parseTree = parseTree
self.children = []
self.myType = "" # either empty string or a TypeStruct
# Do certains actions on every node of the AST tree
# call the same method in each class and its children recursively
......@@ -40,7 +41,7 @@ class ASTNode():
else:
c.recurseBuildEnv(result)
if c.__class__.__name__ == 'VarDclNode':
preVarDcl = c
preVarDcl = c
def buildEnv(self, parentEnv):
......@@ -53,6 +54,11 @@ class ASTNode():
def checkHierarchy(self):
pass
def checkType(self):
for c in self.children:
if c and hasattr(c, 'checkType'):
c.checkType()
def printNodePretty(self, prefix=0):
pp = pprint.PrettyPrinter(indent=prefix)
pp.pprint(self.__class__.__name__)
......
......@@ -42,4 +42,7 @@ def buildEnvAndLink(ASTs):
for t in ASTs:
t[1].recurseAction("checkHierarchy")
for t in ASTs:
t[1].checkType()
#######################################################
from AST import ASTNode, getParseTreeNodes
from Environment import Env
from UnitNodes import LiteralNode
from TheTypeNode import TypeNode
from TheTypeNode import TypeNode, TypeStruct
# file containing smaller (lower level nodes) in the AST
# nodes in this file:
......@@ -48,7 +48,7 @@ def makeNodeFromAllPrimary(parseTree):
if parseTree.children[0].children[0].name == 'arrayAccess':
return ArrayAccessNode(parseTree.children[0].children[0])
parseTree = parseTree.children[0].children[0]
if parseTree.name == 'primary':
if parseTree.children[0].name == 'arrayAccess':
return ArrayAccessNode(parseTree.children[0])
......@@ -207,6 +207,16 @@ class ExprNode(ASTNode):
self.children.append(self.left)
self.children.append(self.right)
def checkType(self):
super().checkType() # check children's type first to populate their myType field
if self.op == '==' or self.op == '!=':
if (self.left.myType == self.right.myType) or (self.left.myType.isNum() and self.right.myType.isNum()):
self.myType = TypeStruct("boolean")
else:
raise Exception('ERROR: Incompatible types for comparison.')
# TODO: type check other types of expr
###################################################################################
# fieldAccess primary PERIOD ID
......
......@@ -61,7 +61,7 @@ class MethodNode(ASTNode):
for n in nameNodes:
paramNode = ParamNode(n)
self.params.append(paramNode)
self.paramTypes += paramNode.paramType.name
self.paramTypes += paramNode.paramType.myType.name
nameNodes = getParseTreeNodes(['type', 'VOID'], parseTree, ['methodBody', 'params'])
for n in nameNodes:
......
from AST import ASTNode, getParseTreeNodes
##################################################################################
# type: primitiveType, ArrayType, RefType
# TypeNode: an AST node represents a type
# TypeStruct: a struct holding type information for type checking
class TypeNode(ASTNode):
# always list all fields in the init method to show the class structure
def __init__(self, parseTree):
self.parseTree = parseTree
self.name = ''
self.isArray = False
self.isPrimitive = False
self.env = None
self.children = []
self.myType = None # pointer pointing to the type
self.myType = "" # empty string or typeStruct
if parseTree == 'VOID':
self.name = 'void'
self.isPrimitive = True
self.myType = TypeStruct('void')
else:
nameNodes = getParseTreeNodes(['BOOLEAN', 'BYTE', 'CHAR', 'INT', 'SHORT'], parseTree)
if nameNodes:
self.isPrimitive = True
self.name = nameNodes[0].lex
self.myType = TypeStruct(nameNodes[0].lex)
else:
self.name = getParseTreeNodes(['ID', 'COMPID'], parseTree)[0].lex
self.myType = TypeStruct(getParseTreeNodes(['ID', 'COMPID'], parseTree)[0].lex)
nameNodes = getParseTreeNodes(['LSQRBRACK'], parseTree)
if nameNodes:
self.isArray = True
self.myType.isArray = True
def __eq__(self, other):
return self.name == other.name
return self.myType == other.myType
def linkType(self):
self.myType.link(self.env)
class TypeStruct():
def __init__(self, name):
self.isArray = False
self.isPrimitive = False
self.typePointer = None
self.name = name
if name in ['boolean', 'byte', 'char', 'int', 'short', 'void']:
self.isPrimitive = True
def link(self, env):
if not self.isPrimitive:
self.myType = self.env.getNode(self.name, 'type')
self.name = self.myType.canonName # Use canonName instead of simple name for comparison
else:
self.myType = self.name
self.typePointer = env.getNode(self.name, 'type')
self.name = self.typePointer.canonName # Use canonName instead of simple name for comparison
def __eq__(self, other):
return self.name == other.name
def isNum(self):
return self.name in ['int', 'short', 'char', 'byte']
# if self is assignable to input typeNode: left := self
def assignable(self, left):
if self == left \
or (self.name in ['short', 'char', 'byte'] and left.name == 'int') \
or (self.name == 'byte' and left.name == 'short') \
or (not left.isPrimitive and self.name == 'null'):
return True
return False
# if self is assignable to input typeNode: self := right
# right is either a TypeNode or a LiteralNode
def assignable(self, right):
if self.isArray == right.isArray:
if self == right \
or (right.name in ['short', 'char', 'byte'] and self.name == 'int') \
or (right.name == 'byte' and self.name == 'short') \
or (not self.isPrimitive and right.name == 'null'):
return True
# check if self is super of right
elif ((not self.isPrimitive) and (not right.isPrimitive)) \
and (self.name in getSupers(right.typePointer)):
return True
return False
return False
# is java.Object added to super class of everything/
# helper: get list of all super class/interface of a ClassInterNode
def getSupers(classType):
result = []
if not classType.super:
return result
for s in classType.super:
result.append(s.canonName)
result.extend(getSupers(s))
return result
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment