from AST import ASTNode, getParseTreeNodes
from Environment import Env
from UnitNodes import LiteralNode
from TheTypeNode import TypeNode, TypeStruct
from NameNode import NameNode

# file containing smaller (lower level nodes) in the AST
# nodes in this file:
#   ArgsNode
#   ArrayAccessNode
#   ArrayCreateNode
#   AssignNode
#   CastNode
#   ClassCreateNode
#   ExprNode
#   FieldAccessNode
#   MethodInvNode

# TODO: go over nodes in this file to see if need to overright buildEnv

###########################################################
# factory methods
###########################################################

# parses the expr node in the parse tree to be either ID or CastNode or ExprNode
def makeNodeFromExpr(parseTree, typeName):
    c = parseTree
    while (42):
        if c.name == 'primaryAndArray':
            return makeNodeFromAllPrimary(c, typeName)
        elif c.name == 'ID' or c.name == 'COMPID':
            return NameNode(c, typeName, typeName)
        elif c.name == 'assignment':
            return AssignNode(c, typeName)
        elif len(c.children) == 1:
            c = c.children[0]
        elif c.name == 'castExpr':
            return CastNode(c, typeName)
        else:
            return ExprNode(c, typeName)

# parses the primaryAndArray/primary/primaryNoArrayAccess node in the parse tree and return corresponding AST nodes
def makeNodeFromAllPrimary(parseTree, typeName):
    if parseTree.name == 'primaryAndArray':
        if parseTree.children[0].name == 'arrayCreationExpr':
            parseTree = parseTree.children[0]
            return ArrayCreateNode(parseTree, typeName)
        elif parseTree.children[0].name == 'primary':
            if parseTree.children[0].children[0].name == 'arrayAccess':
                return ArrayAccessNode(parseTree.children[0].children[0], typeName)
            parseTree = parseTree.children[0].children[0]

    if parseTree.name == 'primary':
        if parseTree.children[0].name == 'arrayAccess':
            return ArrayAccessNode(parseTree.children[0], typeName)
        parseTree = parseTree.children[0]

    node = parseTree.children[0]
    if node.name == 'literal':
        return LiteralNode(node, typeName)
    elif node.name == 'LPAREN':  # primaryNoArrayAccess LPAREN expr RPAREN
        return makeNodeFromExpr(parseTree.children[1], typeName)
    elif node.name == 'classInstanceCreate':
        return ClassCreateNode(node, typeName)
    elif node.name == 'methodInvoc':
        return MethodInvNode(node, typeName)
    elif node.name == 'fieldAccess':
        return FieldAccessNode(node, typeName)
    else:
        raise Exception('ERROR: something wrong at primaryNoArrayAccess')

###########################################################
# helper methods
###########################################################
def helperDisambigName(node):
    if node and node.__class__.__name__ == "NameNode":
        try:
            node.disambigName()
        except Exception as e:
            raise e

###########################################################

########################### Node Definitions #####################################

###################################################################################
# args exprs, exprs expr COMMA exprs
class ArgsNode(ASTNode):
    # always list all fields in the init method to show the class structure
    def __init__(self, parseTree, typeName):
        self.parseTree = parseTree
        self.exprs = []  # a list of expressions
        self.env = None
        self.children = []
        self.typeName = typeName # the type (class/interface) this node belongs under

        exprs = getParseTreeNodes(['expr'], parseTree)
        for e in exprs:
            self.exprs.append(makeNodeFromExpr(e, typeName))

        self.children.extend(self.exprs)

    def disambigName(self):
        for expr in self.exprs:
            helperDisambigName(expr)


###################################################################################
# Array Access
class ArrayAccessNode(ASTNode):
    # always list all fields in the init method to show the class structure
    def __init__(self, parseTree, typeName):
        self.parseTree = parseTree
        self.array = ''  # either a variableName, a field, or array access
        self.index = '' # expr
        self.env = None
        self.children = []
        self.typeName = typeName # the type (class/interface) this node belongs under

        # input parse tree is either: arrayAccess name LSQRBRACK expr RSQRBRACK
        # arrayAccess ID LSQRBRACK expr RSQRBRACK
        # arrayAccess primaryNoArrayAccess LSQRBRACK expr RSQRBRACK
        if (parseTree.children[0].name == 'primaryNoArrayAccess'):
            self.array = makeNodeFromAllPrimary(parseTree.children[0], typeName)
        else:
            self.array = NameNode(parseTree.children[0].children[0], False, typeName)
        self.index = makeNodeFromExpr(parseTree.children[2], typeName)

        self.children.append(self.array)
        self.children.append(self.index)

    def disambigName(self):
        helperDisambigName(self.array)
        helperDisambigName(self.index)

###################################################################################
# arrayCreationExpr
# arrayCreationExpr NEW primitiveType LSQRBRACK expr  RSQRBRACK
# arrayCreationExpr NEW name LSQRBRACK expr RSQRBRACK
# arrayCreationExpr NEW primitiveType LSQRBRACK RSQRBRACK
# arrayCreationExpr NEW name LSQRBRACK RSQRBRACK
class ArrayCreateNode(ASTNode):
    # always list all fields in the init method to show the class structure
    def __init__(self, parseTree, typeName):
        self.parseTree = parseTree
        self.arrayType = ''
        self.arraySize = 0 # or Expr
        self.env = None
        self.children = []
        self.typeName = typeName # the type (class/interface) this node belongs under

        # input is arrayCreationExpr NEW type LSQRBRACK expr RSQRBRACK
        self.arrayType = TypeNode(parseTree.children[1], typeName)

        expr = getParseTreeNodes(['expr'], parseTree)
        if len(expr) > 0:
            self.arraySize = makeNodeFromExpr(expr[0], typeName)

        self.children.append(self.arrayType)
        self.children.append(self.arraySize)
    
    def disambigName(self):
        helperDisambigName(self.arraySize)

###################################################################################
# assignment leftHandSide ASSIGN expr
class AssignNode(ASTNode):
    # always list all fields in the init method to show the class structure
    def __init__(self, parseTree, typeName):
        self.parseTree = parseTree
        self.left = None
        self.right = makeNodeFromExpr(parseTree.children[2], typeName)
        self.env = None
        self.children = []
        self.typeName = typeName # the type (class/interface) this node belongs under

        if parseTree.children[0].name == 'fieldAccess':
            self.left = FieldAccessNode(parseTree.children[0], typeName)
        elif parseTree.children[0].name == 'arrayAccess':
            self.left = ArrayAccessNode(parseTree.children[0], typeName)
        else:
            self.left = NameNode(parseTree.children[0].children[0], False, typeName)

        self.children.append(self.right)
        self.children.append(self.left)
    
    def disambigName(self):
        helperDisambigName(self.right)
        helperDisambigName(self.left)


##################################################################################
# cast: castExpr LPAREN castType RPAREN unaryNotPlusMinus
class CastNode(ASTNode):
    # always list all fields in the init method to show the class structure
    def __init__(self, parseTree, typeName):
        self.parseTree = parseTree
        self.left = parseTree.children[1]  # cast: (left)right
        self.right = makeNodeFromExpr(parseTree.children[3], typeName) # expr
        self.env = None
        self.children = []
        self.typeName = typeName # the type (class/interface) this node belongs under

        if self.left.name == 'expr':
            self.left = makeNodeFromExpr(self.left, typeName)
        else: #primitiveType or ArrayType
            self.left = TypeNode(self.left, typeName)

        self.children.append(self.left)
        self.children.append(self.right)
    
    def disambigName(self):
        helperDisambigName(self.left)
        helperDisambigName(self.right) 

###################################################################################
# classInstanceCreate unqualCreate
# unqualCreate NEW name LPAREN args RPAREN
class ClassCreateNode(ASTNode):
    # always list all fields in the init method to show the class structure
    def __init__(self, parseTree, typeName):
        self.parseTree = parseTree.children[0] # input is classInstanceCreate unqualCreate
        self.className = TypeNode(parseTree.children[0].children[1], typeName)
        self.args = ArgsNode(self.parseTree.children[3], typeName)
        self.env = None
        self.children = [self.args, self.className]
        self.typeName = typeName

    def checkType(self):
        # check class is not abstract
        if 'abstract' in self.className.myType.typePointer.mods:
            raise Exception('ERROR: Cannot create an instance of abstract class {}.'.format(self.className.myType.name))
        # TODO: more type checking



