from AST import ASTNode, getParseTreeNodes, getASTNode
from LineNodes import BlockNode, VarDclNode
from ExprPrimaryNodes import makeNodeFromExpr
from UnitNodes import ParamNode
from TheTypeNode import TypeNode
from Environment import Env
from collections import OrderedDict

# field
class FieldNode(ASTNode):
    # always list all fields in the init method to show the class structure
    def __init__(self, parseTree, typeName, order):
        self.parseTree = parseTree
        self.name = ''
        self.variableDcl = None
        self.mods = []
        self.env = None
        self.children = []
        self.typeName = typeName
        self.order = order

        for node in parseTree.children:
            if node.name == 'methodMod':
                for m in node.children:
                    self.mods.append(m.lex)

            elif node.name == 'variableDcl':
                self.variableDcl = VarDclNode(node, self.typeName)

        self.name = self.variableDcl.name
        self.myType = self.variableDcl.myType

        self.children.append(self.variableDcl)

    def __eq__(self, other):
        return self.name == other.name

    def checkType(self):
        self.variableDcl.checkType()
        # check forward reference
        if self.variableDcl.variableInit:
            allNames = getForwardRefNames(self.variableDcl.variableInit)
            from pprint import pprint

            for n in allNames:
                if n.prefixLink is self.variableDcl:
                    raise Exception("ERROR: Forward reference of field {}  in itself is not allowed.".format(n.prefixLink.name))

                if n.prefixLink.__class__.__name__ == 'FieldNode' \
                and n.prefixLink.typeName == self.typeName \
                and self.order <= n.prefixLink.order \
                and "this" not in n.name:
                    raise Exception("ERROR: Forward reference of field {} is not allowed.".format(n.prefixLink.name))

###########################################################

# method
class MethodNode(ASTNode):
    # always list all fields in the init method to show the class structure
    def __init__(self, parseTree, typeName, order):
        self.parseTree = parseTree
        self.name = ''
        self.methodType = ''
        self.params = [] # a list of paramNodes
        self.mods = []
        self.body = None
        self.paramTypes = '' # a string of param types (signature) for easy type checking against arguments
        self.env = None
        self.children = []
        self.typeName = typeName
        self.order = order

        # get method name
        nameNodes = getParseTreeNodes(['ID'], parseTree, ['params', 'type', 'methodBody'])
        for n in nameNodes:
            self.name = n.lex

        # params
        nameNodes = getParseTreeNodes(['param'], parseTree, ['methodBody'])
        for n in nameNodes:
            paramNode = ParamNode(n, self.typeName)
            self.params.append(paramNode)
            self.paramTypes += paramNode.paramType.myType.name

        nameNodes = getParseTreeNodes(['type', 'VOID'], parseTree, ['methodBody', 'params'])
        for n in nameNodes:
            if n.name == 'VOID':
                self.methodType = TypeNode('VOID', self.typeName)
            else:
                self.methodType = TypeNode(n, self.typeName)

        for node in parseTree.children:
            if node.name == 'methodMod' or node.name == "interfaceMod":
                for m in node.children:
                    self.mods.append(m.lex)

            elif node.name == 'methodBody':
                nameNodes = getParseTreeNodes(['block'], node)
                for n in nameNodes:
                    self.body = BlockNode(n, typeName)

        if self.body: self.children.append(self.body)
        self.children.append(self.methodType)
        self.children.extend(self.params)

    def __eq__(self, other):
        if self.name == other.name and len(self.params) == len(other.params):
            for i in range(len(self.params)):
                if not self.params[i].paramType == other.params[i].paramType:
                    if self.name == 'addAll':
                        raise Exception('HERE {}, {}'.format(self.params[i].paramType.name, other.params[i].paramType.name))
                    return False
            return True
        return False

    def buildEnv(self, parentEnv):
        env = Env(parentEnv)
        for p in self.params:
            env.addtoEnv(p)
        self.env = env
        return env

    def disambigName(self):
        if self.body:
            self.body.disambigName()
        for p in self.params:
            p.disambigName()

    def checkType(self):
        if self.methodType: # constructor would be None
            self.myType = self.methodType.myType
        for p in self.params:
            p.checkType()
        if self.body:
            self.body.checkType()

        # Checking return types against the function type
        # No method body: do not check type as function isn't implemented
        if not self.body:
            return

        # check no use of this in static method
        if 'static' in self.mods:
            names = getNameNodes(self.body)
            for n in names:
                if 'this' in n.name or (n.pointToThis and n.prefixLink.__class__.__name__ == ['MethodNode', 'FieldNode'] and 'static' not in n.prefixLink.mods):
                        raise Exception("ERROR: Cannot use non-static member {} in static method {} in class {}".format(n.name, self.name, self.typeName))

        # With method body
        returnNodes = getASTNode(["ReturnNode"], self.body)

        # Checking for cases where there are no return statements
        if not returnNodes:
            # Either a constructor or the function has type Void
            if not self.methodType or self.myType.name == "void":
                return
            raise Exception("ERROR: no return statement at function {}".format(self.name))

        # Checking for cases where there are return statements
        for n in returnNodes:
            # Checking for functions of type void
            # Only valid if either the function doesn't have a return statement(checked above), or the return statement is a semicolon (return;)
            if self.myType.name == "void":
                if n.myType:
                    raise Exception("ERROR: return type of function {} doesn't match with return statement.".format(self.name))
                return
            # Checking for non void cases
            if not self.myType.assignable(n.myType):
                raise Exception("ERROR: return type of function {} doesn't match with return statement.".format(self.name))
        return

############# helper for forward ref checking ########
# Input: AST Node
# Output: A list of names to be check
def getForwardRefNames(node):
    if node.__class__.__name__ == 'NameNode':
        return [node]

    result = []
    if node.__class__.__name__ == 'AssignNode':
        result.extend(getForwardRefNames(node.right))
    else:
        for c in node.children:
            result.extend(getForwardRefNames(c))

    return result

# Input: AST Node
# Output: A list of names to be check
def getNameNodes(node):
    if not node:
        return []

    if node.__class__.__name__ == 'NameNode':
        return [node]

    result = []
    for c in node.children:
        result.extend(getNameNodes(c))

    return result