Skip to content
Snippets Groups Projects
LineNodes.py 18.06 KiB
from AST import ASTNode, getParseTreeNodes, getASTNode
from Environment import Env
from ExprPrimaryNodes import makeNodeFromExpr, makeNodeFromAllPrimary, MethodInvNode, ClassCreateNode
from TheTypeNode import TypeNode, TypeStruct
from CodeGenUtils import p, getCFlowLabel, iffalse

# Contains:
# block
# for/while/if
# declaration
# return statement


###########################################################
# factory methods
###########################################################

# Creates AST node from statement and statementNoShortIf
def makeNodeFromAllStatement(parseTree, typeName):
    parent = parseTree
    child = parent.children[0]

    if child.name == 'noTailStatement':
        return makeNodeFromNoTailStatement(child, typeName)
    elif child.name == 'ifStatement' or child.name == 'ifElseStatement' or child.name == 'ifElseStatementNoShortIf':
        return IfNode(child, typeName)

    elif child.name == 'forStatement' or child.name == 'forStatementNoShortIf':
        return ForNode(child, typeName)

    elif child.name == 'whileStatement' or child.name == 'whileStatementNoShortIf':
        return WhileNode(child, typeName)

    elif child.name == 'variableDcl':
        return VarDclNode(child, typeName, True)


# Creates AST node from statementExpr
def makeNodeFromStatementExpr(parseTree, typeName):
    parent = parseTree
    child = parent.children[0]

    if child.name == 'assignment':
        return(makeNodeFromExpr(child, typeName))
    elif child.name == 'methodInvoc':
        return(MethodInvNode(child, typeName))
    elif child.name == "unqualCreate":
        return(ClassCreateNode(child, typeName))

# Creates AST node from noTailStatement
def makeNodeFromNoTailStatement(parseTree, typeName):
    parent = parseTree
    child = parent.children[0]

    if child.name == 'SEMICO':
        return None
    elif child.name == 'block':
        return BlockNode(child, typeName)
    elif child.name == 'exprStatement':
        child = child.children[0]
        return makeNodeFromStatementExpr(child, typeName)
    elif child.name == 'returnStatement':
        return ReturnNode(child, typeName)

###########################################################
# end of factory methods
###########################################################

# block
# Rules:
# block LBRACK statements RBRACK
class BlockNode(ASTNode):
    def __init__(self, parseTree, typeName):
        self.parseTree = parseTree
        self.statements = []
        self.env = None
        self.children = self.statements
        self.typeName = typeName

        allStatements = getParseTreeNodes(['statement'], parseTree.children[1])

        for node in allStatements:
            self.statements.append(makeNodeFromAllStatement(node, typeName))

        self.children = self.statements

    def buildEnv(self, parentEnv):
        env = Env(parentEnv)
        self.env = env
        return self.env

    def reachCheck(self, inMaybe):

        if not inMaybe:
            raise Exception("ERROR: cannot reach block node in class {}".format(self.typeName))

        # Checking reachability of each statement
        prevOut = inMaybe # Note: in[S1] = in[L]
        for statement in self.statements:
            if statement:
                statement.reachCheck(prevOut)
                prevOut = statement.outMaybe
            else: # checking for empty statements
                if not prevOut:
                    raise Exception("ERROR: empty statement is unreachable at block node for class {}".format(self.typeName))
        self.outMaybe = prevOut
        return

# variableDcl
# Rules:
# 1. variableDcl type ID
# 2. variableDcl type ID ASSIGN variableInit
class VarDclNode(ASTNode):
    def __init__(self, parseTree, typeName, checkAssign=False, isField=None):
        self.parseTree = parseTree
        self.dclType = None
        self.name = None # variable name
        self.variableInit = None # could be none if not intialized
        self.env = None
        self.children = []
        self.typeName = typeName
        self.isField = isField

        self.dclType = TypeNode(parseTree.children[0], typeName)
        self.name = parseTree.children[1].lex

        self.offset = 0 # offset on the stack

        if len(parseTree.children) > 2:
            # Handling rule: variableInit expr
            self.variableInit = makeNodeFromExpr(parseTree.children[3].children[0], typeName)

        # Checking for definite assignment
        if checkAssign:
            if not self.variableInit:
                raise Exception("ERROR: local variable declaration {} is not assigned in class {}".format(self.name, self.typeName))
            # Checking if the local variable appears in it's own intializer
            nameNodes = getASTNode(["NameNode"], self.variableInit)
            for node in nameNodes:
                if self.name in node.IDs:
                    raise Exception("ERROR: local variable {} appears in it's own intialization in class {}".format(self.name, self.typeName))

        self.myType = self.dclType.myType
        self.children.append(self.dclType)
        self.children.append(self.variableInit)

    def buildEnv(self, parentEnv):
        env = Env(parentEnv)
        self.env = env
        # check if the node already exists in environment
        if parentEnv.findNode(self.name, 'expr'):
            raise Exception("ERROR: Double Local Variable Declaration {}".format(self.name))
        else:
            env.addtoEnv(self)
            return self.env

    def checkType(self):
        if self.variableInit:
            self.variableInit.checkType()
            if not self.myType.assignable(self.variableInit.myType):
                raise Exception("ERROR: Cannot initialize variable of type {} with type {}".format(self.myType.name, self.variableInit.myType.name))

    def reachCheck(self, inMaybe):
        if not inMaybe:
            raise Exception("ERROR: not reaching a variable declaration statement for var {} in class {}".format(self.name, self.typeName))
        self.outMaybe = inMaybe

    def codeGen(self):
        if hasattr(self, "code") and self.code != "":
            return

        self.code = ""
        if self.variableInit:
            self.variableInit.codeGen()
            self.code += self.variableInit.code
            # move init result to var location
            self.code += p("mov", "[ebp - " + str(self.offset) + "]", "eax")

    def addr(self):
        result = p("mov", "eax", "ebp")
        result += p("sub", "eax", str(self.offset))
        return result


# ifStatement, ifElseStatement, ifElseStatementNoShortIf
# Rules:
# 1. ifStatement IF LPAREN expr RPAREN statement
# 2. ifElseStatement IF LPAREN expr RPAREN statementNoShortIf ELSE statement
# 3. ifElseStatementNoShortIf IF LPAREN expr RPAREN statementNoShortIf ELSE statementNoShortIf
class IfNode(ASTNode):
    def __init__(self, parseTree, typeName):
        self.parseTree = parseTree
        self.env = None
        self.children = []
        self.ifConditional = None # the check for the if statement
        self.ifBody = None # the body of the if statement
        self.elseBody = None # there are if statements without the else statement
        self.typeName = typeName

        self.ifConditional = makeNodeFromExpr(parseTree.children[2], typeName)
        self.ifBody = makeNodeFromAllStatement(parseTree.children[4], typeName)
        if parseTree.name == 'ifElseStatement' or parseTree.name == "ifElseStatementNoShortIf":
            self.elseBody = makeNodeFromAllStatement(parseTree.children[6], typeName)

        self.children.append(self.ifConditional)
        self.children.append(self.ifBody)
        self.children.append(self.elseBody)

    def checkType(self):
        self.ifConditional.checkType()
        if self.ifConditional.myType.name != 'boolean':
            raise Exception("ERROR: Cannot use non-boolean type for ifConditional.")

        self.ifBody.checkType()
        if self.elseBody:
            self.elseBody.checkType()

    def reachCheck(self, inMaybe):
        if not inMaybe:
            raise Exception("in[s] = no for IfNode in class {}".format(self.typeName))

        # No need to check for empty statement, since in[ifBody] = inMaybe
        if self.ifBody:
            self.ifBody.reachCheck(inMaybe)

        if not self.elseBody:
            # L : if (E) S
            # in[S] = in[L]
            # out[L] = in[L]
            self.outMaybe = inMaybe
        else:
            # L : if (E) S1 else S2
            # in[S1] = in[L]
            # in[S2] = in[L]
            # out[L] = out[S1] V out[S2]
            if self.elseBody:
                self.elseBody.reachCheck(inMaybe)
                self.outMaybe = (self.ifBody.outMaybe or self.elseBody.outMaybe)
            else:
                # no need to check reachability for empty elseBody, since in[elseBody] = in[L]
                self.outMaybe = self.ifBody.outMaybe

    def codeGen(self):
        if hasattr(self, "code") and self.code != "":
            return

        n = getCFlowLabel()
        elseLabel = "_else" + n
        endLabel = "_end" + n

        self.code = "; start of if clause" + n + "\n"

        if self.elseBody:
            self.code += iffalse(self.ifConditional, elseLabel)
        else:
            self.code += iffalse(self.ifConditional, endLabel)

        self.code += "; start of ifBody code for if clause" + n + "\n"

        self.ifBody.codeGen()
        self.code += self.ifBody.code

        if self.elseBody:
            self.code += "; start of elseBody code for if clause" + n + "\n"
            self.code += p("jmp", endLabel)
            self.code += elseLabel + ":\n"

            self.elseBody.codeGen()
            self.code += self.elseBody.code

        self.code += endLabel + ":\n"
        self.code += "; end of if clause" + n + "\n"



