Skip to content
Snippets Groups Projects
AST.py 2.97 KiB
import pprint

class ASTNode():
    # Base class for a node in the AST
    # Default fields:
    #   children : a list of child nodes, these nodes might
    #       be stored in other fields of the object, we double store the pointers
    #       for easier recursion
    #   parseTree: stores the parse tree that corresponds to this AST node
    #        This is a redundancy that can be cleaned up after the AST construction,
    #        but we will keep it for easier debugging, since effeciency is not a concern here
    def __init__(self, parseTree):
        self.parseTree = parseTree

    # Do certains actions on every node of the AST tree
    #   call the same method in each class and its children recursively
    #   the methods that represent an action would return arguments to be used in
    #   the child nodes' method if neccessary
    def recurseAction(self, actionName, args=None):
        func = getattr(self, actionName)
        result = None
        if func and args:
            result = func(args)
        elif func:
            result = func()
        for c in self.children:
            c.recurseAction(actionName, result)

    def buildEnv(self, parentEnv):
        self.env = parentEnv
        return parentEnv

    def printNodePretty(self, prefix=0):
        pp = pprint.PrettyPrinter(indent=prefix)
        pp.pprint(self.__class__.__name__)
        pp.pprint(vars(self))
        pp.pprint("-----children-----")
        prefix += 1
        return prefix

    def printTree(self):
        self.recurseAction('printNodePretty')

    def printEnv(self, prefix=0):
        pp = pprint.PrettyPrinter(indent=prefix)
        pp.pprint(self.__class__.__name__)
        if self.env:
            pp.pprint(vars(self.env))
        pp.pprint("-----children-----")
        prefix += 1
        return prefix




# Utils ######################################################

#   given a parseTree and a list of names, traverse the tree
#        to return a list of tree nodes(on the same level) that
#        has one of those names. A termination list can also be supplied
#        to stop the recursive search at the specified nodes
def getParseTreeNodes(names, tree, terminateList = []):
    result = []
    if tree.name in names:
        return result.append(tree)
    if not tree.children:
        return []
    for n in tree.children:
        if n.name in names:
            result.append(n)
        elif n.name in terminateList:
            continue
        else:
            result.extend(getParseTreeNodes(names, n))
    return result


# input: a parse tree node with its name == 'type'
# output: (isPrimitiveType: Bool, typeName: String) of a type
def getTypeName(node):
    isPrimType = False
    typeName = ''
    nameNodes = getParseTreeNodes(['BOOLEAN', 'BYTE', 'CHAR', 'INT', 'SHORT'], node)
    if nameNodes:
        isPrimType = True
    else:
        # get refType
        nameNodes = getParseTreeNodes(['ID', 'COMPID'], node)
    for n in nameNodes:
        typeName = n.lex
    return (isPrimType, typeName)