From 014420b3c0dfc75b844c623840424931777700bc Mon Sep 17 00:00:00 2001
From: Xun Yang <x299yang@uwaterloo.ca>
Date: Tue, 3 Mar 2020 19:45:16 -0500
Subject: [PATCH] some type checking stuff

---
 AST.py              |  8 ++++-
 AstBuilding.py      |  3 ++
 ExprPrimaryNodes.py | 14 +++++++--
 MemberNodes.py      |  2 +-
 TheTypeNode.py      | 75 ++++++++++++++++++++++++++++++---------------
 5 files changed, 74 insertions(+), 28 deletions(-)

diff --git a/AST.py b/AST.py
index 4a219c7..f8ee30d 100644
--- a/AST.py
+++ b/AST.py
@@ -12,6 +12,7 @@ class ASTNode():
     def __init__(self, parseTree):
         self.parseTree = parseTree
         self.children = []
+        self.myType = "" # either empty string or a TypeStruct
 
     # Do certains actions on every node of the AST tree
     #   call the same method in each class and its children recursively
@@ -40,7 +41,7 @@ class ASTNode():
                 else:
                     c.recurseBuildEnv(result)
                 if c.__class__.__name__ == 'VarDclNode':
-                    preVarDcl = c     
+                    preVarDcl = c
 
 
     def buildEnv(self, parentEnv):
@@ -53,6 +54,11 @@ class ASTNode():
     def checkHierarchy(self):
         pass
 
+    def checkType(self):
+        for c in self.children:
+            if c and hasattr(c, 'checkType'):
+                c.checkType()
+
     def printNodePretty(self, prefix=0):
         pp = pprint.PrettyPrinter(indent=prefix)
         pp.pprint(self.__class__.__name__)
diff --git a/AstBuilding.py b/AstBuilding.py
index 27803a9..fc92c15 100644
--- a/AstBuilding.py
+++ b/AstBuilding.py
@@ -42,4 +42,7 @@ def buildEnvAndLink(ASTs):
     for t in ASTs:
         t[1].recurseAction("checkHierarchy")
 
+    for t in ASTs:
+        t[1].checkType()
+
 #######################################################
diff --git a/ExprPrimaryNodes.py b/ExprPrimaryNodes.py
index 6ac4f83..9f62d59 100644
--- a/ExprPrimaryNodes.py
+++ b/ExprPrimaryNodes.py
@@ -1,7 +1,7 @@
 from AST import ASTNode, getParseTreeNodes
 from Environment import Env
 from UnitNodes import LiteralNode
-from TheTypeNode import TypeNode
+from TheTypeNode import TypeNode, TypeStruct
 
 # file containing smaller (lower level nodes) in the AST
 # nodes in this file:
@@ -48,7 +48,7 @@ def makeNodeFromAllPrimary(parseTree):
             if parseTree.children[0].children[0].name == 'arrayAccess':
                 return ArrayAccessNode(parseTree.children[0].children[0])
             parseTree = parseTree.children[0].children[0]
-    
+
     if parseTree.name == 'primary':
         if parseTree.children[0].name == 'arrayAccess':
             return ArrayAccessNode(parseTree.children[0])
@@ -207,6 +207,16 @@ class ExprNode(ASTNode):
         self.children.append(self.left)
         self.children.append(self.right)
 
+    def checkType(self):
+        super().checkType() # check children's type first to populate their myType field
+        if self.op == '==' or self.op == '!=':
+            if (self.left.myType == self.right.myType) or (self.left.myType.isNum() and self.right.myType.isNum()):
+                self.myType = TypeStruct("boolean")
+            else:
+                raise Exception('ERROR: Incompatible types for comparison.')
+        # TODO: type check other types of expr
+
+
 
 ###################################################################################
 # fieldAccess primary PERIOD ID
diff --git a/MemberNodes.py b/MemberNodes.py
index 0b07d9e..23c65a8 100644
--- a/MemberNodes.py
+++ b/MemberNodes.py
@@ -61,7 +61,7 @@ class MethodNode(ASTNode):
         for n in nameNodes:
             paramNode = ParamNode(n)
             self.params.append(paramNode)
