from AST import ASTNode, getParseTreeNodes, getASTNode
from LineNodes import BlockNode, VarDclNode
from ExprPrimaryNodes import makeNodeFromExpr
from UnitNodes import ParamNode
from TheTypeNode import TypeNode
from Environment import Env
from collections import OrderedDict
from CodeGenUtils import p, pLabel, genProcedure, importHelper

# field
class FieldNode(ASTNode):
    # always list all fields in the init method to show the class structure
    def __init__(self, parseTree, typeName, order):
        self.parseTree = parseTree
        self.name = ''
        self.variableDcl = None
        self.mods = []
        self.env = None
        self.children = []
        self.typeName = typeName
        self.order = order

        for node in parseTree.children:
            if node.name == 'methodMod':
                for m in node.children:
                    self.mods.append(m.lex)

            elif node.name == 'variableDcl':
                self.variableDcl = VarDclNode(node, self.typeName, isField=self)

        self.name = self.variableDcl.name
        self.myType = self.variableDcl.myType

        self.children.append(self.variableDcl)

    def __eq__(self, other):
        return self.name == other.name

    def checkType(self):
        self.variableDcl.checkType()
        # check forward reference
        if self.variableDcl.variableInit:
            allNames = getForwardRefNames(self.variableDcl.variableInit)
            from pprint import pprint

            for n in allNames:
                if n.prefixLink is self.variableDcl:
                    raise Exception("ERROR: Forward reference of field {}  in itself is not allowed.".format(n.prefixLink.name))

                if n.prefixLink.__class__.__name__ == 'FieldNode' \
                and n.prefixLink.typeName == self.typeName \
                and self.order <= n.prefixLink.order \
                and "this" not in n.name:
                    raise Exception("ERROR: Forward reference of field {} is not allowed.".format(n.prefixLink.name))

    # Note: 1. Not calling codeGen for variableDcl since all variableDcl code are assuming that the variable is a LOCAL variable
    #       2. Only calling codeGen on the variableInit of self.variableDcl if the field is static, since non-static fields should
    #          be initialized by the constructor
    def codeGen(self):
        if hasattr(self, "code"):
            return
        self.code = ""
        self.data = ""
        label = self.typeName + "_" + self.name

        # static fields: the pointer lives in assembly
        if "static" in self.mods:
            self.data += ";Declaring a static field: " + label + "\n"
            self.data += pLabel(name=label, type="static") + \
                         p(instruction="dd", arg1=0, comment="Declaring space on assembly for a static field")
            self.data += ";End of declaration of static field\n"

            # Initializing static fields
            # static fields are intialized in the order of declaration within the class and has to be intialized
            # before the test() method is being called
            initNode = self.variableDcl.variableInit
            if initNode:
                self.code += ";Start of initialization of static field\n"
                initNode.codeGen()

                self.code += "; Calculating the initial value of declared field: " + label + "\n"
                self.code += initNode.code
                # Filling in label with pointer's address
                self.code += p(instruction="mov", arg1="ebx", arg2="dword S_"+label) + \
                             p(instruction="mov", arg1="[ebx]", arg2="eax", comment="eax is a pointer to field value in heap")

                self.code += ";End of initialization of static field\n"


###########################################################

