diff --git a/AST.py b/AST.py index 25b46ee245e6c5f4a0f381d29682d55767070f87..75aedd1839012adf5a2fea88f251d8ab222d9a84 100644 --- a/AST.py +++ b/AST.py @@ -12,6 +12,7 @@ class ASTNode(): def __init__(self, parseTree): self.parseTree = parseTree self.children = [] + self.myType = "" # either empty string or a TypeStruct # Do certains actions on every node of the AST tree # call the same method in each class and its children recursively @@ -40,7 +41,7 @@ class ASTNode(): else: c.recurseBuildEnv(result) if c.__class__.__name__ == 'VarDclNode': - preVarDcl = c + preVarDcl = c def buildEnv(self, parentEnv): self.env = parentEnv @@ -55,6 +56,12 @@ class ASTNode(): def disambigName(self): pass + def checkType(self): + # self is type correct if all its children are type correct (no exception raised) + for c in self.children: + if c and hasattr(c, 'checkType'): + c.checkType() + def printNodePretty(self, prefix=0): pp = pprint.PrettyPrinter(indent=prefix) pp.pprint(self.__class__.__name__) diff --git a/AstBuilding.py b/AstBuilding.py index 74574e1df32dba10a867b4c62ea493da7e96fd71..a97aaa338fb79e175b5ec133d9542a31e2f24f41 100644 --- a/AstBuilding.py +++ b/AstBuilding.py @@ -42,6 +42,9 @@ def buildEnvAndLink(ASTs): for t in ASTs: t[1].recurseAction("checkHierarchy") + for t in ASTs: + t[1].checkType() + ####################################################### def disamiguateAndTypeChecking(ASTs): diff --git a/ExprPrimaryNodes.py b/ExprPrimaryNodes.py index 7e2aefba8f53d0f4bd916f8a4ec10e2d43b2c6a2..df487a35d823f1a16463f4118e5347638556e4c6 100644 --- a/ExprPrimaryNodes.py +++ b/ExprPrimaryNodes.py @@ -1,7 +1,7 @@ from AST import ASTNode, getParseTreeNodes from Environment import Env from UnitNodes import LiteralNode -from TheTypeNode import TypeNode +from TheTypeNode import TypeNode, TypeStruct from NameNode import NameNode # file containing smaller (lower level nodes) in the AST @@ -49,7 +49,7 @@ def makeNodeFromAllPrimary(parseTree, typeName): if parseTree.children[0].children[0].name == 'arrayAccess': return ArrayAccessNode(parseTree.children[0].children[0], typeName) parseTree = parseTree.children[0].children[0] - + if parseTree.name == 'primary': if parseTree.children[0].name == 'arrayAccess': return ArrayAccessNode(parseTree.children[0], typeName) @@ -220,11 +220,17 @@ class ClassCreateNode(ASTNode): # always list all fields in the init method to show the class structure def __init__(self, parseTree, typeName): self.parseTree = parseTree.children[0] # input is classInstanceCreate unqualCreate - self.className = TypeNode(parseTree.children[1], typeName) # the class you're creating, already type linked + self.className = TypeNode(parseTree.children[0].children[1], typeName) self.args = ArgsNode(self.parseTree.children[3], typeName) self.env = None - self.children = [self.args] - self.typeName = typeName # the type (class/interface) this node belongs under + self.children = [self.args, self.className] + self.typeName = typeName + + def checkType(self): + # check class is not abstract + if 'abstract' in self.className.myType.typePointer.mods: + raise Exception('ERROR: Cannot create an instance of abstract class {}.'.format(self.className.myType.name)) + # TODO: more type checking @@ -242,19 +248,75 @@ class ExprNode(ASTNode): self.typeName = typeName # the type (class/interface) this node belongs under if parseTree.name == 'unaryNotPlusMinus' or parseTree.name == 'unaryExpr': - self.op = parseTree.children[0].name + self.op = parseTree.children[0].lex self.right = makeNodeFromExpr(parseTree.children[1], typeName) else: self.left = makeNodeFromExpr(parseTree.children[0], typeName) - self.op = parseTree.children[1].name + self.op = parseTree.children[1].lex self.right = makeNodeFromExpr(parseTree.children[2], typeName) self.children.append(self.left) self.children.append(self.right) - def disambigName(self): - helperDisambigName(self.left) - helperDisambigName(self.right) + def disambigName(self): + helperDisambigName(self.left) + helperDisambigName(self.right) + + # use wrong name to stop method from being called until we finish other implemetation + # def checkType(self): + def checkType1(self): + # steps of type checking: + # check children's types (children's myType field will be populated with a typeStruct) + # check using the rule for current node + # make a TypeStruct node and populate myType field for self + + super().checkType() # check children's type first to populate their myType field + print(self.left) + # Unary operations: + if not self.left: + if self.op == '-' and self.right.myType.isNum(): + self.myType = TypeStruct("int") + return + elif self.op == '!' and self.right.myType.name == 'boolean': + self.myType = self.myType = TypeStruct("boolean") + return + + # Numeric types + elif self.left.myType.isNum() and self.right.myType.isNum(): + # Comparisons: + if self.op in ['==', '!=', '<=', '>=', '>', '<']: + self.myType = TypeStruct("boolean") + return + # numeric operations: + elif self.op in ['+', '-', '*', '/']: + self.myType = TypeStruct("int") + return + # Boolean operations: + elif self.left.myType.name == 'boolean' and self.right.myType.name == 'boolean': + if self.op in ['&&', '&', '|', '||']: + self.myType = TypeStruct("boolean") + return + # Other Comparisons: + elif self.left.myType.assignable(self.right.myType) or self.right.myType.assignable(self.left.myType): + if self.op == '==' or self.op == '!=': + self.myType = TypeStruct("boolean") + return + + # String concat: + elif (self.left.myType.name =='java.lang.String' and self.right.myType.name not in ['null', 'void']) \ + or (self.right.myType.name =='java.lang.String' and self.left.myType.name not in ['null', 'void']): + self.myType = TypeStruct('java.lang.String') + self.myType.link(self.env) + return + + elif self.op == 'instanceof': + # assume it's correct for now, wait for runtime check + self.myType = TypeStruct("boolean") + return + + raise Exception("ERROR: Incompatible types. Left of {} type can't be used with right of {} type on operation {}".format(self.op, self.left.myType.name, self.right.myType.name)) + + ################################################################################### diff --git a/MemberNodes.py b/MemberNodes.py index 92ec06acb6145905330510e73e64645e57e1b376..b44ff9c996b1261fca2b829cec759a470ddfdb04 100644 --- a/MemberNodes.py +++ b/MemberNodes.py @@ -28,14 +28,14 @@ class FieldNode(ASTNode): for m in node.children: self.mods.append(m.lex) - elif node.name == 'type': - self.fieldType = TypeNode(node, self.typeName) - elif node.name == 'variableDcl': self.variableDcl = VarDclNode(node, self.typeName) self.children.append(self.variableDcl) + def __eq__(self, other): + return self.name == other.name + ########################################################### # method @@ -63,7 +63,7 @@ class MethodNode(ASTNode): for n in nameNodes: paramNode = ParamNode(n, self.typeName) self.params.append(paramNode) - self.paramTypes += paramNode.paramType.name + self.paramTypes += paramNode.paramType.myType.name nameNodes = getParseTreeNodes(['type', 'VOID'], parseTree, ['methodBody', 'params']) for n in nameNodes: @@ -83,6 +83,7 @@ class MethodNode(ASTNode): self.body = BlockNode(n, typeName) if self.body: self.children.append(self.body) + self.children.append(self.methodType) def __eq__(self, other): if self.name == other.name and len(self.params) == len(other.params): @@ -92,8 +93,6 @@ class MethodNode(ASTNode): raise Exception('HERE {}, {}'.format(self.params[i].paramType.name, other.params[i].paramType.name)) return False return True - if self.name == 'addAll': - raise Exception('THERE') return False def buildEnv(self, parentEnv): diff --git a/TheTypeNode.py b/TheTypeNode.py index 57b8a22aaea77975e5cfc2cff79d230ed59b0410..9795914c21094db25252bebc283f364be666b35f 100644 --- a/TheTypeNode.py +++ b/TheTypeNode.py @@ -1,54 +1,83 @@ from AST import ASTNode, getParseTreeNodes ################################################################################## -# type: primitiveType, ArrayType, RefType +# TypeNode: an AST node represents a type +# TypeStruct: a struct holding type information for type checking + +# TypeNode represents a parse tree unit that contains a type, +# TypeStruct is not a unit on parseTree, it is just a struct living in different AST nodes to keep track of their type + class TypeNode(ASTNode): # always list all fields in the init method to show the class structure def __init__(self, parseTree, typeName): self.parseTree = parseTree self.name = '' - self.isArray = False - self.isPrimitive = False self.env = None self.children = [] - self.myType = None # pointer pointing to the type - self.typeName = typeName - + self.myType = "" # empty string or typeStruct + if parseTree == 'VOID': - self.name = 'void' - self.isPrimitive = True + self.myType = TypeStruct('void') else: nameNodes = getParseTreeNodes(['BOOLEAN', 'BYTE', 'CHAR', 'INT', 'SHORT'], parseTree) if nameNodes: - self.isPrimitive = True - self.name = nameNodes[0].lex + self.myType = TypeStruct(nameNodes[0].lex) else: - self.name = getParseTreeNodes(['ID', 'COMPID'], parseTree)[0].lex + self.myType = TypeStruct(getParseTreeNodes(['ID', 'COMPID'], parseTree)[0].lex) nameNodes = getParseTreeNodes(['LSQRBRACK'], parseTree) if nameNodes: - self.isArray = True + self.myType.isArray = True def __eq__(self, other): - return self.name == other.name + return self.myType == other.myType def linkType(self): + self.myType.link(self.env) + + +class TypeStruct(): + def __init__(self, name): + self.isArray = False + self.isPrimitive = False + self.typePointer = None + self.name = name + if name in ['boolean', 'byte', 'char', 'int', 'short', 'void']: + self.isPrimitive = True + + def link(self, env): if not self.isPrimitive: - self.myType = self.env.getNode(self.name, 'type') - self.name = self.myType.canonName # Use canonName instead of simple name for comparison - else: - self.myType = self.name + self.typePointer = env.getNode(self.name, 'type') + self.name = self.typePointer.canonName # Use canonName instead of simple name for comparison + + def __eq__(self, other): + return self.name == other.name def isNum(self): return self.name in ['int', 'short', 'char', 'byte'] - # if self is assignable to input typeNode: left := self - def assignable(self, left): - if self == left \ - or (self.name in ['short', 'char', 'byte'] and left.name == 'int') \ - or (self.name == 'byte' and left.name == 'short') \ - or (not left.isPrimitive and self.name == 'null'): - return True - return False + # if self is assignable to input typeNode: self := right + # right is either a TypeNode or a LiteralNode + def assignable(self, right): + if self.isArray == right.isArray: + if self == right \ + or (right.name in ['short', 'char', 'byte'] and self.name == 'int') \ + or (right.name == 'byte' and self.name == 'short') \ + or (not self.isPrimitive and right.name == 'null'): + return True + # check if self is super of right + elif ((not self.isPrimitive) and (not right.isPrimitive)) \ + and (self.name in getSupers(right.typePointer)): + return True + return False + return False - # is java.Object added to super class of everything/ +# helper: get list of all super class/interface of a ClassInterNode +def getSupers(classType): + result = [] + if not classType.super: + return result + for s in classType.super: + result.append(s.canonName) + result.extend(getSupers(s)) + return result diff --git a/TypeNodes.py b/TypeNodes.py index e8705b6c4759e6431cc1688991056d8e61da45dd..920449d447b60fd57a0840b92e9effea79db2147 100644 --- a/TypeNodes.py +++ b/TypeNodes.py @@ -13,10 +13,10 @@ class ClassInterNode(ASTNode): self.superInter = [] # class/Interface's name, then stores a pointer to the class after type linking self.env = None self.children = [] - self.canonName = "" + self.canonName = '' # sets - self.contains = [] + self.inherits = [] self.super = [] def checkHierarchy(self): @@ -29,50 +29,118 @@ class ClassInterNode(ASTNode): unique.append(inter.name) # 7. A class or interface must not declare two methods with the same signature (name and parameter types). + # 9. A class/interface must not contain two methods with the same signature but different return types unique = [] for method in self.methods: key = (method.name, method.paramTypes) if key in unique: raise Exception("ERROR: Class/Interface '{}' declares 2 methods with the same signature '{}'".format(self.name, key[0])) + + # quick fix for errors with java.lang.Object getClass function + if self.canonName != 'java.lang.Object' and key[0] == 'getClass' and not key[1]: + raise Exception("ERROR: Method 'getClass' in class/interface '{}' replaces java.lang.Object's final method".format(self.name, key[0])) + unique.append(key) - # 6. The hierarchy must be acyclic - # 9. A class or interface must not contain (declare or inherit) two methods with the same signature but different return types - # 11. A nonstatic method must not replace a static method - # 13. A protected method must not replace a public method - self.contains = self.getContains([]) + contains = self.getContains([]) - def getContains(self, hierarchy): - # check if not acyclic + # 10. A class that contains any abstract methods must be abstract. + if isinstance(self, ClassNode): + for con in contains: + if isinstance(con, MethodNode): + if 'abstract' in con.mods and (not('abstract' in self.mods)): + raise Exception("ERROR: Non-abstract Class '{}' contains an abstract method".format(self.name)) + if (not con.body) and (not ('native' in con.mods)) and (not ('abstract' in self.mods)) and (not (con in self.constructors)): + raise Exception("ERROR: Non-abstract Class '{}' contains an abstract method {}".format(self.name, con.name)) + # add inherited methods/fields to environment + for i in self.inherits: + self.env.addtoEnv(i) + + # hierarchy: string[] + def getContains(self, hierarchy): + # ---- ACYCLIC CHECK ---- # + # 6. The hierarchy must be acyclic if self.canonName in hierarchy: raise Exception("ERROR: The hierarchy is not acyclic '{}', saw '{}'".format(hierarchy, self.canonName)) - contains = [] + # ---- GET SUPER CONTAINS ---- # + superContains = [] + superInterContains = [] + superClassContains = [] + # parent interface methods for inter in self.superInter: - superContains = inter.getContains(hierarchy + [self.canonName]) - for con in superContains: - conOverwritten = False - for method in self.methods: - if (method == con): - # cannot have same signiture but different return types - if (method.methodType != con.methodType): - raise Exception("ERROR: Class '{}' contains 2 methods '{}' with the same signature but different return types".format(self.name, method.name)) - # protected must not replace public - if 'protected' in method.mods and 'public' in con.mods: - raise Exception("ERROR: Protected method '{}' in class '{}' replaces public method '{}' in class {}".format(method.name, self.name, con.name, inter.name)) - # 14. A method must not replace a final method - if 'final' in con.mods: - raise Exception("ERROR: Final method '{}' in class '{}' can't be overrided by method '{}' in class {}".format(method.name, self.name, con.name, inter.name)) - conOverwritten = True + superInterContains.extend(inter.getContains(hierarchy + [self.canonName])) + + # parent class methods + if hasattr(self, 'superClass') and self.superClass: + superClassContains.extend(self.superClass.getContains(hierarchy + [self.canonName])) + + # replace superInter methods that superClass implements + # this adds methods from superInterContains to superContains + # example: Tests/A2/J1_supermethod_override + for sic in superInterContains: + sicOverwritten = False + for scc in superClassContains: + if type(sic) == type(scc) and sic == scc: + safeReplace(sic, scc, self.name) + sicOverwritten = True break - if not conOverwritten: - contains.append(con) - + if not sicOverwritten: + superContains.append(sic) + + superContains.extend(superClassContains) + + elif not self.superInter and self.canonName != 'java.lang.Object': + # an interface without any super interfaces implicitly declares an abstract version of every public method in java.lang.Object + # as well; every class inherits each method from java.lang.Object + objectInterface = self.env.getNode('java.lang.Object', 'type') + objectContains = objectInterface.getContains(hierarchy + [self.canonName]) + for oc in objectContains: + if isinstance(self, ClassNode) or 'public' in oc.mods: + superContains.append(oc) + + elif superInterContains: # if no superClass and we do have superInterContains + superContains.extend(superInterContains) + + # ---- SUPER CONTAINS AGAINST SELF DECLARES ---- # + inherits = [] + contains = [] contains.extend(self.methods) + if hasattr(self, 'fields'): + contains.extend(self.fields) + + for sc in superContains: + scOverwritten = False + for c in contains: + if type(sc) == type(c) and sc == c: + safeReplace(sc, c, self.name) + scOverwritten = True + break + if not scOverwritten: + contains.append(sc) + inherits.append(sc) + + if not self.inherits: + self.inherits.extend(inherits) + return contains + def printSets(self): + print("---- Sets for Class/Interface {}".format(self.name)) + + print("> self.super") + for c in self.super: + print(c.name) + + print("> self.inherits") + for i in self.inherits: + if isinstance(i, MethodNode): + print(i.name + "(" + i.paramTypes + ")") + if isinstance(i, FieldNode): + print(i.name) + # class class ClassNode(ClassInterNode): # always list all fields in the init method to show the class structure @@ -147,8 +215,8 @@ class ClassNode(ClassInterNode): newSuperInter = self.env.getNode(inter, 'type') self.superInter[index] = newSuperInter self.super.append(newSuperInter) - if not self.super and self.canonName != "java.lang.Object": - objectNode = self.env.getNode("java.lang.Object", 'type') + if not self.super and self.canonName != 'java.lang.Object': + objectNode = self.env.getNode('java.lang.Object', 'type') self.super.append(objectNode) def checkHierarchy(self): @@ -177,51 +245,9 @@ class ClassNode(ClassInterNode): raise Exception("ERROR: Class '{}' declares 2 constructors with the same parameter types".format(self.name)) unique.append(key) - # centralized point for overlapping class & interface logic. Also sets self.contains + # overlapping class/interface logic super().checkHierarchy() - # 10. A class that contains (declares or inherits) any abstract methods must be abstract. - for con in self.contains: - if 'abstract' in con.mods and (not('abstract' in self.mods)): - raise Exception("ERROR: Non-abstract Class '{}' contains an abstract method".format(self.name)) - if (not con.body) and (not ('native' in con.mods)) and (not ('abstract' in self.mods)) and (not (con in self.constructors)): - raise Exception("ERROR: Non-abstract Class '{}' contains an abstract method {}".format(self.name, con.name)) - - # hierarchy: string[] - def getContains(self, hierarchy): - # centralized logic - contains = super().getContains(hierarchy) - - # get contains from extends class - if self.superClass: - addToContains = [] - superContains = self.superClass.getContains(hierarchy + [self.canonName]) - for con in superContains: - conOverwritten = False - for index, method in enumerate(contains): - if (method == con): - # cannot have same signiture but different return types - if (method.methodType != con.methodType): - raise Exception("ERROR: Class '{}' contains 2 methods '{}' with the same signature but different return types".format(self.name, method.name)) - # must not replace final - if 'final' in con.mods: - raise Exception("ERROR: Method '{}' in class '{}' replaces final method '{}' in class '{}'".format(method.name, self.name, con.name, self.superClass.name)) - # nonstatic must not replace static - if 'static' not in method.mods and 'static' in con.mods: - raise Exception("ERROR: Non-static method '{}' in class '{}' replaces static method '{}' in class '{}'".format(method.name, self.name, con.name, self.superClass.name)) - # protected must not replace public - if 'protected' in method.mods and 'public' in con.mods: - raise Exception("ERROR: Protected method '{}' in class '{}' replaces public method '{}' in class {}".format(method.name, self.name, con.name, self.superClass.name)) - conOverwritten = True - if not method.body: - contains[index] = con - conOverwritten = False - break - if not conOverwritten: - addToContains.append(con) - contains += addToContains - return contains - def getConstructor(self, argTypes): for c in self.constructors: if c.paramTypes == argTypes: @@ -287,33 +313,32 @@ class InterNode(ClassInterNode): raise Exception("ERROR: Interface '{}' extends duplicate interfaces '{}'".format(self.name, inter.name)) unique.append(inter.name) - # centralized point for overlapping class & interface logic. Also sets self.contains + # centralized point for overlapping class & interface logic. super().checkHierarchy() - # hierarchy: string[] - def getContains(self, hierarchy): - # centralized logic - contains = super().getContains(hierarchy) - - # an interface without any super interfaces implicitly declares an abstract version of every public method in java.lang.Object - if not self.superInter: - addToContains = [] - objectInterface = self.env.getNode('java.lang.Object', 'type') - superContains = objectInterface.getContains(hierarchy + [self.canonName]) - for con in superContains: - conOverwritten = False - if 'public' in con.mods: - for method in self.methods: - if (method == con): - # cannot have same signiture but different return types - if (method.methodType != con.methodType): - raise Exception("ERROR: Class '{}' contains 2 methods '{}' with the same signature but different return types".format(self.name, method.name)) - # protected must not replace public - if 'protected' in method.mods and 'public' in con.mods: - raise Exception("ERROR: Protected method '{}' in class '{}' replaces public method '{}' in class {}".format(method.name, self.name, con.name, inter.name)) - conOverwritten = True - break - if not conOverwritten: - addToContains.append(con) - # contains += addToContains # this is causing very VERY WEIRD ERRORS I AM VERY FRUSTRATED, DO NOT UNCOMMENT THIS IF YOU WISH TO HAVE A GOOD TIME, UNCOMMENT AT YOUR PERIL - return contains +# helper - replace method check +# cur/new: MethodNode or FieldNode +def safeReplace(cur, new, className): + # getting here signifies that cur and new have the same signature (and are the same type) + + methodOrField = 'method' + if isinstance(cur, FieldNode): + methodOrField = 'field' + + # 11. A nonstatic method must not replace a static method + if 'static' in cur.mods and 'static' not in new.mods: + raise Exception("ERROR: Non-static {0} '{1}' in class '{2}' replaces static {0}".format(methodOrField, new.name, className)) + + # 9. A class/interface must not contain two methods with the same signature but different return types + # 12. A method must not replace a method with a different return type + if isinstance(cur, MethodNode) and cur.methodType != new.methodType: + raise Exception("ERROR: Method '{}' in class '{}' replaces method with a different return type".format(className, cur.name)) + + # 13. A protected method must not replace a public method + if 'public' in cur.mods and 'protected' in new.mods: + raise Exception("ERROR: Protected {0} '{1}' in class '{2}' replaces public {0}".format(methodOrField, new.name, className)) + + # 14. A method must not replace a final method + # quick fix for final method getClass from java.lang.Object + if 'final' in cur.mods and cur.name != 'getClass': + raise Exception("ERROR: {} '{}' in class '{}' replaces final {}".format(methodOrField.capitalize(), cur.name, className, methodOrField)) diff --git a/UnitNodes.py b/UnitNodes.py index aa7f1375e718f0ba7df22facb37b7efe595c7d7c..2a8165234b857495c388bbbf53aba0ef4d1e0e2f 100644 --- a/UnitNodes.py +++ b/UnitNodes.py @@ -1,6 +1,6 @@ from AST import ASTNode, getParseTreeNodes from Environment import Env -from TheTypeNode import TypeNode +from TheTypeNode import TypeNode, TypeStruct # LiteralNode # ParamNode @@ -10,9 +10,9 @@ from TheTypeNode import TypeNode # literals class LiteralNode(ASTNode): toLiType = dict({ - 'LITERALBOOL': 'bool', + 'LITERALBOOL': 'boolean', 'LITERALCHAR': 'char', - 'LITERALSTRING': 'String', + 'LITERALSTRING': 'java.lang.String', 'NULL': 'null', 'NUM': 'int', 'ZERO': 'int' @@ -20,15 +20,16 @@ class LiteralNode(ASTNode): # always list all fields in the init method to show the class structure def __init__(self, parseTree, typeName): self.parseTree = parseTree - self.liType = LiteralNode.toLiType.get(parseTree.children[0].name) # type of the literal + self.name = LiteralNode.toLiType.get(parseTree.children[0].name) # type of the literal self.value = parseTree.children[0].lex # the value + self.myType = TypeStruct(self.name) self.env = None self.children = [] self.typeName = typeName - if self.liType == 'int': + if self.name == 'int': self.value = int(self.value) - if self.liType == 'LITERALBOOL': + if self.name == 'LITERALBOOL': if self.value == 'false': self.value = False else: