from AST import ASTNode, getParseTreeNodes
from Environment import Env
from ExprPrimaryNodes import makeNodeFromExpr, makeNodeFromAllPrimary, MethodInvNode
from TheTypeNode import TypeNode, TypeStruct

# Contains:
# block
# for/while/if
# declaration
# return statement


###########################################################
# factory methods
###########################################################

# Creates AST node from statement and statementNoShortIf
def makeNodeFromAllStatement(parseTree, typeName):
    parent = parseTree
    child = parent.children[0]

    if child.name == 'noTailStatement':
        return makeNodeFromNoTailStatement(child, typeName)
    elif child.name == 'ifStatement' or child.name == 'ifElseStatement' or child.name == 'ifElseStatementNoShortIf':
        return IfNode(child, typeName)

    elif child.name == 'forStatement' or child.name == 'forStatementNoShortIf':
        return ForNode(child, typeName)

    elif child.name == 'whileStatement' or child.name == 'whileStatementNoShortIf':
        return WhileNode(child, typeName)

    elif child.name == 'variableDcl':
        return VarDclNode(child, typeName)


# Creates AST node from statementExpr
def makeNodeFromStatementExpr(parseTree, typeName):
    parent = parseTree
    child = parent.children[0]

    if child.name == 'assignment':
        return(makeNodeFromExpr(child, typeName))
    elif child.name == 'methodInvoc':
        return(MethodInvNode(child, typeName))

# Creates AST node from noTailStatement
def makeNodeFromNoTailStatement(parseTree, typeName):
    parent = parseTree
    child = parent.children[0]

    if child.name == 'SEMICO':
        return None
    elif child.name == 'block':
        return BlockNode(child, typeName)
    elif child.name == 'exprStatement':
        child = child.children[0]
        return makeNodeFromStatementExpr(child, typeName)
    elif child.name == 'returnStatement':
        return ReturnNode(child, typeName)

###########################################################
# end of factory methods
###########################################################

# block
# Rules:
# block LBRACK statements RBRACK
class BlockNode(ASTNode):
    def __init__(self, parseTree, typeName):
        self.parseTree = parseTree
        self.statements = []
        self.env = None
        self.children = self.statements
        self.typeName = typeName

        allStatements = getParseTreeNodes(['statement'], parseTree.children[1])

        for node in allStatements:
            self.statements.append(makeNodeFromAllStatement(node, typeName))

        self.children = self.statements

    def buildEnv(self, parentEnv):
        env = Env(parentEnv)
        self.env = env
        return self.env


# variableDcl
# Rules:
# 1. variableDcl type ID
# 2. variableDcl type ID ASSIGN variableInit
class VarDclNode(ASTNode):
    def __init__(self, parseTree, typeName):
        self.parseTree = parseTree
        self.dclType = None
        self.name = None # variable name
        self.variableInit = None # could be none if not intialized
        self.env = None
        self.children = []
        self.typeName = typeName

        self.dclType = TypeNode(parseTree.children[0], typeName)
        self.name = parseTree.children[1].lex

        if len(parseTree.children) > 2:
            # Handling rule: variableInit expr
            self.variableInit = makeNodeFromExpr(parseTree.children[3].children[0], typeName)

        self.myType = self.dclType.myType
        self.children.append(self.dclType)
        self.children.append(self.variableInit)

    def buildEnv(self, parentEnv):
        env = Env(parentEnv)
        self.env = env
        # check if the node already exists in environment
        if parentEnv.findNode(self.name, 'expr'):
            raise Exception("ERROR: Double Local Variable Declaration {}".format(self.name))
        else:
            env.addtoEnv(self)
            return self.env

    def checkType(self):
        if self.variableInit:
            self.variableInit.checkType()
            if not self.myType.assignable(self.variableInit.myType):
                raise Exception("ERROR: Cannot initialize variable of type {} with type {}".format(self.myType.name, self.variableInit.myType.name))


# ifStatement, ifElseStatement, ifElseStatementNoShortIf
# Rules:
# 1. ifStatement IF LPAREN expr RPAREN statement
# 2. ifElseStatement IF LPAREN expr RPAREN statementNoShortIf ELSE statement
# 3. ifElseStatementNoShortIf IF LPAREN expr RPAREN statementNoShortIf ELSE statementNoShortIf
class IfNode(ASTNode):
    def __init__(self, parseTree, typeName):
        self.parseTree = parseTree
        self.env = None
        self.children = []
        self.ifConditional = None # the check for the if statement
        self.ifBody = None # the body of the if statement
        self.elseBody = None # there are if statements without the else statement
        self.typeName = typeName

        self.ifConditional = makeNodeFromExpr(parseTree.children[2], typeName)
        self.ifBody = makeNodeFromAllStatement(parseTree.children[4], typeName)
        if parseTree.name == 'ifElseStatement':
            self.elseBody = makeNodeFromAllStatement(parseTree.children[6], typeName)

        self.children.append(self.ifConditional)
        self.children.append(self.ifBody)
        self.children.append(self.elseBody)

    def checkType(self):
        self.ifConditional.checkType()
        if self.ifConditional.myType.name != 'boolean':
            raise Exception("ERROR: Cannot use non-boolean type for ifConditional.")

        self.ifBody.checkType()
        if self.elseBody:
            self.elseBody.checkType()

# whileStatement, whileStatementNoShortIf
# Rules:
# 1. whileStatement WHILE LPAREN expr RPAREN statement
# 2. whileStatementNoShortIf WHILE LPAREN expr RPAREN statementNoShortIf
class WhileNode(ASTNode):
    def __init__(self, parseTree, typeName):
        self.parseTree = parseTree
        self.env = None
        self.children = []
        self.whileBound = None
        self.whileBody = None
        self.typeName = typeName

        self.whileBound = makeNodeFromExpr(parseTree.children[2], typeName)
        self.whileBody = makeNodeFromAllStatement(parseTree.children[4], typeName)

        self.children.append(self.whileBound)
        self.children.append(self.whileBody)

    def checkType(self):
        self.whileBound.checkType()
        if self.whileBound.myType.name != 'boolean':
            raise Exception("ERROR: Cannot use non-boolean type for whileBound.")
        self.whileBody.checkType()

# returnStatement
# Rules:
# 1. returnStatement RETURN expr SEMICO
# 2. returnStatement RETURN SEMICO
class ReturnNode(ASTNode):
    def __init__(self, parseTree, typeName):
        self.parseTree = parseTree
        self.env = None
        self.children = []
        self.expr = None # could be None
        self.typeName = typeName

        if len(parseTree.children) == 3:
            self.expr = makeNodeFromExpr(parseTree.children[1], typeName)

        self.children.append(self.expr)

    def disambigName(self):
        if self.expr:
            self.expr.disambigName()

    def checkType(self):
        if self.expr:
            self.expr.checkType()
            self.myType = self.expr.myType
        else:
            self.myType = None # this is None as returning a value of type Void is invalid even in a function with type Void

# forStatement and forStatementNoShortIf
# Rules:
# 1. forStatement FOR LPAREN forInit SEMICO forExpr SEMICO forInit RPAREN statement
# 2. forStatementNoShortIf FOR LPAREN forInit SEMICO forExpr SEMICO forInit RPAREN statementNoShortIf
class ForNode(ASTNode):
    def __init__(self, parseTree, typeName):
        self.parseTree = parseTree
        self.forInit = None # could be None
        self.forBound = None # could be None
        self.forUpdate = None # could be None
        self.bodyStatement = None
        self.env = None
        self.children = []
        self.typeName = typeName

        InitFlag = False # flag for forInit vs forUpdate
        for node in parseTree.children:
            if node.name == 'forInit':

                # Handling case where forInit could derive empty
                if not node.children:
                    InitFlag = True
                    continue

                statementNode = node.children[0]
                exprAstNode = None

                if statementNode.name == 'statementExpr':
                    exprAstNode = makeNodeFromStatementExpr(statementNode, typeName)
                elif statementNode.name == 'variableDcl':
                    exprAstNode = VarDclNode(statementNode, typeName)

                if not InitFlag:
                    self.forInit = exprAstNode
                    InitFlag = True

                else:
                    self.forUpdate = exprAstNode

            elif node.name == 'forExpr':

                # Handling case where forExpr could derive empty
                if not node.children:
                    continue

                self.forBound = makeNodeFromExpr(node.children[0], typeName)

            elif node.name == 'statement' or node.name == 'statementNoShortIf':
                self.bodyStatement = makeNodeFromAllStatement(node, typeName)

        self.children.append(self.forInit)
        self.children.append(self.forBound)
        self.children.append(self.forUpdate)
        self.children.append(self.bodyStatement)

    def checkType(self):
        self.forInit.checkType()
        self.forBound.checkType()  # need resolving var declared in forInit to use it in forBound
        if self.forBound.myType.name != 'boolean':
            raise Exception("ERROR: Cannot use non-boolean type for forBound.")
        self.forUpdate.checkType()
        self.bodyStatement.checkType()