Skip to content
Snippets Groups Projects
TypeNodes.py 19.21 KiB
from AST import ASTNode, getParseTreeNodes
from MemberNodes import FieldNode, MethodNode
from Environment import Env
from TheTypeNode import TypeStruct

# types: class, interface

class ClassInterNode(ASTNode):
    def __init__(self, parseTree, packageName):
        self.parseTree = parseTree
        self.packageName = packageName
        self.name = ''
        self.methods = []
        self.superInter = [] # class/Interface's name, then stores a pointer to the class after type linking
        self.env = None
        self.children = []
        self.canonName = ''

        # sets
        self.inherits = []
        self.super = []

    def checkHierarchy(self):
        # 3. An interface must not be repeated in an implements clause
        if self.superInter:
            unique = []
            for inter in self.superInter:
                if inter.canonName in unique:
                    raise Exception("ERROR: Class/Interface '{}' implements duplicate interfaces '{}'".format(self.name, inter.canonName))
                unique.append(inter.canonName)

        # 7. A class or interface must not declare two methods with the same signature (name and parameter types).
        # 9. A class/interface must not contain two methods with the same signature but different return types
        unique = []
        for method in self.methods:
            key = (method.name, method.paramTypes)
            if key in unique:
                raise Exception("ERROR: Class/Interface '{}' declares 2 methods with the same signature '{}'".format(self.name, key[0]))

            # quick fix for errors with java.lang.Object getClass function
            if self.canonName != 'java.lang.Object' and  key[0] == 'getClass' and not key[1]:
                raise Exception("ERROR: Method 'getClass' in class/interface '{}' replaces java.lang.Object's final method".format(self.name, key[0]))

            unique.append(key)

        contains = self.getContains([])

        # 10. A class that contains any abstract methods must be abstract.
        if isinstance(self, ClassNode):
            for con in contains:
                if isinstance(con, MethodNode):
                    if 'abstract' in con.mods and (not('abstract' in self.mods)):
                        raise Exception("ERROR: Non-abstract Class '{}' contains an abstract method".format(self.name))
                    if (not con.body) and (not ('native' in con.mods)) and (not ('abstract' in self.mods)) and (not (con in self.constructors)):
                        raise Exception("ERROR: Non-abstract Class '{}' contains an abstract method {}".format(self.name, con.name))

        # add inherited methods/fields to environment
        for i in self.inherits:
            self.env.addtoEnv(i)

    # hierarchy: string[]
    def getContains(self, hierarchy):
        # ---- ACYCLIC CHECK ---- #
        # 6. The hierarchy must be acyclic
        if self.canonName in hierarchy:
            raise Exception("ERROR: The hierarchy is not acyclic '{}', saw '{}'".format(hierarchy, self.canonName))

        # ---- GET SUPER CONTAINS ---- #
        superContains = []
        superInterContains = []
        superClassContains = []

        # parent interface methods
        for inter in self.superInter:
            superInterContains.extend(inter.getContains(hierarchy + [self.canonName]))

        # parent class methods
        if hasattr(self, 'superClass') and self.superClass:
            superClassContains.extend(self.superClass.getContains(hierarchy + [self.canonName]))

            # replace superInter methods that superClass implements
            # this adds methods from superInterContains to superContains
            # example: Tests/A2/J1_supermethod_override
            for sic in superInterContains:
                sicOverwritten = False
                for scc in superClassContains:
                    if type(sic) == type(scc) and sic == scc:
                        # sic is implicitly abstract, if scc is abstract, then it is not replacing sic
                        # Thus scc will be added to superContains as well (having 1+ abstract methods with same signiture is fine)
                        # Example: Tests/A3/J1_supermethod_override11/
                        if 'abstract' not in scc.mods:
                            safeReplace(sic, scc, self.name)
                            sicOverwritten = True
                            break
                if not sicOverwritten:
                    superContains.append(sic)

            superContains.extend(superClassContains)

        elif not self.superInter and self.canonName != 'java.lang.Object':
            # an interface without any super interfaces implicitly declares an abstract version of every public method in java.lang.Object
            # as well; every class inherits each method from java.lang.Object
                objectInterface = self.env.getNode('java.lang.Object', 'type')
                objectContains = objectInterface.getContains(hierarchy + [self.canonName])
                for oc in objectContains:
                    if isinstance(self, ClassNode) or 'public' in oc.mods:
                        superContains.append(oc)

        elif superInterContains: # if no superClass and we do have superInterContains
            superContains.extend(superInterContains)

        # ---- SUPER CONTAINS AGAINST SELF DECLARES ---- #
        inherits = []
        contains = []
        contains.extend(self.methods)
        if hasattr(self, 'fields'):
            contains.extend(self.fields)

        for sc in superContains:
            scOverwritten = False
            for c in contains:
                if type(sc) == type(c) and sc == c:
                    safeReplace(sc, c, self.name)
                    scOverwritten = True
                    break
            if not scOverwritten:
                contains.append(sc)
                inherits.append(sc)

        if not self.inherits:
            self.inherits.extend(inherits)

        return contains

    def printSets(self):
        print("---- Sets for Class/Interface {}".format(self.name))

        print("> self.super")
        for c in self.super:
            print(c.name)

        print("> self.inherits")
        for i in self.inherits:
            if isinstance(i, MethodNode):
                print(i.name + "(" + i.paramTypes + ")")
            if isinstance(i, FieldNode):
                print(i.name)

    def checkType(self):
        for c in self.children:
            if c and hasattr(c, 'checkType'):
                c.checkType()

