From 014420b3c0dfc75b844c623840424931777700bc Mon Sep 17 00:00:00 2001 From: Xun Yang <x299yang@uwaterloo.ca> Date: Tue, 3 Mar 2020 19:45:16 -0500 Subject: [PATCH] some type checking stuff --- AST.py | 8 ++++- AstBuilding.py | 3 ++ ExprPrimaryNodes.py | 14 +++++++-- MemberNodes.py | 2 +- TheTypeNode.py | 75 ++++++++++++++++++++++++++++++--------------- 5 files changed, 74 insertions(+), 28 deletions(-) diff --git a/AST.py b/AST.py index 4a219c7..f8ee30d 100644 --- a/AST.py +++ b/AST.py @@ -12,6 +12,7 @@ class ASTNode(): def __init__(self, parseTree): self.parseTree = parseTree self.children = [] + self.myType = "" # either empty string or a TypeStruct # Do certains actions on every node of the AST tree # call the same method in each class and its children recursively @@ -40,7 +41,7 @@ class ASTNode(): else: c.recurseBuildEnv(result) if c.__class__.__name__ == 'VarDclNode': - preVarDcl = c + preVarDcl = c def buildEnv(self, parentEnv): @@ -53,6 +54,11 @@ class ASTNode(): def checkHierarchy(self): pass + def checkType(self): + for c in self.children: + if c and hasattr(c, 'checkType'): + c.checkType() + def printNodePretty(self, prefix=0): pp = pprint.PrettyPrinter(indent=prefix) pp.pprint(self.__class__.__name__) diff --git a/AstBuilding.py b/AstBuilding.py index 27803a9..fc92c15 100644 --- a/AstBuilding.py +++ b/AstBuilding.py @@ -42,4 +42,7 @@ def buildEnvAndLink(ASTs): for t in ASTs: t[1].recurseAction("checkHierarchy") + for t in ASTs: + t[1].checkType() + ####################################################### diff --git a/ExprPrimaryNodes.py b/ExprPrimaryNodes.py index 6ac4f83..9f62d59 100644 --- a/ExprPrimaryNodes.py +++ b/ExprPrimaryNodes.py @@ -1,7 +1,7 @@ from AST import ASTNode, getParseTreeNodes from Environment import Env from UnitNodes import LiteralNode -from TheTypeNode import TypeNode +from TheTypeNode import TypeNode, TypeStruct # file containing smaller (lower level nodes) in the AST # nodes in this file: @@ -48,7 +48,7 @@ def makeNodeFromAllPrimary(parseTree): if parseTree.children[0].children[0].name == 'arrayAccess': return ArrayAccessNode(parseTree.children[0].children[0]) parseTree = parseTree.children[0].children[0] - + if parseTree.name == 'primary': if parseTree.children[0].name == 'arrayAccess': return ArrayAccessNode(parseTree.children[0]) @@ -207,6 +207,16 @@ class ExprNode(ASTNode): self.children.append(self.left) self.children.append(self.right) + def checkType(self): + super().checkType() # check children's type first to populate their myType field + if self.op == '==' or self.op == '!=': + if (self.left.myType == self.right.myType) or (self.left.myType.isNum() and self.right.myType.isNum()): + self.myType = TypeStruct("boolean") + else: + raise Exception('ERROR: Incompatible types for comparison.') + # TODO: type check other types of expr + + ################################################################################### # fieldAccess primary PERIOD ID diff --git a/MemberNodes.py b/MemberNodes.py index 0b07d9e..23c65a8 100644 --- a/MemberNodes.py +++ b/MemberNodes.py @@ -61,7 +61,7 @@ class MethodNode(ASTNode): for n in nameNodes: paramNode = ParamNode(n) self.params.append(paramNode) - self.paramTypes += paramNode.paramType.name + self.paramTypes += paramNode.paramType.myType.name nameNodes = getParseTreeNodes(['type', 'VOID'], parseTree, ['methodBody', 'params']) for n in nameNodes: diff --git a/TheTypeNode.py b/TheTypeNode.py index 0752c34..0ff37f1 100644 --- a/TheTypeNode.py +++ b/TheTypeNode.py @@ -1,53 +1,80 @@ from AST import ASTNode, getParseTreeNodes ################################################################################## -# type: primitiveType, ArrayType, RefType +# TypeNode: an AST node represents a type +# TypeStruct: a struct holding type information for type checking + class TypeNode(ASTNode): # always list all fields in the init method to show the class structure def __init__(self, parseTree): self.parseTree = parseTree self.name = '' - self.isArray = False - self.isPrimitive = False self.env = None self.children = [] - self.myType = None # pointer pointing to the type + self.myType = "" # empty string or typeStruct if parseTree == 'VOID': - self.name = 'void' - self.isPrimitive = True + self.myType = TypeStruct('void') else: nameNodes = getParseTreeNodes(['BOOLEAN', 'BYTE', 'CHAR', 'INT', 'SHORT'], parseTree) if nameNodes: - self.isPrimitive = True - self.name = nameNodes[0].lex + self.myType = TypeStruct(nameNodes[0].lex) else: - self.name = getParseTreeNodes(['ID', 'COMPID'], parseTree)[0].lex + self.myType = TypeStruct(getParseTreeNodes(['ID', 'COMPID'], parseTree)[0].lex) nameNodes = getParseTreeNodes(['LSQRBRACK'], parseTree) if nameNodes: - self.isArray = True + self.myType.isArray = True def __eq__(self, other): - return self.name == other.name + return self.myType == other.myType def linkType(self): + self.myType.link(self.env) + + +class TypeStruct(): + def __init__(self, name): + self.isArray = False + self.isPrimitive = False + self.typePointer = None + self.name = name + if name in ['boolean', 'byte', 'char', 'int', 'short', 'void']: + self.isPrimitive = True + + def link(self, env): if not self.isPrimitive: - self.myType = self.env.getNode(self.name, 'type') - self.name = self.myType.canonName # Use canonName instead of simple name for comparison - else: - self.myType = self.name + self.typePointer = env.getNode(self.name, 'type') + self.name = self.typePointer.canonName # Use canonName instead of simple name for comparison + + def __eq__(self, other): + return self.name == other.name def isNum(self): return self.name in ['int', 'short', 'char', 'byte'] - # if self is assignable to input typeNode: left := self - def assignable(self, left): - if self == left \ - or (self.name in ['short', 'char', 'byte'] and left.name == 'int') \ - or (self.name == 'byte' and left.name == 'short') \ - or (not left.isPrimitive and self.name == 'null'): - return True - return False + # if self is assignable to input typeNode: self := right + # right is either a TypeNode or a LiteralNode + def assignable(self, right): + if self.isArray == right.isArray: + if self == right \ + or (right.name in ['short', 'char', 'byte'] and self.name == 'int') \ + or (right.name == 'byte' and self.name == 'short') \ + or (not self.isPrimitive and right.name == 'null'): + return True + # check if self is super of right + elif ((not self.isPrimitive) and (not right.isPrimitive)) \ + and (self.name in getSupers(right.typePointer)): + return True + return False + return False - # is java.Object added to super class of everything/ +# helper: get list of all super class/interface of a ClassInterNode +def getSupers(classType): + result = [] + if not classType.super: + return result + for s in classType.super: + result.append(s.canonName) + result.extend(getSupers(s)) + return result -- GitLab