Skip to content
Snippets Groups Projects
TypeNodes.py 11.81 KiB
from AST import ASTNode, getParseTreeNodes
from MemberNodes import FieldNode, MethodNode
from Environment import Env

# types: class , interface

class ClassInterNode(ASTNode):    
    def checkHierarchy(self):
        # 3. An interface must not be repeated in an implements clause
        if self.superInter:
            unique = []
            for inter in self.superInter:
                if inter.name in unique:
                    raise Exception("ERROR: Class/Interface '{}' implements duplicate interfaces '{}'".format(self.name, inter.name))
                unique.append(inter.name)
        
        # 7. A class or interface must not declare two methods with the same signature (name and parameter types).
        unique = []
        for method in self.methods:
            key = (method.name, method.paramTypes)
            if key in unique:
                raise Exception("ERROR: Class/Interface '{}' declares 2 methods with the same signature '{}'".format(self.name, key[0]))
            unique.append(key)
        
        # 6. The hierarchy must be acyclic
        # 9. A class or interface must not contain (declare or inherit) two methods with the same signature but different return types
        # 11. A nonstatic method must not replace a static method
        # 13. A protected method must not replace a public method
        # 14. A method must not replace a final method
        return self.getContains([])

    def getContains(self, hierarchy):        
        # check if not acyclic
        if self.name in hierarchy:
            raise Exception("ERROR: The hierarchy is not acyclic '{}', saw '{}'".format(hierarchy, self.name))
        hierarchy.append(self.name)

        # get contains
        contains = self.methods
        
        for inter in self.superInter:
            superContains = inter.getContains(hierarchy)
            for con in superContains:
                conOverwritten = False
                for method in self.methods:
                    if (method == con):
                        # cannot have same signiture but different return types
                        if (method.methodType != con.methodType):
                            raise Exception("ERROR: Class '{}' contains 2 methods '{}' with the same signature but different return types".format(self.name, method.name))
                        # protected must not replace public
                        if 'protected' in method.mods and 'public' in con.mods:
                            raise Exception("ERROR: Protected method '{}' in class '{}' replaces public method '{}' in class {}".format(method.name, self.name, con.name, inter.name))
                        conOverwritten = True
                        break
                if not conOverwritten:
                    contains.append(con)
        
        return contains