# method
class MethodNode(ASTNode):
    # always list all fields in the init method to show the class structure
    def __init__(self, parseTree, typeName, order):
        self.parseTree = parseTree
        self.name = ''
        self.methodType = ''
        self.params = [] # a list of paramNodes
        self.mods = []
        self.body = None
        self.paramTypes = '' # a string of param types (signature) for easy type checking against arguments
        self.env = None
        self.children = []
        self.typeName = typeName
        self.order = order
        self.myType = None
        self.isInterM = False # if self is an interface method
        self.replace = None # pointer to the method that this method overrides

        # get method name
        nameNodes = getParseTreeNodes(['ID'], parseTree, ['params', 'type', 'methodBody'])
        for n in nameNodes:
            self.name = n.lex

        # params
        nameNodes = getParseTreeNodes(['param'], parseTree, ['methodBody'])
        for n in nameNodes:
            paramNode = ParamNode(n, self.typeName)
            self.params.append(paramNode)
            if paramNode.paramType.myType.isArray:
                self.paramTypes += paramNode.paramType.myType.name + "Array"
            else:
                self.paramTypes += paramNode.paramType.myType.name

        nameNodes = getParseTreeNodes(['type', 'VOID'], parseTree, ['methodBody', 'params'])
        for n in nameNodes:
            if n.name == 'VOID':
                self.methodType = TypeNode('VOID', self.typeName)
            else:
                self.methodType = TypeNode(n, self.typeName)

        for node in parseTree.children:
            if node.name == 'methodMod' or node.name == "interfaceMod":
                for m in node.children:
                    self.mods.append(m.lex)

            elif node.name == 'methodBody':
                nameNodes = getParseTreeNodes(['block'], node)
                for n in nameNodes:
                    self.body = BlockNode(n, typeName)

        if self.body: self.children.append(self.body)
        self.children.append(self.methodType)
        self.children.extend(self.params)

    def __eq__(self, other):
        if self.name == other.name and len(self.params) == len(other.params):
            for i in range(len(self.params)):
                if not self.params[i].paramType == other.params[i].paramType:
                    if self.name == 'addAll':
                        raise Exception('HERE {}, {}'.format(self.params[i].paramType.name, other.params[i].paramType.name))
                    return False
            return True
        return False

    def buildEnv(self, parentEnv):
        env = Env(parentEnv)
        for p in self.params:
            env.addtoEnv(p)
        self.env = env
        return env

    def disambigName(self):
        if self.body:
            self.body.disambigName()
        for p in self.params:
            p.disambigName()

    def checkType(self):
        if self.methodType: # constructor would be None
            self.myType = self.methodType.myType
        for p in self.params:
            p.checkType()
        if self.body:
            self.body.checkType()

        # Checking return types against the function type
        # No method body: do not check type as function isn't implemented
        if not self.body:
            return

        # check no use of this in static method
        if 'static' in self.mods:
            names = getNameNodes(self.body)
            for n in names:
                if 'this' in n.name or (n.pointToThis and n.prefixLink.__class__.__name__ == ['MethodNode', 'FieldNode'] and 'static' not in n.prefixLink.mods):
                        raise Exception("ERROR: Cannot use non-static member {} in static method {} in class {}".format(n.name, self.name, self.typeName))

        # With method body
        returnNodes = getASTNode(["ReturnNode"], self.body)


        for n in returnNodes:
            n.method = self
            # Checking for functions of type void
            # Only valid if either the function doesn't have a return statement, or the return statement is a semicolon (return;)
            if self.myType and self.myType.name == "void":
                if n.myType:
                    raise Exception("ERROR: return type of function {} doesn't match with return statement.".format(self.name))
                return
            # Checking for non void cases
            if self.myType and not self.myType.assignable(n.myType):
                raise Exception("ERROR: return type of function {} doesn't match with return statement.".format(self.name))
        return

    def reachCheck(self, inMaybe=True):
        self.outMaybe = False

        # Check reachability of method body
        if self.body:
            self.body.reachCheck(True) # For method bodies, in[L] = maybe by default
            self.outMaybe = self.body.outMaybe

        # Check if out[method body] is a maybe for non-void methods
        if not self.methodType or self.myType.name == "void":  # Omitting the check for constructors and void functions
            return
        if self.outMaybe:
            raise Exception("Non-void method '{}' in class '{}' does not return".format(self.name, self.typeName))
        return

    def codeGen(self):
        if hasattr(self, "code") and self.code != "":
            return

        self.label = "M_" + self.typeName + "_" + self.name + "_" + self.paramTypes
        self.code = pLabel(self.typeName + "_" + self.name + "_" + self.paramTypes, "method")

        # params
        stackoffset = 12
        if 'static' in self.mods:
            stackoffset = 8 # no object as argument for static methods
        rparams = self.params.copy()
        rparams.reverse()
        for i, param in enumerate(rparams):
            param.offset = i * 4 + stackoffset # 12 since the stack is now of the order: ebp, eip, o, params

        if self.body:

            bodyCode = ""
            # push all local var to stack
            vars = getVarDclNodes(self.body)
            for i in range(len(vars)):
                vars[i].offset = i * 4 + 16
                bodyCode += p("push", 0)

            self.body.codeGen()
            bodyCode += self.body.code

            bodyCode += self.label + "_end:            ; end of method for " + self.name + "\n"

            # pop off all the local var
            for i in range(len(vars)):
                bodyCode += p("pop", "edx")

            self.code += genProcedure(bodyCode, "method definition for " + self.name)
        else:
            self.code += p("mov", "eax", "0")
            self.code += p("ret", "")

    # This method is called instead of codeGen if this is a constructor
    def codeGenConstructor(self):
        if hasattr(self, "code") and self.code != "":
            return

        # populate param offsets
        # params
        stackoffset = 12
        rparams = self.params.copy()
        rparams.reverse()
        for i, param in enumerate(rparams):
            param.offset = i * 4 + stackoffset # 12 since the stack is now of the order: ebp, eip, o, params

        myClass = self.env.getNode(self.typeName, 'type')

        self.label = "M_" + self.typeName  + "_" + self.name + "_" + self.paramTypes
        self.code = pLabel(self.typeName  + "_" + self.name + "_" + self.paramTypes, "method") # label
        thisLoc = 8 # Right after ebp, eip
        bodyCode = ""

        if self.body:
            # push all local var to stack
            vars = getVarDclNodes(self.body)
            for i in range(len(vars)):
                vars[i].offset = i * 4 + 16
                bodyCode += p("push", 0)

        # call parent constructor(zero argument)
        if myClass.superClass:
            suLabel = "M_" + myClass.superClass.name + "_" + myClass.superClass.name + "_"
            bodyCode += importHelper(myClass.superClass.name, self.typeName, suLabel)
            bodyCode += p("mov", "eax", "[ebp + " + str(thisLoc) + "]")
            bodyCode += p("push", "eax", None, "# Pass THIS as argument to superClass.")
            bodyCode += p("call", suLabel)
            bodyCode += p("add", "esp", "4") # pop object off stack


        # init fields
        fields = sorted(myClass.fields, key=lambda x: x.order)
        for f in fields:
            if not 'static' in f.mods and f.variableDcl.variableInit:
                f.variableDcl.variableInit.codeGen()
                bodyCode += f.variableDcl.variableInit.code
                bodyCode += p("mov", "ebx", "[ebp + " + str(thisLoc) + "]") # THIS
                bodyCode += p("mov", "[ebx + " + str(f.offset) + " ]", "eax")

        # body code
        if self.body:
            self.body.codeGen()
            bodyCode += self.body.code
            bodyCode += self.label + "_end:            ; end of method for " + self.name + "\n"

            # pop off all the local var
            for i in range(len(vars)):
                bodyCode += p("pop", "edx")
        else:
            bodyCode += p("mov", "eax", "0")

        self.code += genProcedure(bodyCode, "Constructor definition for " + self.name + " " + self.paramTypes)

    # gets the top-most level method that this method overrides
    def getTopReplace(self):
        if not self.replace:
            return self
        else:
            return self.replace.getTopReplace()

