diff --git a/Athos/SeeDot/IR/IRBuilderCSF.py b/Athos/SeeDot/IR/IRBuilderCSF.py index 71db944034e712329caf8a77b1350a413d16a9d4..9c7d69f6cbff5d1d7c6cd961aa9513cba78dfa5c 100644 --- a/Athos/SeeDot/IR/IRBuilderCSF.py +++ b/Athos/SeeDot/IR/IRBuilderCSF.py @@ -493,12 +493,13 @@ class IRBuilderCSF(ASTVisitor): if (Util.Config.disableTruncOpti): progExtraAfter = self.addTruncateFunctionCall(node, funcName, expr_3, Util.Config.consSF) else: + inputs_same = (expr_1.idf == expr_2.idf) expr1_sf = self.scaleFacMapping[expr_1.idf] expr2_sf = self.scaleFacMapping[expr_2.idf] if (expr1_sf > self.scaleFac): progExtraBefore = self.addTruncateFunctionCall(node.expr1, funcName, expr_1, expr1_sf-self.scaleFac) self.scaleFacMapping[expr_1.idf] = self.scaleFac - if (expr2_sf > self.scaleFac): + if (not inputs_same) and (expr2_sf > self.scaleFac): progExtraBefore = IRUtil.prog_merge(progExtraBefore, self.addTruncateFunctionCall(node.expr2, funcName, expr_2, expr2_sf-self.scaleFac)) self.scaleFacMapping[expr_2.idf] = self.scaleFac self.scaleFacMapping[expr_3.idf] = 2*self.scaleFac @@ -533,12 +534,13 @@ class IRBuilderCSF(ASTVisitor): if (Util.Config.disableTruncOpti): progExtraAfter = self.addTruncateFunctionCall(node, "MulInt", expr_3, Util.Config.consSF) else: + inputs_same = (expr_1.idf == expr_2.idf) expr1_sf = self.scaleFacMapping[expr_1.idf] expr2_sf = self.scaleFacMapping[expr_2.idf] if (expr1_sf > self.scaleFac): progExtraBefore = self.addTruncateFunctionCall(node.expr1, "MulInt", expr_1, expr1_sf-self.scaleFac) self.scaleFacMapping[expr_1.idf] = self.scaleFac - if (expr2_sf > self.scaleFac): + if (not inputs_same) and (expr2_sf > self.scaleFac): progExtraBefore = IRUtil.prog_merge(progExtraBefore, self.addTruncateFunctionCall(node.expr2, "MulInt", expr_2, expr2_sf-self.scaleFac)) self.scaleFacMapping[expr_2.idf] = self.scaleFac self.scaleFacMapping[expr_3.idf] = 2*self.scaleFac @@ -578,12 +580,13 @@ class IRBuilderCSF(ASTVisitor): if (Util.Config.disableTruncOpti): progExtraAfter = self.addTruncateFunctionCall(node, "ScalarMul", expr_3, Util.Config.consSF) else: + inputs_same = (expr_1.idf == expr_2.idf) expr1_sf = self.scaleFacMapping[expr_1.idf] expr2_sf = self.scaleFacMapping[expr_2.idf] if (expr1_sf > self.scaleFac): progExtraBefore = self.addTruncateFunctionCall(node.expr1, "ScalarMul", expr_1, expr1_sf-self.scaleFac) self.scaleFacMapping[expr_1.idf] = self.scaleFac - if (expr2_sf > self.scaleFac): + if (not inputs_same) and (expr2_sf > self.scaleFac): progExtraBefore = IRUtil.prog_merge(progExtraBefore, self.addTruncateFunctionCall(node.expr2, "ScalarMul", expr_2, expr2_sf-self.scaleFac)) self.scaleFacMapping[expr_2.idf] = self.scaleFac self.scaleFacMapping[expr_3.idf] = 2*self.scaleFac @@ -639,12 +642,13 @@ class IRBuilderCSF(ASTVisitor): if (Util.Config.disableTruncOpti): progExtraAfter = self.addTruncateFunctionCall(node, "MatMul2D", expr_3, Util.Config.consSF) else: + inputs_same = (expr_1.idf == expr_2.idf) expr1_sf = self.scaleFacMapping[expr_1.idf] expr2_sf = self.scaleFacMapping[expr_2.idf] if (expr1_sf > self.scaleFac): progExtraBefore = self.addTruncateFunctionCall(node.expr1, "MatMul2D", expr_1, expr1_sf-self.scaleFac) self.scaleFacMapping[expr_1.idf] = self.scaleFac - if (expr2_sf > self.scaleFac): + if (not inputs_same) and (expr2_sf > self.scaleFac): progExtraBefore = IRUtil.prog_merge(progExtraBefore, self.addTruncateFunctionCall(node.expr2, "MatMul2D", expr_2, expr2_sf-self.scaleFac)) self.scaleFacMapping[expr_2.idf] = self.scaleFac self.scaleFacMapping[expr_3.idf] = 2*self.scaleFac @@ -728,12 +732,13 @@ class IRBuilderCSF(ASTVisitor): if (Util.Config.disableTruncOpti): progExtraAfter = self.addTruncateFunctionCall(node, "Conv", returnExpr, Util.Config.consSF) else: + inputs_same = (expr_1.idf == expr_2.idf) expr1_sf = self.scaleFacMapping[expr1.idf] expr2_sf = self.scaleFacMapping[expr2.idf] if (expr1_sf > self.scaleFac): progExtraBefore = self.addTruncateFunctionCall(node.expr1, "Conv", expr1, expr1_sf-self.scaleFac) self.scaleFacMapping[expr1.idf] = self.scaleFac - if (expr2_sf > self.scaleFac): + if (not inputs_same) and (expr2_sf > self.scaleFac): progExtraBefore = IRUtil.prog_merge(progExtraBefore, self.addTruncateFunctionCall(node.expr2, "Conv", expr2, expr2_sf-self.scaleFac)) self.scaleFacMapping[expr_2.idf] = self.scaleFac self.scaleFacMapping[returnExpr.idf] = 2*self.scaleFac @@ -833,12 +838,13 @@ class IRBuilderCSF(ASTVisitor): if (Util.Config.disableTruncOpti): progExtraAfter = self.addTruncateFunctionCall(node, "ConvTranspose", returnExpr, self.scaleFac) else: + inputs_same = (expr_1.idf == expr_2.idf) expr1_sf = self.scaleFacMapping[expr1.idf] expr2_sf = self.scaleFacMapping[expr2.idf] if (expr1_sf > self.scaleFac): progExtraBefore = self.addTruncateFunctionCall(node.expr1, "ConvTranspose", expr1, expr1_sf-self.scaleFac) self.scaleFacMapping[expr1.idf] = self.scaleFac - if (expr2_sf > self.scaleFac): + if (not inputs_same) and (expr2_sf > self.scaleFac): progExtraBefore = IRUtil.prog_merge(progExtraBefore, self.addTruncateFunctionCall(node.expr2, "ConvTranspose", expr2, expr2_sf-self.scaleFac)) self.scaleFacMapping[expr2.idf] = self.scaleFac self.scaleFacMapping[returnExpr.idf] = 2*self.scaleFac