# class
class ClassNode(ClassInterNode):
    # always list all fields in the init method to show the class structure
    def __init__(self, parseTree, packageName):
        super().__init__(parseTree, packageName)
        self.fields = []
        self.constructors = []
        self.mods = []
        self.superClass = '' # these fields initially stores a string that represent the super
        #### Code generation relevant fields ####
        self.label = "" # label in assembly
        self.methodOffset = {} # a dictionary that maps method signatures (method.name, method.paramTypes) to offsets in the memory layout
        self.staticFieldLabels = [] # a list of static field labels 

        for node in parseTree.children:
            if node.name == 'classMod':
                for m in node.children:
                    self.mods.append(m.lex)
            elif node.name == 'ID':
                self.name = node.lex

            elif node.name == 'superClass':
                nameNodes = getParseTreeNodes(['ID', 'COMPID'], node)
                for n in nameNodes:  # use for loop to handle no superclass case
                    self.superClass = n.lex
            elif node.name == 'superInterface':
                nameNodes = getParseTreeNodes(['ID', 'COMPID'], node)
                for n in nameNodes:
                    self.superInter.append(n.lex)

            elif node.name == 'classBody':
                order = 0
                memberDcls = getParseTreeNodes(['classBodyDcl'], node, ['constructorDcl', 'methodDcl', 'fieldDcl'])

                for m in memberDcls:
                    if m.children[0].name == 'fieldDcl':
                        self.fields.append(FieldNode(m.children[0], self.name, order))
                    elif m.children[0].name == 'methodDcl':
                        self.methods.append(MethodNode(m.children[0], self.name, order))
                    elif m.children[0].name == 'constructorDcl':
                        self.constructors.append(MethodNode(m.children[0], self.name, order))
                    order += 1

        self.canonName = self.packageName + '.' + self.name
        self.myType = TypeStruct(self.canonName, self)
        self.children += self.fields + self.methods + self.constructors

    def buildEnv(self, parentEnv):
        env = Env(parentEnv)
        for m in self.methods:
            env.addtoEnv(m)
        for f in self.fields:
            env.addtoEnv(f)
        # not adding constructor to the environment, since it's in the type namespace
        # when looking for a constructor, look for a class with the same name, and look in its constructors field
        self.env = env
        return env

    def linkType(self):
        # link types to the actual nodes fom the environment (envs already created)
        # also create super set
        if self.env is not None:
            if self.superClass:
                newSuperClass = self.env.getNode(self.superClass, 'type')
                self.superClass = newSuperClass
                self.super.append(newSuperClass)
            if self.superInter:
                for (index, inter) in enumerate(self.superInter):
                    newSuperInter = self.env.getNode(inter, 'type')
                    self.superInter[index] = newSuperInter
                    self.super.append(newSuperInter)
            if not self.super and self.canonName != 'java.lang.Object':
                objectNode = self.env.getNode('java.lang.Object', 'type')
                self.super.append(objectNode)

    def checkHierarchy(self):
        # 1. A class must not extend an interface.
        # 4. A class must not extend a final class.
        if self.superClass:
            if not isinstance(self.superClass, ClassNode):
                raise Exception("ERROR: Class '{}' extends non-class '{}'".format(self.name, self.superClass.name))
            if 'final' in self.superClass.mods:
                raise Exception("ERROR: Class '{}' extends final class '{}'".format(self.name, self.superClass.name))

        # 2. A class must not implement a class
        # 3. An interface must not be repeated in an implements clause
        if self.superInter:
            unique = []
            for inter in self.superInter:
                if not isinstance(inter, InterNode):
                    raise Exception("ERROR: Class '{}' implements non-interface '{}'".format(self.name, inter.name))
                unique.append(inter.name)

        # 8. A class must not declare two constructors with the same parameter types
        unique = []
        for cons in self.constructors:
            key = (cons.name, cons.paramTypes)
            if key in unique:
                raise Exception("ERROR: Class '{}' declares 2 constructors with the same parameter types".format(self.name))
            unique.append(key)

        # overlapping class/interface logic
        super().checkHierarchy()

    
    def checkType(self):
        super().checkType()

        # Checking if constructor's name is the same as class name
        for constructor in self.constructors:
            if not self.name == constructor.name:
                raise Exception("ERROR: Constructor {0} doesn't have the same name as class {1}".format(constructor.name, self.name))
        return

    
    def code(self):
        genCode = ""
        # print("This is the super class: {}".format(self.superClass))

        # Generate class label
        self.label = "C_" + self.name
        genCode += self.oneLineAssemblyHelper(label=self.label, comment="START OF CLASS MEMORY LAYOUT FOR CLASS " + self.canonName)
        
        # TODO: SIT and subtype testing tables



        ####### ADDING METHODS TO CLASS MEMORY LAYOUT #########
        # 1. Copying over the offsets of methods from superclass and DECLARING memory segment for the methods 
        lastOffset = -4  # stores the largest offset in the superCalss 
                        # TODO: set this to 4 after the implemntation of both the SIT and subtype testing table
                        # Note: This is 4 less than the offset of where the next method would be located
                        #       This is to accomodate for the addition of 4 to lastOffset in EVERY (including the first) iteration in the 
                        #       loops that loops through self.constructors and self.methods, in the case where there is no superClass
        if self.superClass:
            for key,value in self.superClass.methodOffset.items():
                self.methodOffset[key] = value
                newLabel = "VM_" + self.name + "_" + key[0] + "_" + key[1]
                genCode += self.oneLineAssemblyHelper(label=newLabel, operator="dd", op1="64") # just declaring a memory segment with a random number
                lastOffset = max(value, lastOffset)
        
        # 2.  Assigning offsets to constructors and DECLARING memory segment for the methods 
        for method in self.constructors:
            lastOffset += 4
            key = (method.name, method.paramTypes)
            self.methodOffset[(method.name, method.paramTypes)] = lastOffset
            newLabel = "VM_" + self.name + "_" + method.name + "_" + method.paramTypes
            genCode += self.oneLineAssemblyHelper(label=newLabel, operator="dd", op1="64") # just declaring a memory segment with a random number
        
        # 3. Assigning offsets to methods that aren't in the super class DECLARING memory segment for the methods 
        for method in self.methods:
            if not (method.name, method.paramTypes) in self.methodOffset:
                lastOffset += 4
                self.methodOffset[(method.name, method.paramTypes)] = lastOffset
                newLabel = "VM_" + self.name + "_" + method.name + "_" + method.paramTypes
                genCode += self.oneLineAssemblyHelper(label=newLabel, operator="dd", op1="64") # just declaring a memory segment with a random number
        # print(self.methodOffset)
        genCode += self.oneLineAssemblyHelper(comment="END OF CLASS MEMORY LAYOUT FOR CLASS " + self.name)
        # 4. Fill in the memory segment declared in step 1 and 2 with the addresses of the method implementations 
        for key,value in self.methodOffset.items():
            vmLabel = "VM_" + self.name + "_" + key[0] + "_" + key[1]+"" # method at class's vtable
            mLabel = "M_" + self.name + "_" + key[0] + "_" + key[1]
            genCode += self.oneLineAssemblyHelper(operator="mov", op1="eax", op2=vmLabel, comment="Filling in class memory segment for method " + mLabel)
            genCode += self.oneLineAssemblyHelper(operator="mov", op1="[eax]", op2=mLabel)
        
        # print(genCode)


        ###########################################################


        for c in self.children:
            if c and hasattr(c, 'code'):
                genCode += c.code()
        
        return genCode







        



            



#####################################################################
# interface
class InterNode(ClassInterNode):
    # always list all fields in the init method to show the class structure
    def __init__(self, parseTree, packageName):
        super().__init__(parseTree, packageName)

        for node in parseTree.children:
            if node.name == 'ID':
                self.name = node.lex

            elif node.name == 'extendInterface':
                nameNodes = getParseTreeNodes(['ID', 'COMPID'], node)
                for n in nameNodes:
                    self.superInter.append(n.lex)

            elif node.name == 'interfaceBody':
                nodes = getParseTreeNodes(['interfaceMethodDcl'], node)
                for n in nodes:
                    self.methods.append(MethodNode(n, self.name, 0))  # order = 0 since no method body in interface needs to be checked

        self.canonName = self.packageName + '.' + self.name
        self.myType = TypeStruct(self.canonName, self)
        self.children.extend(self.methods)

    def buildEnv(self, parentEnv):
        env = Env(parentEnv)
        for c in self.children:
            env.addtoEnv(c)
        self.env = env
        return env

    def linkType(self):
        # link types to the actual nodes fom the environment (envs already created)
        if self.env is not None:
            if self.superInter:
                for (index, inter) in enumerate(self.superInter):
                    newSuperInter = self.env.getNode(inter, 'type')
                    self.superInter[index] = newSuperInter
                    self.super.append(newSuperInter)

    def checkHierarchy(self):
        # 5. An interface must not extend a class.
        # 3. An interface must not be repeated in an extends clause of an interface
        if self.superInter:
            unique = []
            for inter in self.superInter:
                if not isinstance(inter, InterNode):
                    raise Exception("ERROR: Interface '{}' extends non-interface '{}'".format(self.name, inter.name))
                if inter.name in unique:
                    raise Exception("ERROR: Interface '{}' extends duplicate interfaces '{}'".format(self.name, inter.name))
                unique.append(inter.name)

        # centralized point for overlapping class & interface logic.
        super().checkHierarchy()

