From 973998c27caf90ce20214b2a65aa21196c8fb6ba Mon Sep 17 00:00:00 2001 From: Nishant Kumar <t-niskum@microsoft.com> Date: Fri, 4 Oct 2019 18:44:17 +0530 Subject: [PATCH] Small bug fixes: 1. Bug fix in SeeDot/IRBuilderCSF.py -- for visitFloat() -- had a bug earlier. 2. TFNodesAST -- compilation of mul TensorFlow node 3. Added some more functions in library files of TFEzPCLibrary. --- Athos/SeeDot/IR/IRBuilderCSF.py | 6 ++-- Athos/TFCompiler/TFNodesAST.py | 2 +- Athos/TFEzPCLibrary/Library32_common.ezpc | 35 ++++++++++++++++++++++ Athos/TFEzPCLibrary/Library64_common.ezpc | 36 +++++++++++++++++++++++ 4 files changed, 75 insertions(+), 4 deletions(-) diff --git a/Athos/SeeDot/IR/IRBuilderCSF.py b/Athos/SeeDot/IR/IRBuilderCSF.py index 31cd90a..bc566e8 100644 --- a/Athos/SeeDot/IR/IRBuilderCSF.py +++ b/Athos/SeeDot/IR/IRBuilderCSF.py @@ -95,15 +95,15 @@ class IRBuilderCSF(ASTVisitor): r = node.value p = self.get_expnt(abs(r)) k = IR.DataType.getInt(np.ldexp(r, p)) - comment = IR.Comment('Float to int : ' + str(r) + ' to ' + str(k)) expr = None + prog = IR.Prog([IR.Comment('Float to int : {0} to {1}, isSecret = {2}.'.format(str(r), str(k), node.isSecret))]) if not(node.isSecret): expr = IR.Int(k) else: expr = self.getTempVar() self.decls[expr.idf] = [node.type] - prog = IRUtil.prog_merge(IR.Prog([IR.Decl(expr.idf, node.type)]), prog) - return (IR.Prog([comment]), expr) + prog = IRUtil.prog_merge(IR.Prog([IR.Decl(expr.idf, node.type)]), prog) + return (prog, expr) def visitId(self, node:AST.ID, args=None): idf = node.name diff --git a/Athos/TFCompiler/TFNodesAST.py b/Athos/TFCompiler/TFNodesAST.py index ad3c64f..d5d0956 100644 --- a/Athos/TFCompiler/TFNodesAST.py +++ b/Athos/TFCompiler/TFNodesAST.py @@ -121,7 +121,7 @@ class TFNodesAST: inputsRef = curNode.getInputsRef() assert(len(inputsRef) == 2) return (None, AST.BOp(AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]), - TFNodesAST.getOperatorsIdx('*'), + TFNodesAST.getOperatorsIdx('.*'), AST.ID(dictNodeNameToOutVarStr[inputsRef[1]]) )) diff --git a/Athos/TFEzPCLibrary/Library32_common.ezpc b/Athos/TFEzPCLibrary/Library32_common.ezpc index 935ca25..f6184db 100644 --- a/Athos/TFEzPCLibrary/Library32_common.ezpc +++ b/Athos/TFEzPCLibrary/Library32_common.ezpc @@ -146,6 +146,15 @@ def void CreateIdentity44(int32_pl s1, int32_pl s2, int32_pl s3, int32_pl s4, in }; } +(**************************) +def void CreateCopy2211(int32_pl s1, int32_pl s2, int32_pl inps1, int32_pl inps2, int32_al[inps1][inps2] inArr, int32_pl perDimSize, int32_pl[perDimSize] beginIdx, int32_pl[perDimSize] sizeIdx, int32_al[s1][s2] outArr){ + for i=[0:s1]{ + for j=[0:s2]{ + outArr[i][j] = inArr[beginIdx[0]+i][beginIdx[1]+j]; + }; + }; +} + (**************************) def void Concat2T444(int32_pl s1, int32_pl s2, int32_pl s3, int32_pl s4, int32_pl inp1s1, int32_pl inp1s2, int32_pl inp1s3, int32_pl inp1s4, int32_al[inp1s1][inp1s2][inp1s3][inp1s4] inp1, int32_pl inp2s1, int32_pl inp2s2, int32_pl inp2s3, int32_pl inp2s4, int32_al[inp2s1][inp2s2][inp2s3][inp2s4] inp2, int32_pl axis, int32_al[s1][s2][s3][s4] outp){ for i1=[0:s1]{ @@ -195,6 +204,29 @@ def void Concat2T444(int32_pl s1, int32_pl s2, int32_pl s3, int32_pl s4, int32_p }; } +def void Concat2T222(int32_pl s1, int32_pl s2, int32_pl inp1s1, int32_pl inp1s2, int32_al[inp1s1][inp1s2] inp1, int32_pl inp2s1, int32_pl inp2s2, int32_al[inp2s1][inp2s2] inp2, int32_pl axis, int32_al[s1][s2] outp){ + for i1=[0:s1]{ + for i2=[0:s2]{ + if (axis==0){ + if (i1 < inp1s1){ + outp[i1][i2] = inp1[i1][i2]; + } + else{ + outp[i1][i2] = inp2[i1-inp1s1][i2]; + }; + } + else{ + if (i2 < inp1s2){ + outp[i1][i2] = inp1[i1][i2]; + } + else{ + outp[i1][i2] = inp2[i1][i2-inp1s2]; + }; + }; + }; + }; +} + (**************************) (* Generic implementation of Conv2DCSF *) @@ -353,3 +385,6 @@ def void ClearMemPublic(int32_pl x){ return; } +def void ClearMemPublic1(int32_pl s, int32_pl[s] x){ + return; +} \ No newline at end of file diff --git a/Athos/TFEzPCLibrary/Library64_common.ezpc b/Athos/TFEzPCLibrary/Library64_common.ezpc index a715aba..042cb9c 100644 --- a/Athos/TFEzPCLibrary/Library64_common.ezpc +++ b/Athos/TFEzPCLibrary/Library64_common.ezpc @@ -146,6 +146,15 @@ def void CreateIdentity44(int32_pl s1, int32_pl s2, int32_pl s3, int32_pl s4, in }; } +(**************************) +def void CreateCopy2211(int32_pl s1, int32_pl s2, int32_pl inps1, int32_pl inps2, int64_al[inps1][inps2] inArr, int32_pl perDimSize, int32_pl[perDimSize] beginIdx, int32_pl[perDimSize] sizeIdx, int64_al[s1][s2] outArr){ + for i=[0:s1]{ + for j=[0:s2]{ + outArr[i][j] = inArr[beginIdx[0]+i][beginIdx[1]+j]; + }; + }; +} + (**************************) def void Concat2T444(int32_pl s1, int32_pl s2, int32_pl s3, int32_pl s4, int32_pl inp1s1, int32_pl inp1s2, int32_pl inp1s3, int32_pl inp1s4, int64_al[inp1s1][inp1s2][inp1s3][inp1s4] inp1, int32_pl inp2s1, int32_pl inp2s2, int32_pl inp2s3, int32_pl inp2s4, int64_al[inp2s1][inp2s2][inp2s3][inp2s4] inp2, int32_pl axis, int64_al[s1][s2][s3][s4] outp){ for i1=[0:s1]{ @@ -195,6 +204,29 @@ def void Concat2T444(int32_pl s1, int32_pl s2, int32_pl s3, int32_pl s4, int32_p }; } +def void Concat2T222(int32_pl s1, int32_pl s2, int32_pl inp1s1, int32_pl inp1s2, int64_al[inp1s1][inp1s2] inp1, int32_pl inp2s1, int32_pl inp2s2, int64_al[inp2s1][inp2s2] inp2, int32_pl axis, int64_al[s1][s2] outp){ + for i1=[0:s1]{ + for i2=[0:s2]{ + if (axis==0){ + if (i1 < inp1s1){ + outp[i1][i2] = inp1[i1][i2]; + } + else{ + outp[i1][i2] = inp2[i1-inp1s1][i2]; + }; + } + else{ + if (i2 < inp1s2){ + outp[i1][i2] = inp1[i1][i2]; + } + else{ + outp[i1][i2] = inp2[i1][i2-inp1s2]; + }; + }; + }; + }; +} + (**************************) (* Generic implementation of Conv2DCSF *) @@ -351,4 +383,8 @@ def void Squeeze24(int32_pl s1, int32_pl s2, int32_pl dim1, int32_pl dim2, int32 (**************************) def void ClearMemPublic(int32_pl x){ return; +} + +def void ClearMemPublic1(int32_pl s, int32_pl[s] x){ + return; } \ No newline at end of file -- GitLab