From b15c3227f5a1d589e2b685da073abdd3af0db2a0 Mon Sep 17 00:00:00 2001
From: Bhatu <prbhatu@microsoft.com>
Date: Wed, 25 Nov 2020 17:07:59 +0530
Subject: [PATCH] Fix double scaledown bug for mul like ops.

Squarediff exposed a bug in codegen where both inputs to mul were
same. Depending on the scale of the variable at that point, we
sometimes do a scaledown of the inputs of multiplication so as to
maintain precision.
    scaledown(a, scale)
    scaledown(b, scale)
    mul(a,b)
But in this case both the inputs to mul were same so we were doing
    scaledown(a, scale)
    scaledown(a, scale)
    mul(a,a)
This led to loss of precision. Now we just do:
    scaledown(a, scale)
    mul(a,a)
---
 Athos/SeeDot/IR/IRBuilderCSF.py | 18 ++++++++++++------
 1 file changed, 12 insertions(+), 6 deletions(-)

diff --git a/Athos/SeeDot/IR/IRBuilderCSF.py b/Athos/SeeDot/IR/IRBuilderCSF.py
index 71db944..9c7d69f 100644
--- a/Athos/SeeDot/IR/IRBuilderCSF.py
+++ b/Athos/SeeDot/IR/IRBuilderCSF.py
@@ -493,12 +493,13 @@ class IRBuilderCSF(ASTVisitor):
 		if (Util.Config.disableTruncOpti):
 			progExtraAfter = self.addTruncateFunctionCall(node, funcName, expr_3, Util.Config.consSF)
 		else:
+			inputs_same = (expr_1.idf == expr_2.idf)
 			expr1_sf = self.scaleFacMapping[expr_1.idf]
 			expr2_sf = self.scaleFacMapping[expr_2.idf]
 			if (expr1_sf > self.scaleFac):
 				progExtraBefore = self.addTruncateFunctionCall(node.expr1, funcName, expr_1, expr1_sf-self.scaleFac)
 				self.scaleFacMapping[expr_1.idf] = self.scaleFac
-			if (expr2_sf > self.scaleFac):
+			if (not inputs_same) and (expr2_sf > self.scaleFac):
 				progExtraBefore = IRUtil.prog_merge(progExtraBefore, self.addTruncateFunctionCall(node.expr2, funcName, expr_2, expr2_sf-self.scaleFac))
 				self.scaleFacMapping[expr_2.idf] = self.scaleFac
 			self.scaleFacMapping[expr_3.idf] = 2*self.scaleFac
@@ -533,12 +534,13 @@ class IRBuilderCSF(ASTVisitor):
 		if (Util.Config.disableTruncOpti):
 			progExtraAfter = self.addTruncateFunctionCall(node, "MulInt", expr_3, Util.Config.consSF)
 		else:
+			inputs_same = (expr_1.idf == expr_2.idf)
 			expr1_sf = self.scaleFacMapping[expr_1.idf]
 			expr2_sf = self.scaleFacMapping[expr_2.idf]
 			if (expr1_sf > self.scaleFac):
 				progExtraBefore = self.addTruncateFunctionCall(node.expr1, "MulInt", expr_1, expr1_sf-self.scaleFac)
 				self.scaleFacMapping[expr_1.idf] = self.scaleFac
-			if (expr2_sf > self.scaleFac):
+			if (not inputs_same) and (expr2_sf > self.scaleFac):
 				progExtraBefore = IRUtil.prog_merge(progExtraBefore, self.addTruncateFunctionCall(node.expr2, "MulInt", expr_2, expr2_sf-self.scaleFac))
 				self.scaleFacMapping[expr_2.idf] = self.scaleFac
 			self.scaleFacMapping[expr_3.idf] = 2*self.scaleFac
@@ -578,12 +580,13 @@ class IRBuilderCSF(ASTVisitor):
 		if (Util.Config.disableTruncOpti):
 			progExtraAfter = self.addTruncateFunctionCall(node, "ScalarMul", expr_3, Util.Config.consSF)
 		else:
+			inputs_same = (expr_1.idf == expr_2.idf)
 			expr1_sf = self.scaleFacMapping[expr_1.idf]
 			expr2_sf = self.scaleFacMapping[expr_2.idf]
 			if (expr1_sf > self.scaleFac):
 				progExtraBefore = self.addTruncateFunctionCall(node.expr1, "ScalarMul", expr_1, expr1_sf-self.scaleFac)
 				self.scaleFacMapping[expr_1.idf] = self.scaleFac