# helper - replace method check
# cur/new: MethodNode or FieldNode
def safeReplace(cur, new, className):
    # getting here signifies that cur and new have the same signature (and are the same type)

    methodOrField = 'method'
    if isinstance(cur, FieldNode):
        methodOrField = 'field'

    # 11. A nonstatic method must not replace a static method
    if 'static' in cur.mods and 'static' not in new.mods:
        raise Exception("ERROR: In class {0}, non-static {1} '{2}' in class '{3}' replaces static {1} in class/interface {3}".format(className, methodOrField, new.name, new.typeName, cur.typeName))

    # 9. A class/interface must not contain two methods with the same signature but different return types
    # 12. A method must not replace a method with a different return type
    if isinstance(cur, MethodNode) and cur.methodType != new.methodType:
        raise Exception("ERROR: In class {}, method '{}' in class '{}' replaces method with a different return type in class/interface {}".format(className, new.name, new.typeName, cur.typeName))

    # 13. A protected method must not replace a public method
    if 'public' in cur.mods and 'protected' in new.mods:
        raise Exception("ERROR: In class {0}, protected {1} '{2}' from class '{3}' replaces public {1} from class/interface {4}".format(className, methodOrField, new.name, new.typeName, cur.typeName))

    # 14. A method must not replace a final method
    # quick fix for final method getClass from java.lang.Object
    if 'final' in cur.mods and cur.name != 'getClass':
        raise Exception("ERROR: In class {0}, {1} '{2}' in class '{3}' replaces final {1} in class/interface {4}".format(className, methodOrField, new.name, new.typeName, cur.typeName))