From 875b4cf14a9f5dd938f78c682b0c366a34c5b92b Mon Sep 17 00:00:00 2001
From: Xun Yang <x299yang@uwaterloo.ca>
Date: Sat, 11 Apr 2020 11:05:01 -0400
Subject: [PATCH] make SIT work for multi level inheritance

---
 AstBuilding.py |  1 +
 MemberNodes.py | 10 +++++++++-
 TypeNodes.py   | 35 +++++++++++++++++++----------------
 3 files changed, 29 insertions(+), 17 deletions(-)

diff --git a/AstBuilding.py b/AstBuilding.py
index 11320e6..4e889e1 100644
--- a/AstBuilding.py
+++ b/AstBuilding.py
@@ -72,6 +72,7 @@ def codeGenPrep(ASTs):
     if interM: # prep SIT
         for i in range(len(interM)):
             interM[i].offset = i * 4
+            interM[i].isInterM = True
         for t in ASTs:
             classInterNode = t[1].typeDcl
             if classInterNode and classInterNode.__class__.__name__ == "ClassNode":
diff --git a/MemberNodes.py b/MemberNodes.py
index cd982e7..ece5d00 100644
--- a/MemberNodes.py
+++ b/MemberNodes.py
@@ -105,7 +105,8 @@ class MethodNode(ASTNode):
         self.typeName = typeName
         self.order = order
         self.myType = None
-        # self.SIToffset for methods that implements interface method
+        self.isInterM = False # if self is an interface method
+        self.replace = None # pointer to the method that this method overrides
 
         # get method name
         nameNodes = getParseTreeNodes(['ID'], parseTree, ['params', 'type', 'methodBody'])
@@ -301,6 +302,13 @@ class MethodNode(ASTNode):
 
         self.code += genProcedure(bodyCode, "Constructor definition for " + self.name + " " + self.paramTypes)
 
+    # gets the top-most level method that this method overrides
+    def getTopReplace(self):
+        if not self.replace:
+            return self
+        else:
+            return self.replace.getTopReplace()
+
 ############# helper for forward ref checking ########
 # Input: AST Node
 # Output: A list of names to be check
diff --git a/TypeNodes.py b/TypeNodes.py
index b4bd4bb..d7ae472 100644
--- a/TypeNodes.py
+++ b/TypeNodes.py
@@ -46,6 +46,7 @@ class ClassInterNode(ASTNode):
             unique.append(key)
 
         contains = self.getContains([])
+        self.contains = contains
 
         # 10. A class that contains any abstract methods must be abstract.
         if isinstance(self, ClassNode):
@@ -93,6 +94,8 @@ class ClassInterNode(ASTNode):
                         if 'abstract' not in scc.mods:
                             safeReplace(sic, scc, self.name)
                             sicOverwritten = True
+                            if sc.__class__.__name__ == "MethodNode":
+                                c.replace = sc
                             break
                 if not sicOverwritten:
                     superContains.append(sic)
@@ -124,6 +127,8 @@ class ClassInterNode(ASTNode):
                 if type(sc) == type(c) and sc == c:
                     safeReplace(sc, c, self.name)
                     scOverwritten = True
+                    if sc.__class__.__name__ == "MethodNode":
+                        c.replace = sc
                     break
             if not scOverwritten:
                 contains.append(sc)
@@ -368,9 +373,9 @@ class ClassNode(ClassInterNode):
                     methodDict[(i.name, i.paramTypes)] = i
 
         # Layout SIT
-        self.code += pLabel("SIT_" + self.name, "local")
+        self.data += pLabel("SIT_" + self.name, "local")
         for i in range(self.SITsize):
-            self.code += pLabel("SIT_" + i, "local")
+            self.data += pLabel("SIT_" + i, "local")
             self.data += p("dd", "42")
 
         self.data += ";END OF CLASS MEMORY LAYOUT FOR CLASS " + self.name + "\n"
@@ -402,20 +407,17 @@ class ClassNode(ClassInterNode):
         self.code += p("mov", "eax", "_SIT_spot" + self.name)
         self.code += p("mov", "[eax]", "dword _SIT_" + self.name)
 
-        interM = []
-        for s in self.super:
-            if s.__class__.__name__ == "InterNode":
-                interM += s.methods
-        for m in interM:
-            imple = methodDict[(m.name, m.paramTypes)]
-            className = imple.typeName
-
-            dlabel = "_SIT_" + str(m.offset / 4)
-            imLabel = "M_" + className + "_" + m.name + "_" + m.paramTypes
-            if className != self.typeName:
-                self.code += p("extern", imLabel)
-            self.code += p("mov", "eax", dlabel)
-            self.code += p("mov", "[eax]", "dword " + imLabel)
+        for m in self.contains:
+            if m.__class__.__name__ == "MethodNode":
+                sm = m.getTopReplace()
+                if sm != m and sm.isInterM:
+                    dlabel = "_SIT_" + str(sm.offset / 4)
+                    className = m.typeName
+                    imLabel = "M_" + className + "_" + m.name + "_" + m.paramTypes
+                    if className != self.typeName:
+                        self.code += p("extern", imLabel)
+                    self.code += p("mov", "eax", dlabel)
+                    self.code += p("mov", "[eax]", "dword " + imLabel)
         self.code += "; End of fill in SIT.\n"
 
         self.code += p(instruction="ret", arg1="")
@@ -425,6 +427,7 @@ class ClassNode(ClassInterNode):
 
 
 
+
         # print(self.name)
         # print(self.fieldOffset)
 
-- 
GitLab