from AST import ASTNode, getParseTreeNodes
from Environment import Env
from UnitNodes import LiteralNode
import MemberNodes
from TheTypeNode import TypeNode, TypeStruct
from NameNode import NameNode, checkProtected

# file containing smaller (lower level nodes) in the AST
# nodes in this file:
#   ArgsNode
#   ArrayAccessNode
#   ArrayCreateNode
#   AssignNode
#   CastNode
#   ClassCreateNode
#   ExprNode
#   FieldAccessNode
#   MethodInvNode

# TODO: go over nodes in this file to see if need to overright buildEnv

###########################################################
# factory methods
###########################################################

# parses the expr node in the parse tree to be either ID or CastNode or ExprNode
def makeNodeFromExpr(parseTree, typeName):
    c = parseTree
    while (42):
        if c.name == 'primaryAndArray':
            return makeNodeFromAllPrimary(c, typeName)
        elif c.name == 'ID' or c.name == 'COMPID':
            return NameNode(c, False, typeName) # TODO is this always False??
        elif c.name == 'assignment':
            return AssignNode(c, typeName)
        elif c.name == 'refType':
            return TypeNode(c, typeName)
        elif len(c.children) == 1:
            c = c.children[0]
        elif c.name == 'castExpr':
            return CastNode(c, typeName)
        else:
            return ExprNode(c, typeName)

# parses the primaryAndArray/primary/primaryNoArrayAccess node in the parse tree and return corresponding AST nodes
def makeNodeFromAllPrimary(parseTree, typeName):
    if parseTree.name == 'primaryAndArray':
        if parseTree.children[0].name == 'arrayCreationExpr':
            parseTree = parseTree.children[0]
            return ArrayCreateNode(parseTree, typeName)
        elif parseTree.children[0].name == 'primary':
            if parseTree.children[0].children[0].name == 'arrayAccess':
                return ArrayAccessNode(parseTree.children[0].children[0], typeName)
            parseTree = parseTree.children[0].children[0]

    if parseTree.name == 'primary':
        if parseTree.children[0].name == 'arrayAccess':
            return ArrayAccessNode(parseTree.children[0], typeName)
        parseTree = parseTree.children[0]

    node = parseTree.children[0]
    if node.name == 'literal':
        return LiteralNode(node, typeName)
    elif node.name == 'LPAREN':  # primaryNoArrayAccess LPAREN expr RPAREN
        return makeNodeFromExpr(parseTree.children[1], typeName)
    elif node.name == 'classInstanceCreate':
        return ClassCreateNode(node.children[0], typeName)
    elif node.name == 'methodInvoc':
        return MethodInvNode(node, typeName)
    elif node.name == 'fieldAccess':
        return FieldAccessNode(node, typeName)
    else:
        raise Exception('ERROR: something wrong at primaryNoArrayAccess')

###########################################################
# helper methods
###########################################################
def helperDisambigName(node):
    if node and node.__class__.__name__ == "NameNode":
        try:
            node.disambigName()
        except Exception as e:
            raise e

###########################################################

########################### Node Definitions #####################################

###################################################################################
# args exprs, exprs expr COMMA exprs
class ArgsNode(ASTNode):
    # always list all fields in the init method to show the class structure
    def __init__(self, parseTree, typeName):
        self.parseTree = parseTree
        self.exprs = []  # a list of expressions
        self.env = None
        self.children = []
        self.typeName = typeName # the type (class/interface) this node belongs under

        exprs = getParseTreeNodes(['expr'], parseTree)
        for e in exprs:
            self.exprs.append(makeNodeFromExpr(e, typeName))

        self.children.extend(self.exprs)

    def disambigName(self):
        for expr in self.exprs:
            expr.disambigName()
            # helperDisambigName(expr)