############# helper for forward ref checking ########
# Input: AST Node
# Output: A list of names to be check
def getForwardRefNames(node):
    if node.__class__.__name__ == 'NameNode':
        return [node]

    result = []
    if node.__class__.__name__ == 'AssignNode':
        result.extend(getForwardRefNames(node.right))
    else:
        for c in node.children:
            result.extend(getForwardRefNames(c))

    return result

# Input: AST Node
# Output: A list of names to be check
def getNameNodes(node):
    if not node:
        return []

    if node.__class__.__name__ == 'NameNode':
        return [node]

    result = []
    for c in node.children:
        result.extend(getNameNodes(c))

    return result

# Input: Block Node
# Output: A list of local var dcl to be pushed onto the stack
def getVarDclNodes(node):
    result = []

    if node.__class__.__name__ == 'VarDclNode':
        return [node]
    elif node.__class__.__name__ == 'BlockNode':
        for s in node.statements:
            if s.__class__.__name__ == 'VarDclNode':
                result += [s]

            if s.__class__.__name__ == 'BlockNode':
                result += getVarDclNodes(node)

            if s.__class__.__name__ == 'ForNode':
                if s.forInit:
                    result += [s.forInit]
                if s.bodyStatement:
                    result += getVarDclNodes(s.bodyStatement)

            if s.__class__.__name__  =='WhileNode' and s.whileBody:
                result += getVarDclNodes(s.whileBody)
            if s.__class__.__name__ == "IfNode":
                result += getVarDclNodes(s.ifBody)
                if s.elseBody:
                    result += getVarDclNodes(s.elseBody)

    return result