diff --git a/AST.py b/AST.py
index 4a219c7b21f8e86e79b071927d2ccf1d55925193..f8ee30d555676c9cb9336bb0aedbb66605e2ae77 100644
--- a/AST.py
+++ b/AST.py
@@ -12,6 +12,7 @@ class ASTNode():
     def __init__(self, parseTree):
         self.parseTree = parseTree
         self.children = []
+        self.myType = "" # either empty string or a TypeStruct
 
     # Do certains actions on every node of the AST tree
     #   call the same method in each class and its children recursively
@@ -40,7 +41,7 @@ class ASTNode():
                 else:
                     c.recurseBuildEnv(result)
                 if c.__class__.__name__ == 'VarDclNode':
-                    preVarDcl = c     
+                    preVarDcl = c
 
 
     def buildEnv(self, parentEnv):
@@ -53,6 +54,11 @@ class ASTNode():
     def checkHierarchy(self):
         pass
 
+    def checkType(self):
+        for c in self.children:
+            if c and hasattr(c, 'checkType'):
+                c.checkType()
+
     def printNodePretty(self, prefix=0):
         pp = pprint.PrettyPrinter(indent=prefix)
         pp.pprint(self.__class__.__name__)
diff --git a/AstBuilding.py b/AstBuilding.py
index 27803a9225810aa6218ff121358b22fa38c6dcce..fc92c15e92f7be5332c94f5dfea4f9227950685f 100644
--- a/AstBuilding.py
+++ b/AstBuilding.py
@@ -42,4 +42,7 @@ def buildEnvAndLink(ASTs):
     for t in ASTs:
         t[1].recurseAction("checkHierarchy")
 
+    for t in ASTs:
+        t[1].checkType()
+
 #######################################################
diff --git a/ExprPrimaryNodes.py b/ExprPrimaryNodes.py
index 6ac4f8372528f858c0cfdbb4327cba537b3acf61..9f62d5954e8843e039d20ff8405fbdba03a74f09 100644
--- a/ExprPrimaryNodes.py
+++ b/ExprPrimaryNodes.py
@@ -1,7 +1,7 @@
 from AST import ASTNode, getParseTreeNodes
 from Environment import Env
 from UnitNodes import LiteralNode
-from TheTypeNode import TypeNode
+from TheTypeNode import TypeNode, TypeStruct
 
 # file containing smaller (lower level nodes) in the AST
 # nodes in this file:
@@ -48,7 +48,7 @@ def makeNodeFromAllPrimary(parseTree):
             if parseTree.children[0].children[0].name == 'arrayAccess':
                 return ArrayAccessNode(parseTree.children[0].children[0])
             parseTree = parseTree.children[0].children[0]
-    
+
     if parseTree.name == 'primary':
         if parseTree.children[0].name == 'arrayAccess':
             return ArrayAccessNode(parseTree.children[0])
@@ -207,6 +207,16 @@ class ExprNode(ASTNode):
         self.children.append(self.left)
         self.children.append(self.right)
 
+    def checkType(self):
+        super().checkType() # check children's type first to populate their myType field
+        if self.op == '==' or self.op == '!=':
+            if (self.left.myType == self.right.myType) or (self.left.myType.isNum() and self.right.myType.isNum()):
+                self.myType = TypeStruct("boolean")
+            else:
+                raise Exception('ERROR: Incompatible types for comparison.')
+        # TODO: type check other types of expr
+
+
 
 ###################################################################################
 # fieldAccess primary PERIOD ID
diff --git a/MemberNodes.py b/MemberNodes.py
index 0b07d9eb7bcaca246e678c79ee425ed2009dd41c..23c65a8dec8c4970aa71dde8a39863811a2fea24 100644
--- a/MemberNodes.py
+++ b/MemberNodes.py
@@ -61,7 +61,7 @@ class MethodNode(ASTNode):
         for n in nameNodes:
             paramNode = ParamNode(n)
             self.params.append(paramNode)
-            self.paramTypes += paramNode.paramType.name
+            self.paramTypes += paramNode.paramType.myType.name
 
         nameNodes = getParseTreeNodes(['type', 'VOID'], parseTree, ['methodBody', 'params'])
         for n in nameNodes:
diff --git a/TheTypeNode.py b/TheTypeNode.py
index 0752c341b632cbad627958594c3b068ee9f2f51e..0ff37f19bf76bc0c2f8ecf1ab42e0fdf6d497e15 100644
--- a/TheTypeNode.py
+++ b/TheTypeNode.py
@@ -1,53 +1,80 @@
 from AST import ASTNode, getParseTreeNodes
 ##################################################################################
-# type: primitiveType, ArrayType, RefType
+# TypeNode: an AST node represents a type
+# TypeStruct: a struct holding type information for type checking
+
 class TypeNode(ASTNode):
     # always list all fields in the init method to show the class structure
     def __init__(self, parseTree):
         self.parseTree = parseTree
         self.name = ''
-        self.isArray = False
-        self.isPrimitive = False
         self.env = None
         self.children = []
-        self.myType = None # pointer pointing to the type
+        self.myType = "" # empty string or typeStruct
 
         if parseTree == 'VOID':
-            self.name = 'void'
-            self.isPrimitive = True
+            self.myType = TypeStruct('void')
         else:
             nameNodes = getParseTreeNodes(['BOOLEAN', 'BYTE', 'CHAR', 'INT', 'SHORT'], parseTree)
             if nameNodes:
-                self.isPrimitive = True
-                self.name = nameNodes[0].lex
+                self.myType = TypeStruct(nameNodes[0].lex)
             else:
-                self.name = getParseTreeNodes(['ID', 'COMPID'], parseTree)[0].lex
+                self.myType = TypeStruct(getParseTreeNodes(['ID', 'COMPID'], parseTree)[0].lex)
 
             nameNodes = getParseTreeNodes(['LSQRBRACK'], parseTree)
             if nameNodes:
-                self.isArray = True
+                self.myType.isArray = True
 
     def __eq__(self, other):
-        return self.name == other.name
+        return self.myType == other.myType
 
     def linkType(self):
+        self.myType.link(self.env)
+
+
+class TypeStruct():
+    def __init__(self, name):
+        self.isArray = False
+        self.isPrimitive = False
+        self.typePointer = None
+        self.name = name
+        if name in ['boolean', 'byte', 'char', 'int', 'short', 'void']:
+            self.isPrimitive = True
+
+    def link(self, env):
         if not self.isPrimitive:
-            self.myType = self.env.getNode(self.name, 'type')
-            self.name = self.myType.canonName  # Use canonName instead of simple name for comparison
-        else:
-            self.myType = self.name
+            self.typePointer = env.getNode(self.name, 'type')
+            self.name = self.typePointer.canonName  # Use canonName instead of simple name for comparison
+
+    def __eq__(self, other):
+        return self.name == other.name
 
     def isNum(self):
         return self.name in ['int', 'short', 'char', 'byte']
 
-    # if self is assignable to input typeNode: left := self
-    def assignable(self, left):
-        if self == left \
-        or (self.name in ['short', 'char', 'byte'] and left.name == 'int') \
-        or (self.name == 'byte' and left.name == 'short') \
-        or (not left.isPrimitive and self.name == 'null'):
-            return True
-        return False
+    # if self is assignable to input typeNode: self := right
+    # right is either a TypeNode or a LiteralNode
+    def assignable(self, right):
+        if self.isArray == right.isArray:
+            if self == right \
+            or (right.name in ['short', 'char', 'byte'] and self.name == 'int') \
+            or (right.name == 'byte' and self.name == 'short') \
+            or (not self.isPrimitive and right.name == 'null'):
+                return True
+            #  check if self is super of right
+            elif ((not self.isPrimitive) and (not right.isPrimitive)) \
+            and (self.name in getSupers(right.typePointer)):
+                return True
+            return False
 
+        return False
 
-    # is java.Object added to super class of everything/
+# helper: get list of all super class/interface of a ClassInterNode
+def getSupers(classType):
+    result = []
+    if not classType.super:
+        return result
+    for s in classType.super:
+        result.append(s.canonName)
+        result.extend(getSupers(s))
+    return result