###################################################################################
# Array Access
class ArrayAccessNode(ASTNode):
    # always list all fields in the init method to show the class structure
    def __init__(self, parseTree, typeName):
        self.parseTree = parseTree
        self.array = ''  # either a variableName, a field, or array access
        self.index = '' # expr
        self.env = None
        self.children = []
        self.typeName = typeName # the type (class/interface) this node belongs under

        # input parse tree is either: arrayAccess name LSQRBRACK expr RSQRBRACK
        # arrayAccess ID LSQRBRACK expr RSQRBRACK
        # arrayAccess primaryNoArrayAccess LSQRBRACK expr RSQRBRACK
        if (parseTree.children[0].name == 'primaryNoArrayAccess'):
            self.array = makeNodeFromAllPrimary(parseTree.children[0], typeName)
        else:
            self.array = NameNode(parseTree.children[0], False, typeName)
        self.index = makeNodeFromExpr(parseTree.children[2], typeName)

        self.children.append(self.array)
        self.children.append(self.index)

    def disambigName(self):
        self.array.disambigName()
        self.index.disambigName()

    def checkType(self):
        self.array.disambigName() # hacky fix, not sure why disambigName wasn't called before
        self.array.checkType()
        self.index.checkType()
        if not self.array.myType.isArray:
            raise Exception("ERROR: Cannot perform array access on non-array {}".format(self.array.name))
        if not self.index.myType.isNum():
            raise Exception("ERROR: Array index must be a number.")
        self.myType = TypeStruct(self.array.myType.name, self.array.myType.typePointer)


###################################################################################
# arrayCreationExpr
# arrayCreationExpr NEW primitiveType LSQRBRACK expr  RSQRBRACK
# arrayCreationExpr NEW name LSQRBRACK expr RSQRBRACK
# arrayCreationExpr NEW primitiveType LSQRBRACK RSQRBRACK
# arrayCreationExpr NEW name LSQRBRACK RSQRBRACK
class ArrayCreateNode(ASTNode):
    # always list all fields in the init method to show the class structure
    def __init__(self, parseTree, typeName):
        self.parseTree = parseTree
        self.arrayType = ''
        self.arraySize = 0 # or Expr
        self.env = None
        self.children = []
        self.typeName = typeName # the type (class/interface) this node belongs under

        # input is arrayCreationExpr NEW type LSQRBRACK expr RSQRBRACK
        self.arrayType = TypeNode(parseTree.children[1], typeName)

        expr = getParseTreeNodes(['expr'], parseTree)
        if len(expr) > 0:
            self.arraySize = makeNodeFromExpr(expr[0], typeName)

        self.children.append(self.arrayType)
        self.children.append(self.arraySize)

    def disambigName(self):
        self.arraySize.disambigName()
        # helperDisambigName(self.arraySize)

    def checkType(self):
        if self.arraySize != 0:
            self.arraySize.checkType()
        if not self.arraySize.myType.isNum():
            raise Exception("ERROR: Array index must be a number.")
        self.myType = TypeStruct(self.arrayType.myType.name, self.arrayType.myType.typePointer)
        self.myType.isArray = True

###################################################################################
# assignment leftHandSide ASSIGN expr
class AssignNode(ASTNode):
    # always list all fields in the init method to show the class structure
    def __init__(self, parseTree, typeName):
        self.parseTree = parseTree
        self.left = None
        self.right = makeNodeFromExpr(parseTree.children[2], typeName)
        self.env = None
        self.children = []
        self.typeName = typeName # the type (class/interface) this node belongs under

        if parseTree.children[0].children[0].name == 'fieldAccess':
            self.left = FieldAccessNode(parseTree.children[0].children[0], typeName)
        elif parseTree.children[0].children[0].name == 'arrayAccess':
            self.left = ArrayAccessNode(parseTree.children[0].children[0], typeName)
        else:
            self.left = NameNode(parseTree.children[0].children[0], False, typeName)

        self.children.append(self.right)
        self.children.append(self.left)

    def disambigName(self):
        self.left.disambigName()
        self.right.disambigName()
        # helperDisambigName(self.right)
        # helperDisambigName(self.left)

    def checkType(self):
        self.left.checkType()
        self.right.checkType()

        if self.left.myType.assignable(self.right.myType):
            self.myType = self.left.myType
            return

        raise Exception("ERROR: assignment operation failed. Cannot assign type {0} to type {1} at class {2}".format(self.left.myType.name, self.right.myType.name, self.typeName))

    def reachCheck(self, inMaybe):
        if not inMaybe:
            raise Exception("ERROR: not reaching a assignment statement")
        self.outMaybe = inMaybe


