Skip to content
Snippets Groups Projects
Commit 973998c2 authored by Nishant Kumar's avatar Nishant Kumar
Browse files

Small bug fixes:

1. Bug fix in SeeDot/IRBuilderCSF.py -- for visitFloat() -- had a bug earlier.
2. TFNodesAST -- compilation of mul TensorFlow node
3. Added some more functions in library files of TFEzPCLibrary.
parent be909a64
No related branches found
No related tags found
No related merge requests found
...@@ -95,15 +95,15 @@ class IRBuilderCSF(ASTVisitor): ...@@ -95,15 +95,15 @@ class IRBuilderCSF(ASTVisitor):
r = node.value r = node.value
p = self.get_expnt(abs(r)) p = self.get_expnt(abs(r))
k = IR.DataType.getInt(np.ldexp(r, p)) k = IR.DataType.getInt(np.ldexp(r, p))
comment = IR.Comment('Float to int : ' + str(r) + ' to ' + str(k))
expr = None expr = None
prog = IR.Prog([IR.Comment('Float to int : {0} to {1}, isSecret = {2}.'.format(str(r), str(k), node.isSecret))])
if not(node.isSecret): if not(node.isSecret):
expr = IR.Int(k) expr = IR.Int(k)
else: else:
expr = self.getTempVar() expr = self.getTempVar()
self.decls[expr.idf] = [node.type] self.decls[expr.idf] = [node.type]
prog = IRUtil.prog_merge(IR.Prog([IR.Decl(expr.idf, node.type)]), prog) prog = IRUtil.prog_merge(IR.Prog([IR.Decl(expr.idf, node.type)]), prog)
return (IR.Prog([comment]), expr) return (prog, expr)
def visitId(self, node:AST.ID, args=None): def visitId(self, node:AST.ID, args=None):
idf = node.name idf = node.name
......
...@@ -121,7 +121,7 @@ class TFNodesAST: ...@@ -121,7 +121,7 @@ class TFNodesAST:
inputsRef = curNode.getInputsRef() inputsRef = curNode.getInputsRef()
assert(len(inputsRef) == 2) assert(len(inputsRef) == 2)
return (None, AST.BOp(AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]), return (None, AST.BOp(AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]),
TFNodesAST.getOperatorsIdx('*'), TFNodesAST.getOperatorsIdx('.*'),
AST.ID(dictNodeNameToOutVarStr[inputsRef[1]]) AST.ID(dictNodeNameToOutVarStr[inputsRef[1]])
)) ))
......
...@@ -146,6 +146,15 @@ def void CreateIdentity44(int32_pl s1, int32_pl s2, int32_pl s3, int32_pl s4, in ...@@ -146,6 +146,15 @@ def void CreateIdentity44(int32_pl s1, int32_pl s2, int32_pl s3, int32_pl s4, in
}; };
} }
(**************************)
def void CreateCopy2211(int32_pl s1, int32_pl s2, int32_pl inps1, int32_pl inps2, int32_al[inps1][inps2] inArr, int32_pl perDimSize, int32_pl[perDimSize] beginIdx, int32_pl[perDimSize] sizeIdx, int32_al[s1][s2] outArr){
for i=[0:s1]{
for j=[0:s2]{
outArr[i][j] = inArr[beginIdx[0]+i][beginIdx[1]+j];
};
};
}
(**************************) (**************************)
def void Concat2T444(int32_pl s1, int32_pl s2, int32_pl s3, int32_pl s4, int32_pl inp1s1, int32_pl inp1s2, int32_pl inp1s3, int32_pl inp1s4, int32_al[inp1s1][inp1s2][inp1s3][inp1s4] inp1, int32_pl inp2s1, int32_pl inp2s2, int32_pl inp2s3, int32_pl inp2s4, int32_al[inp2s1][inp2s2][inp2s3][inp2s4] inp2, int32_pl axis, int32_al[s1][s2][s3][s4] outp){ def void Concat2T444(int32_pl s1, int32_pl s2, int32_pl s3, int32_pl s4, int32_pl inp1s1, int32_pl inp1s2, int32_pl inp1s3, int32_pl inp1s4, int32_al[inp1s1][inp1s2][inp1s3][inp1s4] inp1, int32_pl inp2s1, int32_pl inp2s2, int32_pl inp2s3, int32_pl inp2s4, int32_al[inp2s1][inp2s2][inp2s3][inp2s4] inp2, int32_pl axis, int32_al[s1][s2][s3][s4] outp){
for i1=[0:s1]{ for i1=[0:s1]{
...@@ -195,6 +204,29 @@ def void Concat2T444(int32_pl s1, int32_pl s2, int32_pl s3, int32_pl s4, int32_p ...@@ -195,6 +204,29 @@ def void Concat2T444(int32_pl s1, int32_pl s2, int32_pl s3, int32_pl s4, int32_p
}; };
} }
def void Concat2T222(int32_pl s1, int32_pl s2, int32_pl inp1s1, int32_pl inp1s2, int32_al[inp1s1][inp1s2] inp1, int32_pl inp2s1, int32_pl inp2s2, int32_al[inp2s1][inp2s2] inp2, int32_pl axis, int32_al[s1][s2] outp){
for i1=[0:s1]{
for i2=[0:s2]{
if (axis==0){
if (i1 < inp1s1){
outp[i1][i2] = inp1[i1][i2];
}
else{
outp[i1][i2] = inp2[i1-inp1s1][i2];
};
}
else{
if (i2 < inp1s2){
outp[i1][i2] = inp1[i1][i2];
}
else{
outp[i1][i2] = inp2[i1][i2-inp1s2];
};
};
};
};
}
(**************************) (**************************)
(* Generic implementation of Conv2DCSF *) (* Generic implementation of Conv2DCSF *)
...@@ -353,3 +385,6 @@ def void ClearMemPublic(int32_pl x){ ...@@ -353,3 +385,6 @@ def void ClearMemPublic(int32_pl x){
return; return;
} }
def void ClearMemPublic1(int32_pl s, int32_pl[s] x){
return;
}
\ No newline at end of file
...@@ -146,6 +146,15 @@ def void CreateIdentity44(int32_pl s1, int32_pl s2, int32_pl s3, int32_pl s4, in ...@@ -146,6 +146,15 @@ def void CreateIdentity44(int32_pl s1, int32_pl s2, int32_pl s3, int32_pl s4, in
}; };
} }
(**************************)
def void CreateCopy2211(int32_pl s1, int32_pl s2, int32_pl inps1, int32_pl inps2, int64_al[inps1][inps2] inArr, int32_pl perDimSize, int32_pl[perDimSize] beginIdx, int32_pl[perDimSize] sizeIdx, int64_al[s1][s2] outArr){
for i=[0:s1]{
for j=[0:s2]{
outArr[i][j] = inArr[beginIdx[0]+i][beginIdx[1]+j];
};
};
}
(**************************) (**************************)
def void Concat2T444(int32_pl s1, int32_pl s2, int32_pl s3, int32_pl s4, int32_pl inp1s1, int32_pl inp1s2, int32_pl inp1s3, int32_pl inp1s4, int64_al[inp1s1][inp1s2][inp1s3][inp1s4] inp1, int32_pl inp2s1, int32_pl inp2s2, int32_pl inp2s3, int32_pl inp2s4, int64_al[inp2s1][inp2s2][inp2s3][inp2s4] inp2, int32_pl axis, int64_al[s1][s2][s3][s4] outp){ def void Concat2T444(int32_pl s1, int32_pl s2, int32_pl s3, int32_pl s4, int32_pl inp1s1, int32_pl inp1s2, int32_pl inp1s3, int32_pl inp1s4, int64_al[inp1s1][inp1s2][inp1s3][inp1s4] inp1, int32_pl inp2s1, int32_pl inp2s2, int32_pl inp2s3, int32_pl inp2s4, int64_al[inp2s1][inp2s2][inp2s3][inp2s4] inp2, int32_pl axis, int64_al[s1][s2][s3][s4] outp){
for i1=[0:s1]{ for i1=[0:s1]{
...@@ -195,6 +204,29 @@ def void Concat2T444(int32_pl s1, int32_pl s2, int32_pl s3, int32_pl s4, int32_p ...@@ -195,6 +204,29 @@ def void Concat2T444(int32_pl s1, int32_pl s2, int32_pl s3, int32_pl s4, int32_p
}; };
} }
def void Concat2T222(int32_pl s1, int32_pl s2, int32_pl inp1s1, int32_pl inp1s2, int64_al[inp1s1][inp1s2] inp1, int32_pl inp2s1, int32_pl inp2s2, int64_al[inp2s1][inp2s2] inp2, int32_pl axis, int64_al[s1][s2] outp){
for i1=[0:s1]{
for i2=[0:s2]{
if (axis==0){
if (i1 < inp1s1){
outp[i1][i2] = inp1[i1][i2];
}
else{
outp[i1][i2] = inp2[i1-inp1s1][i2];
};
}
else{
if (i2 < inp1s2){
outp[i1][i2] = inp1[i1][i2];
}
else{
outp[i1][i2] = inp2[i1][i2-inp1s2];
};
};
};
};
}
(**************************) (**************************)
(* Generic implementation of Conv2DCSF *) (* Generic implementation of Conv2DCSF *)
...@@ -351,4 +383,8 @@ def void Squeeze24(int32_pl s1, int32_pl s2, int32_pl dim1, int32_pl dim2, int32 ...@@ -351,4 +383,8 @@ def void Squeeze24(int32_pl s1, int32_pl s2, int32_pl dim1, int32_pl dim2, int32
(**************************) (**************************)
def void ClearMemPublic(int32_pl x){ def void ClearMemPublic(int32_pl x){
return; return;
}
def void ClearMemPublic1(int32_pl s, int32_pl[s] x){
return;
} }
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment