from AST import ASTNode, getParseTreeNodes
from MemberNodes import FieldNode, MethodNode
from Environment import Env
from TheTypeNode import TypeStruct

# types: class, interface

class ClassInterNode(ASTNode):
    def __init__(self, parseTree, packageName):
        self.parseTree = parseTree
        self.packageName = packageName
        self.name = ''
        self.methods = []
        self.superInter = [] # class/Interface's name, then stores a pointer to the class after type linking
        self.env = None
        self.children = []
        self.canonName = ''

        # sets
        self.inherits = []
        self.super = []

    def checkHierarchy(self):
        # 3. An interface must not be repeated in an implements clause
        if self.superInter:
            unique = []
            for inter in self.superInter:
                if inter.canonName in unique:
                    raise Exception("ERROR: Class/Interface '{}' implements duplicate interfaces '{}'".format(self.name, inter.canonName))
                unique.append(inter.canonName)

        # 7. A class or interface must not declare two methods with the same signature (name and parameter types).
        # 9. A class/interface must not contain two methods with the same signature but different return types
        unique = []
        for method in self.methods:
            key = (method.name, method.paramTypes)
            if key in unique:
                raise Exception("ERROR: Class/Interface '{}' declares 2 methods with the same signature '{}'".format(self.name, key[0]))

            # quick fix for errors with java.lang.Object getClass function
            if self.canonName != 'java.lang.Object' and  key[0] == 'getClass' and not key[1]:
                raise Exception("ERROR: Method 'getClass' in class/interface '{}' replaces java.lang.Object's final method".format(self.name, key[0]))

            unique.append(key)

        contains = self.getContains([])

        # 10. A class that contains any abstract methods must be abstract.
        if isinstance(self, ClassNode):
            for con in contains:
                if isinstance(con, MethodNode):
                    if 'abstract' in con.mods and (not('abstract' in self.mods)):
                        raise Exception("ERROR: Non-abstract Class '{}' contains an abstract method".format(self.name))
                    if (not con.body) and (not ('native' in con.mods)) and (not ('abstract' in self.mods)) and (not (con in self.constructors)):
                        raise Exception("ERROR: Non-abstract Class '{}' contains an abstract method {}".format(self.name, con.name))

        # add inherited methods/fields to environment
        for i in self.inherits:
            self.env.addtoEnv(i)

    # hierarchy: string[]
    def getContains(self, hierarchy):
        # ---- ACYCLIC CHECK ---- #
        # 6. The hierarchy must be acyclic
        if self.canonName in hierarchy:
            raise Exception("ERROR: The hierarchy is not acyclic '{}', saw '{}'".format(hierarchy, self.canonName))

        # ---- GET SUPER CONTAINS ---- #
        superContains = []
        superInterContains = []
        superClassContains = []

        # parent interface methods
        for inter in self.superInter:
            superInterContains.extend(inter.getContains(hierarchy + [self.canonName]))

        # parent class methods
        if hasattr(self, 'superClass') and self.superClass:
            superClassContains.extend(self.superClass.getContains(hierarchy + [self.canonName]))

            # replace superInter methods that superClass implements
            # this adds methods from superInterContains to superContains
            # example: Tests/A2/J1_supermethod_override
            for sic in superInterContains:
                sicOverwritten = False
                for scc in superClassContains:
                    if type(sic) == type(scc) and sic == scc:
                        # sic is implicitly abstract, if scc is abstract, then it is not replacing sic
                        # Thus scc will be added to superContains as well (having 1+ abstract methods with same signiture is fine)
                        # Example: Tests/A3/J1_supermethod_override11/
                        if 'abstract' not in scc.mods:
                            safeReplace(sic, scc, self.name)
                            sicOverwritten = True
                            break
                if not sicOverwritten:
                    superContains.append(sic)

            superContains.extend(superClassContains)

        elif not self.superInter and self.canonName != 'java.lang.Object':
            # an interface without any super interfaces implicitly declares an abstract version of every public method in java.lang.Object
            # as well; every class inherits each method from java.lang.Object
                objectInterface = self.env.getNode('java.lang.Object', 'type')
                objectContains = objectInterface.getContains(hierarchy + [self.canonName])
                for oc in objectContains:
                    if isinstance(self, ClassNode) or 'public' in oc.mods:
                        superContains.append(oc)

        elif superInterContains: # if no superClass and we do have superInterContains
            superContains.extend(superInterContains)

        # ---- SUPER CONTAINS AGAINST SELF DECLARES ---- #
        inherits = []
        contains = []
        contains.extend(self.methods)
        if hasattr(self, 'fields'):
            contains.extend(self.fields)

        for sc in superContains:
            scOverwritten = False
            for c in contains:
                if type(sc) == type(c) and sc == c:
                    safeReplace(sc, c, self.name)
                    scOverwritten = True
                    break
            if not scOverwritten:
                contains.append(sc)
                inherits.append(sc)

        if not self.inherits:
            self.inherits.extend(inherits)

        return contains

    def printSets(self):
        print("---- Sets for Class/Interface {}".format(self.name))

        print("> self.super")
        for c in self.super:
            print(c.name)

        print("> self.inherits")
        for i in self.inherits:
            if isinstance(i, MethodNode):
                print(i.name + "(" + i.paramTypes + ")")
            if isinstance(i, FieldNode):
                print(i.name)

    def checkType(self):
        for c in self.children:
            if c and hasattr(c, 'checkType'):
                c.checkType()

# class
class ClassNode(ClassInterNode):
    # always list all fields in the init method to show the class structure
    def __init__(self, parseTree, packageName):
        super().__init__(parseTree, packageName)
        self.fields = []
        self.constructors = []
        self.mods = []
        self.superClass = '' # these fields initially stores a string that represent the super

        for node in parseTree.children:
            if node.name == 'classMod':
                for m in node.children:
                    self.mods.append(m.lex)
            elif node.name == 'ID':
                self.name = node.lex

            elif node.name == 'superClass':
                nameNodes = getParseTreeNodes(['ID', 'COMPID'], node)
                for n in nameNodes:  # use for loop to handle no superclass case
                    self.superClass = n.lex
            elif node.name == 'superInterface':
                nameNodes = getParseTreeNodes(['ID', 'COMPID'], node)
                for n in nameNodes:
                    self.superInter.append(n.lex)

            elif node.name == 'classBody':
                order = 0
                memberDcls = getParseTreeNodes(['classBodyDcl'], node, ['constructorDcl', 'methodDcl', 'fieldDcl'])

                for m in memberDcls:
                    if m.children[0].name == 'fieldDcl':
                        self.fields.append(FieldNode(m.children[0], self.name, order))
                    elif m.children[0].name == 'methodDcl':
                        self.methods.append(MethodNode(m.children[0], self.name, order))
                    elif m.children[0].name == 'constructorDcl':
                        self.constructors.append(MethodNode(m.children[0], self.name, order))
                    order += 1

        self.canonName = self.packageName + '.' + self.name
        self.myType = TypeStruct(self.canonName, self)
        self.children += self.fields + self.methods + self.constructors

    def buildEnv(self, parentEnv):
        env = Env(parentEnv)
        for m in self.methods:
            env.addtoEnv(m)
        for f in self.fields:
            env.addtoEnv(f)
        # not adding constructor to the environment, since it's in the type namespace
        # when looking for a constructor, look for a class with the same name, and look in its constructors field
        self.env = env
        return env

    def linkType(self):
        # link types to the actual nodes fom the environment (envs already created)
        # also create super set
        if self.env is not None:
            if self.superClass:
                newSuperClass = self.env.getNode(self.superClass, 'type')
                self.superClass = newSuperClass
                self.super.append(newSuperClass)
            if self.superInter:
                for (index, inter) in enumerate(self.superInter):
                    newSuperInter = self.env.getNode(inter, 'type')
                    self.superInter[index] = newSuperInter
                    self.super.append(newSuperInter)
            if not self.super and self.canonName != 'java.lang.Object':
                objectNode = self.env.getNode('java.lang.Object', 'type')
                self.super.append(objectNode)

    def checkHierarchy(self):
        # 1. A class must not extend an interface.
        # 4. A class must not extend a final class.
        if self.superClass:
            if not isinstance(self.superClass, ClassNode):
                raise Exception("ERROR: Class '{}' extends non-class '{}'".format(self.name, self.superClass.name))
            if 'final' in self.superClass.mods:
                raise Exception("ERROR: Class '{}' extends final class '{}'".format(self.name, self.superClass.name))

        # 2. A class must not implement a class
        # 3. An interface must not be repeated in an implements clause
        if self.superInter:
            unique = []
            for inter in self.superInter:
                if not isinstance(inter, InterNode):
                    raise Exception("ERROR: Class '{}' implements non-interface '{}'".format(self.name, inter.name))
                unique.append(inter.name)

        # 8. A class must not declare two constructors with the same parameter types
        unique = []
        for cons in self.constructors:
            key = (cons.name, cons.paramTypes)
            if key in unique:
                raise Exception("ERROR: Class '{}' declares 2 constructors with the same parameter types".format(self.name))
            unique.append(key)

        # overlapping class/interface logic
        super().checkHierarchy()

    
    def checkType(self):
        super().checkType()

        # Checking if constructor's name is the same as class name
        for constructor in self.constructors:
            if not self.name == constructor.name:
                raise Exception("ERROR: Constructor {0} doesn't have the same name as class {1}".format(constructor.name, self.name))
        return


#####################################################################
# interface
class InterNode(ClassInterNode):
    # always list all fields in the init method to show the class structure
    def __init__(self, parseTree, packageName):
        super().__init__(parseTree, packageName)

        for node in parseTree.children:
            if node.name == 'ID':
                self.name = node.lex

            elif node.name == 'extendInterface':
                nameNodes = getParseTreeNodes(['ID', 'COMPID'], node)
                for n in nameNodes:
                    self.superInter.append(n.lex)

            elif node.name == 'interfaceBody':
                nodes = getParseTreeNodes(['interfaceMethodDcl'], node)
                for n in nodes:
                    self.methods.append(MethodNode(n, self.name, 0))  # order = 0 since no method body in interface needs to be checked

        self.canonName = self.packageName + '.' + self.name
        self.myType = TypeStruct(self.canonName, self)
        self.children.extend(self.methods)

    def buildEnv(self, parentEnv):
        env = Env(parentEnv)
        for c in self.children:
            env.addtoEnv(c)
        self.env = env
        return env

    def linkType(self):
        # link types to the actual nodes fom the environment (envs already created)
        if self.env is not None:
            if self.superInter:
                for (index, inter) in enumerate(self.superInter):
                    newSuperInter = self.env.getNode(inter, 'type')
                    self.superInter[index] = newSuperInter
                    self.super.append(newSuperInter)

    def checkHierarchy(self):
        # 5. An interface must not extend a class.
        # 3. An interface must not be repeated in an extends clause of an interface
        if self.superInter:
            unique = []
            for inter in self.superInter:
                if not isinstance(inter, InterNode):
                    raise Exception("ERROR: Interface '{}' extends non-interface '{}'".format(self.name, inter.name))
                if inter.name in unique:
                    raise Exception("ERROR: Interface '{}' extends duplicate interfaces '{}'".format(self.name, inter.name))
                unique.append(inter.name)

        # centralized point for overlapping class & interface logic.
        super().checkHierarchy()

# helper - replace method check
# cur/new: MethodNode or FieldNode
def safeReplace(cur, new, className):
    # getting here signifies that cur and new have the same signature (and are the same type)

    methodOrField = 'method'
    if isinstance(cur, FieldNode):
        methodOrField = 'field'

    # 11. A nonstatic method must not replace a static method
    if 'static' in cur.mods and 'static' not in new.mods:
        raise Exception("ERROR: In class {0}, non-static {1} '{2}' in class '{3}' replaces static {1} in class/interface {3}".format(className, methodOrField, new.name, new.typeName, cur.typeName))

    # 9. A class/interface must not contain two methods with the same signature but different return types
    # 12. A method must not replace a method with a different return type
    if isinstance(cur, MethodNode) and cur.methodType != new.methodType:
        raise Exception("ERROR: In class {}, method '{}' in class '{}' replaces method with a different return type in class/interface {}".format(className, new.name, new.typeName, cur.typeName))

    # 13. A protected method must not replace a public method
    if 'public' in cur.mods and 'protected' in new.mods:
        raise Exception("ERROR: In class {0}, protected {1} '{2}' from class '{3}' replaces public {1} from class/interface {4}".format(className, methodOrField, new.name, new.typeName, cur.typeName))

    # 14. A method must not replace a final method
    # quick fix for final method getClass from java.lang.Object
    if 'final' in cur.mods and cur.name != 'getClass':
        raise Exception("ERROR: In class {0}, {1} '{2}' in class '{3}' replaces final {1} in class/interface {4}".format(className, methodOrField, new.name, new.typeName, cur.typeName))