-			if (expr2_sf > self.scaleFac):
+			if (not inputs_same) and (expr2_sf > self.scaleFac):
 				progExtraBefore = IRUtil.prog_merge(progExtraBefore, self.addTruncateFunctionCall(node.expr2, "ScalarMul", expr_2, expr2_sf-self.scaleFac))
 				self.scaleFacMapping[expr_2.idf] = self.scaleFac
 			self.scaleFacMapping[expr_3.idf] = 2*self.scaleFac
@@ -639,12 +642,13 @@ class IRBuilderCSF(ASTVisitor):
 		if (Util.Config.disableTruncOpti):
 			progExtraAfter = self.addTruncateFunctionCall(node, "MatMul2D", expr_3, Util.Config.consSF)
 		else:
+			inputs_same = (expr_1.idf == expr_2.idf)
 			expr1_sf = self.scaleFacMapping[expr_1.idf]
 			expr2_sf = self.scaleFacMapping[expr_2.idf]
 			if (expr1_sf > self.scaleFac):
 				progExtraBefore = self.addTruncateFunctionCall(node.expr1, "MatMul2D", expr_1, expr1_sf-self.scaleFac)
 				self.scaleFacMapping[expr_1.idf] = self.scaleFac
-			if (expr2_sf > self.scaleFac):
+			if (not inputs_same) and (expr2_sf > self.scaleFac):
 				progExtraBefore = IRUtil.prog_merge(progExtraBefore, self.addTruncateFunctionCall(node.expr2, "MatMul2D", expr_2, expr2_sf-self.scaleFac))
 				self.scaleFacMapping[expr_2.idf] = self.scaleFac
 			self.scaleFacMapping[expr_3.idf] = 2*self.scaleFac
@@ -728,12 +732,13 @@ class IRBuilderCSF(ASTVisitor):
 		if (Util.Config.disableTruncOpti):
 			progExtraAfter = self.addTruncateFunctionCall(node, "Conv", returnExpr, Util.Config.consSF)
 		else:
+			inputs_same = (expr_1.idf == expr_2.idf)
 			expr1_sf = self.scaleFacMapping[expr1.idf]
 			expr2_sf = self.scaleFacMapping[expr2.idf]
 			if (expr1_sf > self.scaleFac):
 				progExtraBefore = self.addTruncateFunctionCall(node.expr1, "Conv", expr1, expr1_sf-self.scaleFac)
 				self.scaleFacMapping[expr1.idf] = self.scaleFac
-			if (expr2_sf > self.scaleFac):
+			if (not inputs_same) and (expr2_sf > self.scaleFac):
 				progExtraBefore = IRUtil.prog_merge(progExtraBefore, self.addTruncateFunctionCall(node.expr2, "Conv", expr2, expr2_sf-self.scaleFac))
 				self.scaleFacMapping[expr_2.idf] = self.scaleFac
 			self.scaleFacMapping[returnExpr.idf] = 2*self.scaleFac
@@ -833,12 +838,13 @@ class IRBuilderCSF(ASTVisitor):
 		if (Util.Config.disableTruncOpti):
 			progExtraAfter = self.addTruncateFunctionCall(node, "ConvTranspose", returnExpr, self.scaleFac)
 		else:
+			inputs_same = (expr_1.idf == expr_2.idf)
 			expr1_sf = self.scaleFacMapping[expr1.idf]
 			expr2_sf = self.scaleFacMapping[expr2.idf]
 			if (expr1_sf > self.scaleFac):
 				progExtraBefore = self.addTruncateFunctionCall(node.expr1, "ConvTranspose", expr1, expr1_sf-self.scaleFac)
 				self.scaleFacMapping[expr1.idf] = self.scaleFac
-			if (expr2_sf > self.scaleFac):
+			if (not inputs_same) and (expr2_sf > self.scaleFac):
 				progExtraBefore = IRUtil.prog_merge(progExtraBefore, self.addTruncateFunctionCall(node.expr2, "ConvTranspose", expr2, expr2_sf-self.scaleFac))
 				self.scaleFacMapping[expr2.idf] = self.scaleFac
 			self.scaleFacMapping[returnExpr.idf] = 2*self.scaleFac
-- 
GitLab