##################################################################################
# cast: castExpr LPAREN castType RPAREN unaryNotPlusMinus
class CastNode(ASTNode):
    # always list all fields in the init method to show the class structure
    def __init__(self, parseTree, typeName):
        self.parseTree = parseTree
        self.left = parseTree.children[1]  # cast: (left)right
        self.right = makeNodeFromExpr(parseTree.children[3], typeName) # expr
        self.env = None
        self.children = []
        self.typeName = typeName # the type (class/interface) this node belongs under

        if self.left.name == 'expr':
            self.left = makeNodeFromExpr(self.left, typeName)
        else: #primitiveType or ArrayType
            self.left = TypeNode(self.left, typeName)
        # since a type might be mis-parsed as a name
        if self.left.__class__.__name__ == 'NameNode':
            self.left = TypeNode(self.parseTree.children[1], typeName)

        self.children.append(self.left)
        self.children.append(self.right)

    def disambigName(self):
        if self.left.__class__.__name__ != 'TypeNode':
            self.left.disambigName()
        self.right.disambigName()
        # helperDisambigName(self.left)
        # helperDisambigName(self.right)

    def checkType(self):
        self.left.checkType()
        from pprint import pprint
        self.right.disambigName()
        self.right.checkType()
        if (self.left.myType.isNum() and self.right.myType.isNum()) \
        or self.left.myType.assignable(self.right.myType) \
        or self.right.myType.assignable(self.left.myType):
            self.myType = self.left.myType
            return
        raise Exception("ERROR: Cannot cast type {} to type {}.".format(self.right.myType.name, self.left.myType.name))

###################################################################################
# unqualCreate NEW name LPAREN args RPAREN
class ClassCreateNode(ASTNode):
    # always list all fields in the init method to show the class structure
    def __init__(self, parseTree, typeName):
        self.parseTree = parseTree
        self.className = TypeNode(parseTree.children[1], typeName)
        self.args = ArgsNode(parseTree.children[3], typeName)
        self.env = None
        self.children = [self.className, self.args]
        self.typeName = typeName
        self.cons = None # the constructor used to create the class

    def checkType(self):
        # return # TO REMOVE after name node type checking is done

        self.args.checkType()
        classDef = self.className.myType.typePointer
        # check class is not abstract
        if 'abstract' in classDef.mods:
            raise Exception('ERROR: Cannot create an instance of abstract class {}.'.format(self.className.myType.name))
        elif classDef.__class__.__name__ != 'ClassNode':
            raise Exception('ERROR: Cannot create an instance of {}, it is not a class.'.format(self.className.myType.name))

        # check 0 arguement constructor of superclass exists
        su = classDef.superClass
        while su != '':  # if it doesn't have an explict super class, its super class is java.lang.object, which is safe
            found = False
            for c in su.constructors:
                if c.params == []:
                    found = True
                    break
            if not found:
                raise Exception("ERROR: Class {} doesn't have a zero-arguement constructor.".format(su.name))
            su = su.superClass

        # get constructor using arg Types
        m = getMethod(classDef.constructors, "", self.args)
        if m:
            self.cons = m
            self.myType = self.className.myType
        else:
            raise Exception("ERROR: Class {} doesn't have a constructor with given argument types.".format(classDef.name))

        # check to make sure we are allowed to call this (protected?)
        # if self.cons is protected, check that:
        # - current class is in the same package
        if 'protected' in self.cons.mods:
            curClass = self.env.getNode(self.typeName, 'type')

            if curClass.packageName != classDef.packageName:
                raise Exception("ERROR: In class {0}, using a protected constructor, but class {1} is not in class {0}'s package ({2}).".format(curClass.name, classDef.name, curClass.packageName))


#################################################################################
# condOrExpr
class ExprNode(ASTNode):
    # always list all fields in the init method to show the class structure
    def __init__(self, parseTree, typeName):
        self.parseTree = parseTree
        self.left = None
        self.op = ''
        self.right = None  # another expr
        self.env = None
        self.children = []
        self.typeName = typeName # the type (class/interface) this node belongs under

        if parseTree.name == 'unaryNotPlusMinus' or parseTree.name == 'unaryExpr':
            self.op = parseTree.children[0].lex
            self.right = makeNodeFromExpr(parseTree.children[1], typeName)
        else:
            self.left = makeNodeFromExpr(parseTree.children[0], typeName)
            self.op = parseTree.children[1].lex
            self.right = makeNodeFromExpr(parseTree.children[2], typeName)

        self.children.append(self.left)
        self.children.append(self.right)

    def disambigName(self):
        if self.left:
            self.left.disambigName()
        self.right.disambigName()
        # helperDisambigName(self.left)
        # helperDisambigName(self.right)

    # use wrong name to stop method from being called until we finish other implemetation
    # def checkType(self):
    def checkType(self):
        # steps of type checking:
        #   check children's types (children's myType field will be populated with a typeStruct)
        #   check using the rule for current node
        #   make a TypeStruct node and populate myType field for self

        super().checkType() # check children's type first to populate their myType field

        # Unary operations:
        if not self.left:
            if self.op == '-' and self.right.myType.isNum():
                self.myType = TypeStruct("int", None)
                return
            elif self.op == '!' and self.right.myType.name == 'boolean':
                self.myType = self.myType = TypeStruct("boolean", None)
                return

        # Numeric types
        if self.left.myType.isNum() and self.right.myType.isNum():
            # Comparisons:
            if self.op in ['==', '!=', '<=', '>=', '>', '<']:
                self.myType = TypeStruct("boolean", None)
                return
            # numeric operations:
            elif self.op in ['+', '-', '*', '/', '%']:
                self.myType = TypeStruct("int", None)
                return
        # Boolean operations:
        if self.left.myType.name == 'boolean' and self.right.myType.name == 'boolean':
            if self.op in ['&&', '&', '|', '||', '!=', '==']:
                self.myType = TypeStruct("boolean", None)
                return

        if self.left.myType.assignable(self.right.myType) or self.right.myType.assignable(self.left.myType):
            if self.op == '==' or self.op == '!=' or self.op == 'instanceof':
                self.myType = TypeStruct("boolean", None)
                return

        # String concat:
        if ((self.left.myType.name =='java.lang.String' and self.right.myType.name not in ['void']) \
        or (self.right.myType.name =='java.lang.String' and self.left.myType.name not in ['void'])) and self.op == '+':
            self.myType = TypeStruct('java.lang.String', self.env.getNode('java.lang.String', 'type'))
            self.myType.link(self.env)
            return

        raise Exception("ERROR: Incompatible types. Left of {} type can't be used with right of {} type on operation {}".format(self.left.myType.name, self.right.myType.name, self.op))

    # returns True, False, Int or None (for non-constant expr)
    # children of exprNode is either exprNode or literalNode
    def getConstant(self):
        if not hasattr(self.right, "getConstant"):
            return None
        cRight = self.right.getConstant()
        if cRight == None:
            return None

        # Unary Ops
        if not self.left:
            if self.op == '-':
                return -cRight
            return not cRight # op = '!'

        else:
            if not hasattr(self.left, "getConstant"):
                return None
            cLeft = self.left.getConstant()
            if cLeft == None:
                return None

            # arithmetic
            if self.op == '+':
                return cLeft + cRight
            elif self.op == '-':
                return cLeft - cRight
            elif self.op == '*':
                return cLeft * cRight
            elif self.op == '/':
                return cLeft // cRight
            elif self.op == '%':
                return cLeft % cRight
            # Comparison
            elif self.op == '==':
                return cLeft == cRight
            elif self.op == '!=':
                return cLeft != cRight
            elif self.op == '>':
                return cLeft > cRight
            elif self.op == '<':
                return cLeft < cRight
            elif self.op == '>=':
                return cLeft >= cRight
            elif self.op == '<=':
                return cLeft <= cRight
            # boolean Ops
            elif self.op == '&&' or self.op == '&':
                return cLeft and cRight
            elif self.op == '||' or self.op == '|':
                return cLeft or cRight
            else:
                return None