-            self.paramTypes += paramNode.paramType.name
+            self.paramTypes += paramNode.paramType.myType.name
 
         nameNodes = getParseTreeNodes(['type', 'VOID'], parseTree, ['methodBody', 'params'])
         for n in nameNodes:
diff --git a/TheTypeNode.py b/TheTypeNode.py
index 0752c34..0ff37f1 100644
--- a/TheTypeNode.py
+++ b/TheTypeNode.py
@@ -1,53 +1,80 @@
 from AST import ASTNode, getParseTreeNodes
 ##################################################################################
-# type: primitiveType, ArrayType, RefType
+# TypeNode: an AST node represents a type
+# TypeStruct: a struct holding type information for type checking
+
 class TypeNode(ASTNode):
     # always list all fields in the init method to show the class structure
     def __init__(self, parseTree):
         self.parseTree = parseTree
         self.name = ''
-        self.isArray = False
-        self.isPrimitive = False
         self.env = None
         self.children = []
-        self.myType = None # pointer pointing to the type
+        self.myType = "" # empty string or typeStruct
 
         if parseTree == 'VOID':
-            self.name = 'void'
-            self.isPrimitive = True
+            self.myType = TypeStruct('void')
         else:
             nameNodes = getParseTreeNodes(['BOOLEAN', 'BYTE', 'CHAR', 'INT', 'SHORT'], parseTree)
             if nameNodes:
-                self.isPrimitive = True
-                self.name = nameNodes[0].lex
+                self.myType = TypeStruct(nameNodes[0].lex)
             else:
-                self.name = getParseTreeNodes(['ID', 'COMPID'], parseTree)[0].lex
+                self.myType = TypeStruct(getParseTreeNodes(['ID', 'COMPID'], parseTree)[0].lex)
 
             nameNodes = getParseTreeNodes(['LSQRBRACK'], parseTree)
             if nameNodes:
-                self.isArray = True
+                self.myType.isArray = True
 
     def __eq__(self, other):
-        return self.name == other.name
+        return self.myType == other.myType
 
     def linkType(self):
+        self.myType.link(self.env)
+
+
+class TypeStruct():
+    def __init__(self, name):
+        self.isArray = False
+        self.isPrimitive = False
+        self.typePointer = None
+        self.name = name
+        if name in ['boolean', 'byte', 'char', 'int', 'short', 'void']:
+            self.isPrimitive = True
+
+    def link(self, env):
         if not self.isPrimitive:
-            self.myType = self.env.getNode(self.name, 'type')
-            self.name = self.myType.canonName  # Use canonName instead of simple name for comparison
-        else:
-            self.myType = self.name
+            self.typePointer = env.getNode(self.name, 'type')
+            self.name = self.typePointer.canonName  # Use canonName instead of simple name for comparison
+
+    def __eq__(self, other):
+        return self.name == other.name
 
     def isNum(self):
         return self.name in ['int', 'short', 'char', 'byte']
 
-    # if self is assignable to input typeNode: left := self
-    def assignable(self, left):
-        if self == left \
-        or (self.name in ['short', 'char', 'byte'] and left.name == 'int') \
-        or (self.name == 'byte' and left.name == 'short') \
-        or (not left.isPrimitive and self.name == 'null'):
-            return True
-        return False
+    # if self is assignable to input typeNode: self := right
+    # right is either a TypeNode or a LiteralNode
+    def assignable(self, right):
+        if self.isArray == right.isArray:
+            if self == right \
+            or (right.name in ['short', 'char', 'byte'] and self.name == 'int') \
+            or (right.name == 'byte' and self.name == 'short') \
+            or (not self.isPrimitive and right.name == 'null'):
+                return True
+            #  check if self is super of right
+            elif ((not self.isPrimitive) and (not right.isPrimitive)) \
+            and (self.name in getSupers(right.typePointer)):
+                return True
+            return False
 
+        return False
 
-    # is java.Object added to super class of everything/
+# helper: get list of all super class/interface of a ClassInterNode
+def getSupers(classType):
+    result = []
+    if not classType.super:
+        return result
+    for s in classType.super:
+        result.append(s.canonName)
+        result.extend(getSupers(s))
+    return result
-- 
GitLab