# whileStatement, whileStatementNoShortIf
# Rules:
# 1. whileStatement WHILE LPAREN expr RPAREN statement
# 2. whileStatementNoShortIf WHILE LPAREN expr RPAREN statementNoShortIf
class WhileNode(ASTNode):
    def __init__(self, parseTree, typeName):
        self.parseTree = parseTree
        self.env = None
        self.children = []
        self.whileBound = None
        self.whileBody = None
        self.typeName = typeName

        self.whileBound = makeNodeFromExpr(parseTree.children[2], typeName)
        self.whileBody = makeNodeFromAllStatement(parseTree.children[4], typeName)

        self.children.append(self.whileBound)
        self.children.append(self.whileBody)

    def checkType(self):
        if self.whileBound:
            self.whileBound.checkType()
            if self.whileBound.myType.name != 'boolean':
                raise Exception("ERROR: Cannot use non-boolean type for whileBound.")
        if self.whileBody:
            self.whileBody.checkType()

    def reachCheck(self, inMaybe):
        if not inMaybe:
            raise Exception("in[s] = no for WhileNode in class {}".format(self.typeName))

        # Checking constant expression in whileBound
        con = None # default to None: i.e. not a constant expression
        if hasattr(self.whileBound, "getConstant"):
            con = self.whileBound.getConstant()

        # Setting self.outMaybe
        inMaybeWhileBody = inMaybe # the input to reachCheck on whileBody
        # General case: while(E) S
        if con == None:
            self.outMaybe = inMaybe
        # while(false) S
        elif con == False or con == 0:
            self.outMaybe = inMaybe
            inMaybeWhileBody = False
        else: # either an integer that's not zero or True
            self.outMaybe = False

        # Checking reachability on whileBody
        if self.whileBody:
            self.whileBody.reachCheck(inMaybeWhileBody)
        elif not inMaybeWhileBody: # empty block/empty statement that's unreachable
            raise Exception("ERROR: unreachable empty statment/block at while node for class {}".format(self.typeName))

        return

    def codeGen(self):
        if hasattr(self, "code") and self.code != "":
            return

        n = getCFlowLabel()
        startLabel =  "_start" + n
        endLabel = "_end" + n

        self.code = "; start of while clause" + n + "\n"
        self.code += startLabel + ":\n"
        self.code += iffalse(self.whileBound, endLabel)

        if self.whileBody:
            self.whileBody.codeGen()
            self.code += self.whileBody.code
        self.code += p("jmp", startLabel)

        self.code += endLabel + ":\n"
        self.code += "; end of while clause" + n + "\n"



# returnStatement
# Rules:
# 1. returnStatement RETURN expr SEMICO
# 2. returnStatement RETURN SEMICO
class ReturnNode(ASTNode):
    def __init__(self, parseTree, typeName):
        self.parseTree = parseTree
        self.env = None
        self.children = []
        self.expr = None # could be None
        self.typeName = typeName
        self.method = None # methodNode that this return belongs to

        if len(parseTree.children) == 3:
            self.expr = makeNodeFromExpr(parseTree.children[1], typeName)

        self.children.append(self.expr)

    def disambigName(self):
        if self.expr:
            self.expr.disambigName()

    def checkType(self):
        if self.expr:
            self.expr.checkType()
            self.myType = self.expr.myType
        else:
            self.myType = None # this is None as returning a value of type Void is invalid even in a function with type Void

    def reachCheck(self, inMaybe):
        if not inMaybe:
            raise Exception("ERROR: return statement unreachable at class {}".format(self.typeName))
        self.outMaybe = False # out[L] = no

    def codeGen(self):
        if hasattr(self, "code") and self.code != "":
            return

        self.code = ""
        if self.expr:
            self.expr.codeGen()
            self.code += self.expr.code
        self.code += p("jmp", self.method.label + "_end", None, "Return to the end of this method.")