###################################################################################
# fieldAccess primary PERIOD ID
class FieldAccessNode(ASTNode):
    # always list all fields in the init method to show the class structure
    def __init__(self, parseTree, typeName):
        self.parseTree = parseTree
        self.primary = ''
        self.ID = '' # method/fieldName
        self.env = None
        self.children = []
        self.typeName = typeName # the type (class/interface) this node belongs under

        # input: fieldAccess primary PERIOD ID
        self.primary = makeNodeFromAllPrimary(parseTree.children[0], typeName)
        self.ID = NameNode(parseTree.children[2], False, typeName)

        self.children.append(self.primary)
        self.children.append(self.ID)

    def disambigName(self):
        if not self.primary: # this implies that the ID has nothing that comes before it
            # helperDisambigName(self.ID)
            self.ID.disambigName()
            # self.right.disambigName()
        else:
            self.primary.disambigName()

    def checkType(self):
        self.primary.checkType()
        if self.primary.myType.isArray or self.primary.myType.isPrimitive:
            self.ID.prefixLink = self.primary
        else:
            self.ID.prefixLink = self.primary.myType.typePointer
        self.ID.checkType()
        self.myType = self.ID.myType

        # check protected
        try:
            if "protected" in self.ID.prefixLink.mods:
                checkProtected(self.ID.prefixLink, self)
        except: # where there are no mods
            return



###################################################################################
# methodInvoc
class MethodInvNode(ASTNode):
    # always list all fields in the init method to show the class structure
    def __init__(self, parseTree, typeName):
        self.parseTree = parseTree
        self.primary = None  # can be empty
        self.ID = '' # can be either ID or compID
        self.args = None
        self.env = None
        self.children = []
        self.method = None
        self.typeName = typeName # the type (class/interface) this node belongs under

        # input parse tree is either: methodInvoc primary PERIOD ID LPAREN args RPAREN
        #  methodInvoc name LPAREN args RPAREN
        self.ID = NameNode(parseTree.children[-4], True, typeName)
        self.args = ArgsNode(parseTree.children[-2], typeName)
        if parseTree.children[0].name == 'primary':
            self.primary = makeNodeFromAllPrimary(parseTree.children[0], typeName)

        self.children.append(self.primary)
        self.children.append(self.args)
        self.children.append(self.ID)

    def disambigName(self):
        if not self.primary: # this implies our ID doesn't have anything that comes before it
            if isinstance(self.ID, NameNode) and len(self.ID.IDs) > 1:
                self.ID.disambigName()
                # helperDisambigName(self.ID)
        if self.args:
            self.args.disambigName()

    def checkType(self):
        # steps of type checking:
        #   check param's types
        #   check using the rule for current node
        #   make a TypeStruct node for self (return type of method)

        # populate params myTypes
        for param in self.args.exprs:
            param.checkType()

        # now that we have the method name, param types, we need to:
        # - check if method exists under the class its under

        m = None
        if not self.primary:
            self.ID.checkType()
            m = getMethod(self.ID.prefixLink.values(), self.ID.methodName, self.args)
        else:
            self.primary.checkType()
            methods = []
            methods.extend(self.primary.myType.typePointer.methods)
            methods.extend([meth for meth in self.primary.myType.typePointer.inherits if isinstance(meth, MemberNodes.MethodNode)]) # need to check inherited methods as well
            m = getMethod(methods, self.ID.name, self.args)

        if m:
            # check static
            if self.ID.shouldBeStatic and (not 'static' in m.mods):
                raise Exception("ERROR: Static access of non-static method {}.".format(m.name))
            if (not self.ID.shouldBeStatic) and 'static' in m.mods:
                raise Exception("ERROR: Non-static access of static method {}.".format(m.name))

            # check protected
            if "protected" in m.mods:
                checkProtected(m, self)

            self.method = m
            self.myType = m.methodType.myType
            return
        else:

            raise Exception("ERROR: Class {} doesn't have a method {} with given argument types.".format(self.typeName, self.ID.name))
            
    def reachCheck(self, inMaybe):
        if not inMaybe:
            raise Exception("ERROR: not reaching a variable declaration statement for var {}".format(self.name))
        self.outMaybe = inMaybe

################# Helper #######################

def getMethod(methods, methodName, args):
    # methodName = "" if it's constructor
    for c in methods:
        if (methodName == "" or c.name == methodName) and len(args.exprs) == len(c.params):
            found = True
            for i, param in enumerate(args.exprs):
                if c.params[i].paramType.myType != param.myType:
                    found = False
            if found:
                return c
    return None