# class
class ClassNode(ClassInterNode):
    # always list all fields in the init method to show the class structure
    def __init__(self, parseTree):
        self.parseTree = parseTree
        self.name = ''
        self.fields = []
        self.methods = []
        self.constructors = []
        self.mods = []
        self.superClass = '' # these fields initially stores a string that represent the super
        self.superInter = [] #    class/Interface's name, then stores a pointer to the class after type linking
        self.env = None
        self.children = []

        for node in parseTree.children:
            if node.name == 'classMod':
                for m in node.children:
                    self.mods.append(m.lex)
            elif node.name == 'ID':
                self.name = node.lex

            elif node.name == 'superClass':
                nameNodes = getParseTreeNodes(['ID', 'COMPID'], node)
                for n in nameNodes:  # use for loop to handle no superclass case
                    self.superClass = n.lex
            elif node.name == 'superInterface':
                nameNodes = getParseTreeNodes(['ID', 'COMPID'], node)
                for n in nameNodes:
                    self.superInter.append(n.lex)

            elif node.name == 'classBody':
                fieldNodes = getParseTreeNodes(['fieldDcl'], node, ['constructorDcl', 'methodDcl'])
                for f in fieldNodes:
                    self.fields.append(FieldNode(f))

                constructorDcl = getParseTreeNodes(['constructorDcl'], node, ['fieldDcl', 'methodDcl'])
                for c in constructorDcl:
                    self.constructors.append(MethodNode(c))

                methodNodes = getParseTreeNodes(['methodDcl'], node, ['constructorDcl', 'fieldDcl'])
                for m in methodNodes:
                    self.methods.append(MethodNode(m))

        self.children += self.fields + self.methods + self.constructors

    def buildEnv(self, parentEnv):
        env = Env(parentEnv)
        for m in self.methods:
            env.addtoEnv(m)
        for f in self.fields:
            env.addtoEnv(f)
        # not adding constructor to the environment, since it's in the type namespace
        # when looking for a constructor, look for a class with the same name, and look in its constructors field
        self.env = env
        return env

    def linkType(self):
        # link types to the actual nodes fom the environment (envs already created)
        if self.env is not None:
            if self.superClass:
                newSuperClass = self.env.getNode(self.superClass, 'type')
                self.superClass = newSuperClass
            if self.superInter:
                for (index, inter) in enumerate(self.superInter):
                    newSuperInter = self.env.getNode(inter, 'type')
                    self.superInter[index] = newSuperInter

    def checkHierarchy(self):
        # 1. A class must not extend an interface.
        # 4. A class must not extend a final class.
        if self.superClass:
            if not isinstance(self.superClass, ClassNode):
                raise Exception("ERROR: Class '{}' extends non-class '{}'".format(self.name, self.superClass.name))
            if 'final' in self.superClass.mods:
                raise Exception("ERROR: Class '{}' extends final class '{}'".format(self.name, self.superClass.name))

        # 2. A class must not implement a class
        # 3. An interface must not be repeated in an implements clause
        if self.superInter:
            unique = []
            for inter in self.superInter:
                if not isinstance(inter, InterNode):
                    raise Exception("ERROR: Class '{}' implements non-interface '{}'".format(self.name, inter.name))
                unique.append(inter.name)

        # 8. A class must not declare two constructors with the same parameter types
        unique = []
        for cons in self.constructors:
            key = (cons.name, cons.paramTypes)
            if key in unique:
                raise Exception("ERROR: Class '{}' declares 2 constructors with the same parameter types".format(self.name))
            unique.append(key)
        
        # centralized point for overlapping class & interface logic
        contains = super().checkHierarchy()

        # 10. A class that contains (declares or inherits) any abstract methods must be abstract.
        for con in contains:
            if ('abstract' in con.mods or not con.body) and 'abstract' not in self.mods:
                raise Exception("ERROR: Non-abstract Class '{}' contains an abstract method".format(self.name))

    # hierarchy: string[]
    def getContains(self, hierarchy):
        # centralized logic
        contains = super().getContains(hierarchy)
        
        # get contains from extends class
        if self.superClass:
            addToContains = []
            superContains = self.superClass.getContains(hierarchy)
            for con in superContains:
                conOverwritten = False
                for method in contains:
                    if (method == con):
                        # cannot have same signiture but different return types
                        if (method.methodType != con.methodType):
                            raise Exception("ERROR: Class '{}' contains 2 methods '{}' with the same signature but different return types".format(self.name, method.name))
                        # must not replace final
                        if 'final' in con.mods:
                            raise Exception("ERROR: Method '{}' in class '{}' replaces final method '{}' in class '{}'".format(method.name, self.name, con.name, self.superClass.name))
                        # nonstatic must not replace static
                        if 'static' not in method.mods and 'static' in con.mods:
                            raise Exception("ERROR: Non-static method '{}' in class '{}' replaces static method '{}' in class '{}'".format(method.name, self.name, con.name, self.superClass.name))
                        # protected must not replace public
                        if 'protected' in method.mods and 'public' in con.mods:
                            raise Exception("ERROR: Protected method '{}' in class '{}' replaces public method '{}' in class {}".format(method.name, self.name, con.name, self.superClass.name))
                        conOverwritten = True
                        break
                if not conOverwritten:
                    addToContains.append(con)
            contains += addToContains
        return contains

    def getConstructor(self, argTypes):
        for c in self.constructors:
            if c.paramTypes == argTypes:
                return c

#####################################################################
# interface
class InterNode(ClassInterNode):
    # always list all fields in the init method to show the class structure
    def __init__(self, parseTree):
        self.parseTree = parseTree
        self.name = ''
        self.methods = []
        self.superInter = [] # list of strings of extendInterface's name, then stores a pointer to the node after type linking
        self.env = None
        self.children = []

        for node in parseTree.children:
            if node.name == 'ID':
                self.name = node.lex

            elif node.name == 'extendInterface':
                nameNodes = getParseTreeNodes(['ID', 'COMPID'], node)
                for n in nameNodes:
                    self.superInter.append(n.lex)

            elif node.name == 'interfaceBody':
                nodes = getParseTreeNodes(['interfaceMethodDcl'], node)
                for n in nodes:
                    self.methods.append(MethodNode(n))

        self.children = self.methods

    def buildEnv(self, parentEnv):
        env = Env(parentEnv)
        for c in self.children:
            env.addtoEnv(c)
        self.env = env
        return env

    def linkType(self):
        # link types to the actual nodes fom the environment (envs already created)
        if self.env is not None:
            if self.superInter:
                for (index, inter) in enumerate(self.superInter):
                    newSuperInter = self.env.getNode(inter, 'type')
                    self.superInter[index] = newSuperInter

    def checkHierarchy(self):
        # 5. An interface must not extend a class.
        # 3. An interface must not be repeated in an extends clause of an interface
        if self.superInter:
            unique = []
            for inter in self.superInter:
                if not isinstance(inter, InterNode):
                    raise Exception("ERROR: Interface '{}' extends non-interface '{}'".format(self.name, inter.name))
                if inter.name in unique:
                    raise Exception("ERROR: Interface '{}' extends duplicate interfaces '{}'".format(self.name, inter.name))
                unique.append(inter.name)

        # centralized point for overlapping class & interface logic
        contains = super().checkHierarchy()

    # hierarchy: string[]
    def getContains(self, hierarchy):
        # centralized logic
        return super().getContains(hierarchy)