diff --git a/Athos/SeeDot/IR/IRBuilderCSF.py b/Athos/SeeDot/IR/IRBuilderCSF.py index 31cd90a3d1e1dd783f0ff3007645813d7a5e77d0..bc566e8e1d4537a1e7e79b8c653035e3eb3dd4fd 100644 --- a/Athos/SeeDot/IR/IRBuilderCSF.py +++ b/Athos/SeeDot/IR/IRBuilderCSF.py @@ -95,15 +95,15 @@ class IRBuilderCSF(ASTVisitor): r = node.value p = self.get_expnt(abs(r)) k = IR.DataType.getInt(np.ldexp(r, p)) - comment = IR.Comment('Float to int : ' + str(r) + ' to ' + str(k)) expr = None + prog = IR.Prog([IR.Comment('Float to int : {0} to {1}, isSecret = {2}.'.format(str(r), str(k), node.isSecret))]) if not(node.isSecret): expr = IR.Int(k) else: expr = self.getTempVar() self.decls[expr.idf] = [node.type] - prog = IRUtil.prog_merge(IR.Prog([IR.Decl(expr.idf, node.type)]), prog) - return (IR.Prog([comment]), expr) + prog = IRUtil.prog_merge(IR.Prog([IR.Decl(expr.idf, node.type)]), prog) + return (prog, expr) def visitId(self, node:AST.ID, args=None): idf = node.name diff --git a/Athos/TFCompiler/TFNodesAST.py b/Athos/TFCompiler/TFNodesAST.py index ad3c64fe2cb5c9d83ab437c2282996f92ed5a125..d5d09562831aa54bf835e49cf6d60834b9d9198a 100644 --- a/Athos/TFCompiler/TFNodesAST.py +++ b/Athos/TFCompiler/TFNodesAST.py @@ -121,7 +121,7 @@ class TFNodesAST: inputsRef = curNode.getInputsRef() assert(len(inputsRef) == 2) return (None, AST.BOp(AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]), - TFNodesAST.getOperatorsIdx('*'), + TFNodesAST.getOperatorsIdx('.*'), AST.ID(dictNodeNameToOutVarStr[inputsRef[1]]) )) diff --git a/Athos/TFEzPCLibrary/Library32_common.ezpc b/Athos/TFEzPCLibrary/Library32_common.ezpc index 935ca25c3fb8cd92c7543eb28c05d0b532d2e474..f6184dbf38302d9d631a120d1ae8a5a8c6725621 100644 --- a/Athos/TFEzPCLibrary/Library32_common.ezpc +++ b/Athos/TFEzPCLibrary/Library32_common.ezpc @@ -146,6 +146,15 @@ def void CreateIdentity44(int32_pl s1, int32_pl s2, int32_pl s3, int32_pl s4, in }; } +(**************************) +def void CreateCopy2211(int32_pl s1, int32_pl s2, int32_pl inps1, int32_pl inps2, int32_al[inps1][inps2] inArr, int32_pl perDimSize, int32_pl[perDimSize] beginIdx, int32_pl[perDimSize] sizeIdx, int32_al[s1][s2] outArr){ + for i=[0:s1]{ + for j=[0:s2]{ + outArr[i][j] = inArr[beginIdx[0]+i][beginIdx[1]+j]; + }; + }; +} + (**************************) def void Concat2T444(int32_pl s1, int32_pl s2, int32_pl s3, int32_pl s4, int32_pl inp1s1, int32_pl inp1s2, int32_pl inp1s3, int32_pl inp1s4, int32_al[inp1s1][inp1s2][inp1s3][inp1s4] inp1, int32_pl inp2s1, int32_pl inp2s2, int32_pl inp2s3, int32_pl inp2s4, int32_al[inp2s1][inp2s2][inp2s3][inp2s4] inp2, int32_pl axis, int32_al[s1][s2][s3][s4] outp){ for i1=[0:s1]{ @@ -195,6 +204,29 @@ def void Concat2T444(int32_pl s1, int32_pl s2, int32_pl s3, int32_pl s4, int32_p }; } +def void Concat2T222(int32_pl s1, int32_pl s2, int32_pl inp1s1, int32_pl inp1s2, int32_al[inp1s1][inp1s2] inp1, int32_pl inp2s1, int32_pl inp2s2, int32_al[inp2s1][inp2s2] inp2, int32_pl axis, int32_al[s1][s2] outp){ + for i1=[0:s1]{ + for i2=[0:s2]{ + if (axis==0){ + if (i1 < inp1s1){ + outp[i1][i2] = inp1[i1][i2]; + } + else{ + outp[i1][i2] = inp2[i1-inp1s1][i2]; + }; + } + else{ + if (i2 < inp1s2){ + outp[i1][i2] = inp1[i1][i2]; + } + else{ + outp[i1][i2] = inp2[i1][i2-inp1s2]; + }; + }; + }; + }; +} + (**************************) (* Generic implementation of Conv2DCSF *) @@ -353,3 +385,6 @@ def void ClearMemPublic(int32_pl x){ return; } +def void ClearMemPublic1(int32_pl s, int32_pl[s] x){ + return; +} \ No newline at end of file diff --git a/Athos/TFEzPCLibrary/Library64_common.ezpc b/Athos/TFEzPCLibrary/Library64_common.ezpc index a715aba33103934d8e6521badccf5b48ce840ce9..042cb9ca0eb42e8ceee1e81420a696aa65d940c4 100644 --- a/Athos/TFEzPCLibrary/Library64_common.ezpc +++ b/Athos/TFEzPCLibrary/Library64_common.ezpc @@ -146,6 +146,15 @@ def void CreateIdentity44(int32_pl s1, int32_pl s2, int32_pl s3, int32_pl s4, in }; } +(**************************) +def void CreateCopy2211(int32_pl s1, int32_pl s2, int32_pl inps1, int32_pl inps2, int64_al[inps1][inps2] inArr, int32_pl perDimSize, int32_pl[perDimSize] beginIdx, int32_pl[perDimSize] sizeIdx, int64_al[s1][s2] outArr){ + for i=[0:s1]{ + for j=[0:s2]{ + outArr[i][j] = inArr[beginIdx[0]+i][beginIdx[1]+j]; + }; + }; +} + (**************************) def void Concat2T444(int32_pl s1, int32_pl s2, int32_pl s3, int32_pl s4, int32_pl inp1s1, int32_pl inp1s2, int32_pl inp1s3, int32_pl inp1s4, int64_al[inp1s1][inp1s2][inp1s3][inp1s4] inp1, int32_pl inp2s1, int32_pl inp2s2, int32_pl inp2s3, int32_pl inp2s4, int64_al[inp2s1][inp2s2][inp2s3][inp2s4] inp2, int32_pl axis, int64_al[s1][s2][s3][s4] outp){ for i1=[0:s1]{ @@ -195,6 +204,29 @@ def void Concat2T444(int32_pl s1, int32_pl s2, int32_pl s3, int32_pl s4, int32_p }; } +def void Concat2T222(int32_pl s1, int32_pl s2, int32_pl inp1s1, int32_pl inp1s2, int64_al[inp1s1][inp1s2] inp1, int32_pl inp2s1, int32_pl inp2s2, int64_al[inp2s1][inp2s2] inp2, int32_pl axis, int64_al[s1][s2] outp){ + for i1=[0:s1]{ + for i2=[0:s2]{ + if (axis==0){ + if (i1 < inp1s1){ + outp[i1][i2] = inp1[i1][i2]; + } + else{ + outp[i1][i2] = inp2[i1-inp1s1][i2]; + }; + } + else{ + if (i2 < inp1s2){ + outp[i1][i2] = inp1[i1][i2]; + } + else{ + outp[i1][i2] = inp2[i1][i2-inp1s2]; + }; + }; + }; + }; +} + (**************************) (* Generic implementation of Conv2DCSF *) @@ -351,4 +383,8 @@ def void Squeeze24(int32_pl s1, int32_pl s2, int32_pl dim1, int32_pl dim2, int32 (**************************) def void ClearMemPublic(int32_pl x){ return; +} + +def void ClearMemPublic1(int32_pl s, int32_pl[s] x){ + return; } \ No newline at end of file