#################################################################################
# condOrExpr
class ExprNode(ASTNode):
    # always list all fields in the init method to show the class structure
    def __init__(self, parseTree, typeName):
        self.parseTree = parseTree
        self.left = None
        self.op = ''
        self.right = None  # another expr
        self.env = None
        self.children = []
        self.typeName = typeName # the type (class/interface) this node belongs under

        if parseTree.name == 'unaryNotPlusMinus' or parseTree.name == 'unaryExpr':
            self.op = parseTree.children[0].lex
            self.right = makeNodeFromExpr(parseTree.children[1], typeName)
        else:
            self.left = makeNodeFromExpr(parseTree.children[0], typeName)
            self.op = parseTree.children[1].lex
            self.right = makeNodeFromExpr(parseTree.children[2], typeName)

        self.children.append(self.left)
        self.children.append(self.right)

    def disambigName(self):
        helperDisambigName(self.left)
        helperDisambigName(self.right)

    # use wrong name to stop method from being called until we finish other implemetation
    # def checkType(self):
    def checkType1(self):
        # steps of type checking:
        #   check children's types (children's myType field will be populated with a typeStruct)
        #   check using the rule for current node
        #   make a TypeStruct node and populate myType field for self

        super().checkType() # check children's type first to populate their myType field
        print(self.left)
        # Unary operations:
        if not self.left:
            if self.op == '-' and self.right.myType.isNum():
                self.myType = TypeStruct("int")
                return
            elif self.op == '!' and self.right.myType.name == 'boolean':
                self.myType = self.myType = TypeStruct("boolean")
                return

        # Numeric types
        elif self.left.myType.isNum() and self.right.myType.isNum():
            # Comparisons:
            if self.op in ['==', '!=', '<=', '>=', '>', '<']:
                self.myType = TypeStruct("boolean")
                return
            # numeric operations:
            elif self.op in ['+', '-', '*', '/']:
                self.myType = TypeStruct("int")
                return
        # Boolean operations:
        elif self.left.myType.name == 'boolean' and self.right.myType.name == 'boolean':
            if self.op in ['&&', '&', '|', '||']:
                self.myType = TypeStruct("boolean")
                return
        # Other Comparisons:
        elif self.left.myType.assignable(self.right.myType) or self.right.myType.assignable(self.left.myType):
            if self.op == '==' or self.op == '!=':
                self.myType = TypeStruct("boolean")
                return

        # String concat:
        elif (self.left.myType.name =='java.lang.String' and self.right.myType.name not in ['null', 'void']) \
        or (self.right.myType.name =='java.lang.String' and self.left.myType.name not in ['null', 'void']):
            self.myType = TypeStruct('java.lang.String')
            self.myType.link(self.env)
            return

        elif self.op == 'instanceof':
            #  assume it's correct for now, wait for runtime check
            self.myType = TypeStruct("boolean")
            return

        raise Exception("ERROR: Incompatible types. Left of {} type can't be used with right of {} type on operation {}".format(self.op, self.left.myType.name, self.right.myType.name))




###################################################################################
# fieldAccess primary PERIOD ID
class FieldAccessNode(ASTNode):
    # always list all fields in the init method to show the class structure
    def __init__(self, parseTree, typeName):
        self.parseTree = parseTree
        self.primary = ''
        self.ID = '' # method/fieldName
        self.env = None
        self.children = []
        self.typeName = typeName # the type (class/interface) this node belongs under

        # input: fieldAccess primary PERIOD ID
        self.primary = makeNodeFromAllPrimary(parseTree.children[0], typeName)
        self.ID = NameNode(parseTree.children[2], False, typeName)

        self.children.append(self.primary)
    
    def disambigName(self):
        if not self.primary: # this implies that the ID has nothing that comes before it
            helperDisambigName(self.ID)



###################################################################################
# methodInvoc
class MethodInvNode(ASTNode):
    # always list all fields in the init method to show the class structure
    def __init__(self, parseTree, typeName):
        self.parseTree = parseTree
        self.primary = None  # can be empty
        self.ID = '' # can be either ID or compID
        self.args = None
        self.env = None
        self.children = []
        self.typeName = typeName # the type (class/interface) this node belongs under

        # input parse tree is either: methodInvoc primary PERIOD ID LPAREN args RPAREN
        #  methodInvoc name LPAREN args RPAREN
        self.ID = NameNode(parseTree.children[-4], True, typeName)
        self.args = ArgsNode(parseTree.children[-2], typeName)
        if parseTree.children[0].name == 'primary':
            self.primary = makeNodeFromAllPrimary(parseTree.children[0], typeName)

        self.children.append(self.primary)
        self.children.append(self.args)

    def disambigName(self):
        if not self.primary: # this implies our ID doesn't have anything that comes before it
            if '.' in self.ID:
                helperDisambigName(self.ID)