# forStatement and forStatementNoShortIf
# Rules:
# 1. forStatement FOR LPAREN forInit SEMICO forExpr SEMICO forInit RPAREN statement
# 2. forStatementNoShortIf FOR LPAREN forInit SEMICO forExpr SEMICO forInit RPAREN statementNoShortIf
class ForNode(ASTNode):
    def __init__(self, parseTree, typeName):
        self.parseTree = parseTree
        self.forInit = None # could be None
        self.forBound = None # could be None
        self.forUpdate = None # could be None
        self.bodyStatement = None
        self.env = None
        self.children = []
        self.typeName = typeName

        InitFlag = False # flag for forInit vs forUpdate
        for node in parseTree.children:
            if node.name == 'forInit':

                # Handling case where forInit could derive empty
                if not node.children:
                    InitFlag = True
                    continue

                statementNode = node.children[0]
                exprAstNode = None

                if statementNode.name == 'statementExpr':
                    exprAstNode = makeNodeFromStatementExpr(statementNode, typeName)
                elif statementNode.name == 'variableDcl':
                    exprAstNode = VarDclNode(statementNode, typeName)

                if not InitFlag:
                    self.forInit = exprAstNode
                    InitFlag = True

                else:
                    self.forUpdate = exprAstNode

            elif node.name == 'forExpr':

                # Handling case where forExpr could derive empty
                if not node.children:
                    continue

                self.forBound = makeNodeFromExpr(node.children[0], typeName)

            elif node.name == 'statement' or node.name == 'statementNoShortIf':
                self.bodyStatement = makeNodeFromAllStatement(node, typeName)

        self.children.append(self.forInit)
        self.children.append(self.forBound)
        self.children.append(self.forUpdate)
        self.children.append(self.bodyStatement)

    def checkType(self):
        if self.forInit:
            self.forInit.checkType()
        if self.forBound:
            self.forBound.checkType()  # need resolving var declared in forInit to use it in forBound
            if self.forBound.myType.name != 'boolean':
                raise Exception("ERROR: Cannot use non-boolean type for forBound.")
        if self.forUpdate:
            self.forUpdate.checkType()
        self.bodyStatement.checkType()

    def reachCheck(self, inMaybe):
        if not inMaybe:
            raise Exception("in[s] = no for ForNode in class {}".format(self.typeName))

        # Checking constant expression in whileBound
        con = None # default to None: i.e. not a constant expression
        if hasattr(self.forBound, "getConstant"):
            con = self.forBound.getConstant()

        # Setting self.outMaybe
        inMaybeForBody = inMaybe # the input to reachCheck on bodyStatement
        # General case
        if con == None:
            self.outMaybe = inMaybe
        # for(false) S
        elif con == False or con == 0:
            self.outMaybe = inMaybe
            inMaybeForBody = False
        else: # either an integer that's not zero or True
            self.outMaybe = False

        # Checking reachability on whileBody
        if self.bodyStatement:
            self.bodyStatement.reachCheck(inMaybeForBody)
        elif inMaybeForBody: # checking if the empty forBody can be reached
            raise Exception("ERROR: unreachable empty statement/block at for node at class {}".format(self.typeName))
        return

    def codeGen(self):
        if hasattr(self, "code") and self.code != "":
            return

        n = getCFlowLabel()
        startLabel =  "_start" + n
        endLabel = "_end" + n

        self.code = "; start of for clause" + n + "\n"
        if self.forInit:
            self.forInit.codeGen()
            self.code += self.forInit.code
        self.code += startLabel + ":\n"
        if self.forBound:
            self.code += iffalse(self.forBound, endLabel)

        if self.bodyStatement:
            self.bodyStatement.codeGen()
            self.code += self.bodyStatement.code
        if self.forUpdate:
            self.forUpdate.codeGen()
            self.code += self.forUpdate.code
        self.code += p("jmp", startLabel)

        self.code += endLabel + ":\n"
        self.code += "; end of for clause" + n + "\n"