From 614fe2ea8f2b7bb4129aab40b7fbdf9900596cd8 Mon Sep 17 00:00:00 2001 From: Pratik Bhatu <prbhatu@microsoft.com> Date: Thu, 4 Jun 2020 20:38:42 +0530 Subject: [PATCH] Move loop versions of convolution to cpp library --- Athos/TFEzPCLibrary/Library32_common.ezpc | 60 +--------------------- Athos/TFEzPCLibrary/Library32_cpp.ezpc | 59 ++++++++++++++++++++- Athos/TFEzPCLibrary/Library64_common.ezpc | 54 +------------------ Athos/TFEzPCLibrary/Library64_cpp.ezpc | 58 ++++++++++++++++++++- Athos/TFEzPCLibrary/Library64_porthos.ezpc | 2 +- 5 files changed, 116 insertions(+), 117 deletions(-) diff --git a/Athos/TFEzPCLibrary/Library32_common.ezpc b/Athos/TFEzPCLibrary/Library32_common.ezpc index 1832c82..e6b1ccd 100644 --- a/Athos/TFEzPCLibrary/Library32_common.ezpc +++ b/Athos/TFEzPCLibrary/Library32_common.ezpc @@ -455,26 +455,6 @@ def void Conv2DCSF(int32_pl N, int32_pl H, int32_pl W, int32_pl CI, Conv2DReshapeMatMulOP(N, newH, newW, CO, matmulOP, outArr); } -(* int32_al[N][H][W][CI] inputArr, - int32_al[FH][FW][CI][CO] filterArr, - int32_al[N][((H-FH+zPadHLeft+zPadHRight)/strideH)+1][((W-FW+zPadWLeft+zPadWRight)/strideW)+1][CO] outArr -*) - -def void Conv2DCSFLoop(int32_pl N, int32_pl H, int32_pl W, int32_pl CI, - int32_pl FH, int32_pl FW, int32_pl CO, - int32_pl zPadHLeft, int32_pl zPadHRight, int32_pl zPadWLeft, int32_pl zPadWRight, - int32_pl strideH, int32_pl strideW, int32_pl G, - int32_al[N][H][W][CI] inputArr, - int32_al[FH][FW][CI][CO] filterArr, - int32_pl consSF, - int32_al[N][((H-FH+(zPadHLeft+zPadHRight))/strideH)+1][((W-FW+(zPadWLeft+zPadWRight))/strideW)+1][CO] outArr) -{ - int32_pl outH = ((H-FH+(zPadHLeft+zPadHRight))/strideH)+1; - int32_pl outW = ((W-FW+(zPadWLeft+zPadWRight))/strideW)+1; - - Conv2DLoop(N, H, W, CI, FH, FW, CO, zPadHLeft, zPadHRight, zPadWLeft, zPadWRight, strideH, strideW, outH, outW, G, inputArr, filterArr, consSF, outArr); -} - (**************************) (* Generic implementation of Conv2D with Groups *) @@ -669,27 +649,6 @@ def void Conv3DReshapeInput(int32_pl N, int32_pl D, int32_pl H, int32_pl W, int3 }; } -(* int32_al[N][D][H][W][CI] inputArr, - int32_al[FD][FH][FW][CI][CO] filterArr, - int32_al[N][((D-FD+zPadDLeft+zPadDRight)/strideD)+1][((H-FH+zPadHLeft+zPadHRight)/strideH)+1][((W-FW+zPadWLeft+zPadWRight)/strideW)+1][CO] outArr -*) -(* Loop implementation of convolution run faster with multithreadin *) -def void Conv3DCSFLoop(int32_pl N, int32_pl D, int32_pl H, int32_pl W, int32_pl CI, - int32_pl FD, int32_pl FH, int32_pl FW, int32_pl CO, - int32_pl zPadDLeft, int32_pl zPadDRight, int32_pl zPadHLeft, int32_pl zPadHRight, int32_pl zPadWLeft, int32_pl zPadWRight, - int32_pl strideD, int32_pl strideH, int32_pl strideW, - int32_al[N][D][H][W][CI] inputArr, - int32_al[FD][FH][FW][CI][CO] filterArr, - int32_pl consSF, - int32_al[N][((D-FD+(zPadDLeft+zPadDRight))/strideD)+1][((H-FH+(zPadHLeft+zPadHRight))/strideH)+1][((W-FW+(zPadWLeft+zPadWRight))/strideW)+1][CO] outArr) -{ - int32_pl outD = ((D-FD+(zPadDLeft+zPadDRight))/strideD)+1; - int32_pl outH = ((H-FH+(zPadHLeft+zPadHRight))/strideH)+1; - int32_pl outW = ((W-FW+(zPadWLeft+zPadWRight))/strideW)+1; - - Conv3DLoop(N, D, H, W, CI, FD, FH, FW, CO, zPadDLeft, zPadDRight, zPadHLeft, zPadHRight, zPadWLeft, zPadWRight, strideD, strideH, strideW, outD, outH, outW, inputArr, filterArr, consSF, outArr); -} - (* int32_al[N][D][H][W][CI] inputArr, int32_al[FD][FH][FW][CI][CO] filterArr, int32_al[N][((D-FD+zPadDLeft+zPadDRight)/strideD)+1][((H-FH+zPadHLeft+zPadHRight)/strideH)+1][((W-FW+zPadWLeft+zPadWRight)/strideW)+1][CO] outArr @@ -1016,23 +975,6 @@ def void ConvTranspose3DReshapeInput(int32_pl N, int32_pl DPrime, int32_pl HPrim }; } -(* int32_al[N][DPrime][HPrime][WPrime][CI] inputArr, - int32_al[FD][FH][FW][CO][CI] filter, - int32_al[N][D][H][W][CO] outputArr -*) -def void ConvTranspose3DCSFLoop(int32_pl N, int32_pl DPrime, int32_pl HPrime, int32_pl WPrime, int32_pl CI, - int32_pl FD, int32_pl FH, int32_pl FW, int32_pl CO, - int32_pl D, int32_pl H, int32_pl W, - int32_pl zPadTrDLeft, int32_pl zPadTrDRight, int32_pl zPadTrHLeft, int32_pl zPadTrHRight, int32_pl zPadTrWLeft, int32_pl zPadTrWRight, - int32_pl strideD, int32_pl strideH, int32_pl strideW, - int32_al[N][DPrime][HPrime][WPrime][CI] inputArr, - int32_al[FD][FH][FW][CO][CI] filterArr, - int32_pl consSF, - int32_al[N][D][H][W][CO] outArr) -{ - ConvTranspose3DLoop(N, DPrime, HPrime, WPrime, CI, FD, FH, FW, CO, zPadTrDLeft, zPadTrDRight, zPadTrHLeft, zPadTrHRight, zPadTrWLeft, zPadTrWRight, strideD, strideH, strideW, D, H, W, inputArr, filterArr, consSF, outArr); -} - (* int32_al[N][DPrime][HPrime][WPrime][CI] inputArr, int32_al[FD][FH][FW][CO][CI] filter, int32_al[N][D][H][W][CO] outputArr @@ -1081,4 +1023,4 @@ def void ClearMemPublic4(int32_pl s1, int32_pl s2, int32_pl s3, int32_pl s4, int def void ClearMemPublic5(int32_pl s1, int32_pl s2, int32_pl s3, int32_pl s4, int32_pl s5, int32_pl[s1][s2][s3][s4][s5] arr) { return; -} \ No newline at end of file +} diff --git a/Athos/TFEzPCLibrary/Library32_cpp.ezpc b/Athos/TFEzPCLibrary/Library32_cpp.ezpc index 57dd0a8..fb65727 100644 --- a/Athos/TFEzPCLibrary/Library32_cpp.ezpc +++ b/Athos/TFEzPCLibrary/Library32_cpp.ezpc @@ -84,6 +84,26 @@ def void Conv2DLoop(int32_pl N, int32_pl H, int32_pl W, int32_pl CI, }; } +(* int32_al[N][H][W][CI] inputArr, + int32_al[FH][FW][CI][CO] filterArr, + int32_al[N][((H-FH+zPadHLeft+zPadHRight)/strideH)+1][((W-FW+zPadWLeft+zPadWRight)/strideW)+1][CO] outArr +*) + +def void Conv2DCSFLoop(int32_pl N, int32_pl H, int32_pl W, int32_pl CI, + int32_pl FH, int32_pl FW, int32_pl CO, + int32_pl zPadHLeft, int32_pl zPadHRight, int32_pl zPadWLeft, int32_pl zPadWRight, + int32_pl strideH, int32_pl strideW, int32_pl G, + int32_al[N][H][W][CI] inputArr, + int32_al[FH][FW][CI][CO] filterArr, + int32_pl consSF, + int32_al[N][((H-FH+(zPadHLeft+zPadHRight))/strideH)+1][((W-FW+(zPadWLeft+zPadWRight))/strideW)+1][CO] outArr) +{ + int32_pl outH = ((H-FH+(zPadHLeft+zPadHRight))/strideH)+1; + int32_pl outW = ((W-FW+(zPadWLeft+zPadWRight))/strideW)+1; + + Conv2DLoop(N, H, W, CI, FH, FW, CO, zPadHLeft, zPadHRight, zPadWLeft, zPadWRight, strideH, strideW, outH, outW, G, inputArr, filterArr, consSF, outArr); +} + (**************************) def void Conv3DLoop(int32_pl N, int32_pl D, int32_pl H, int32_pl W, int32_pl CI, int32_pl FD, int32_pl FH, int32_pl FW, int32_pl CO, @@ -126,6 +146,26 @@ def void Conv3DLoop(int32_pl N, int32_pl D, int32_pl H, int32_pl W, int32_pl CI, }; } +(* int32_al[N][D][H][W][CI] inputArr, + int32_al[FD][FH][FW][CI][CO] filterArr, + int32_al[N][((D-FD+zPadDLeft+zPadDRight)/strideD)+1][((H-FH+zPadHLeft+zPadHRight)/strideH)+1][((W-FW+zPadWLeft+zPadWRight)/strideW)+1][CO] outArr +*) +(* Loop implementation of convolution run faster with multithreadin *) +def void Conv3DCSFLoop(int32_pl N, int32_pl D, int32_pl H, int32_pl W, int32_pl CI, + int32_pl FD, int32_pl FH, int32_pl FW, int32_pl CO, + int32_pl zPadDLeft, int32_pl zPadDRight, int32_pl zPadHLeft, int32_pl zPadHRight, int32_pl zPadWLeft, int32_pl zPadWRight, + int32_pl strideD, int32_pl strideH, int32_pl strideW, + int32_al[N][D][H][W][CI] inputArr, + int32_al[FD][FH][FW][CI][CO] filterArr, + int32_pl consSF, + int32_al[N][((D-FD+(zPadDLeft+zPadDRight))/strideD)+1][((H-FH+(zPadHLeft+zPadHRight))/strideH)+1][((W-FW+(zPadWLeft+zPadWRight))/strideW)+1][CO] outArr) +{ + int32_pl outD = ((D-FD+(zPadDLeft+zPadDRight))/strideD)+1; + int32_pl outH = ((H-FH+(zPadHLeft+zPadHRight))/strideH)+1; + int32_pl outW = ((W-FW+(zPadWLeft+zPadWRight))/strideW)+1; + + Conv3DLoop(N, D, H, W, CI, FD, FH, FW, CO, zPadDLeft, zPadDRight, zPadHLeft, zPadHRight, zPadWLeft, zPadWRight, strideD, strideH, strideW, outD, outH, outW, inputArr, filterArr, consSF, outArr); +} (**************************) def void ConvTranspose3DLoop(int32_pl N, int32_pl D, int32_pl H, int32_pl W, int32_pl CI, @@ -171,7 +211,22 @@ def void ConvTranspose3DLoop(int32_pl N, int32_pl D, int32_pl H, int32_pl W, int }; }; } - +(* int32_al[N][DPrime][HPrime][WPrime][CI] inputArr, + int32_al[FD][FH][FW][CO][CI] filter, + int32_al[N][D][H][W][CO] outputArr +*) +def void ConvTranspose3DCSFLoop(int32_pl N, int32_pl DPrime, int32_pl HPrime, int32_pl WPrime, int32_pl CI, + int32_pl FD, int32_pl FH, int32_pl FW, int32_pl CO, + int32_pl D, int32_pl H, int32_pl W, + int32_pl zPadTrDLeft, int32_pl zPadTrDRight, int32_pl zPadTrHLeft, int32_pl zPadTrHRight, int32_pl zPadTrWLeft, int32_pl zPadTrWRight, + int32_pl strideD, int32_pl strideH, int32_pl strideW, + int32_al[N][DPrime][HPrime][WPrime][CI] inputArr, + int32_al[FD][FH][FW][CO][CI] filterArr, + int32_pl consSF, + int32_al[N][D][H][W][CO] outArr) +{ + ConvTranspose3DLoop(N, DPrime, HPrime, WPrime, CI, FD, FH, FW, CO, zPadTrDLeft, zPadTrDRight, zPadTrHLeft, zPadTrHRight, zPadTrWLeft, zPadTrWRight, strideD, strideH, strideW, D, H, W, inputArr, filterArr, consSF, outArr); +} (**************************) def void ArgMax1(int32_pl outArrS1, int32_pl inArrS1, int32_pl inArrS2, int32_al[inArrS1][inArrS2] inArr, int32_pl dim, int32_al[outArrS1] outArr){ @@ -594,4 +649,4 @@ def void StartComputation() def void EndComputation() { return; -} \ No newline at end of file +} diff --git a/Athos/TFEzPCLibrary/Library64_common.ezpc b/Athos/TFEzPCLibrary/Library64_common.ezpc index 9eda438..eb88b55 100644 --- a/Athos/TFEzPCLibrary/Library64_common.ezpc +++ b/Athos/TFEzPCLibrary/Library64_common.ezpc @@ -460,20 +460,6 @@ def void Conv2DCSF(int32_pl N, int32_pl H, int32_pl W, int32_pl CI, int64_al[N][((H-FH+zPadHLeft+zPadHRight)/strideH)+1][((W-FW+zPadWLeft+zPadWRight)/strideW)+1][CO] outArr *) -def void Conv2DCSFLoop(int32_pl N, int32_pl H, int32_pl W, int32_pl CI, - int32_pl FH, int32_pl FW, int32_pl CO, - int32_pl zPadHLeft, int32_pl zPadHRight, int32_pl zPadWLeft, int32_pl zPadWRight, - int32_pl strideH, int32_pl strideW, int32_pl G, - int64_al[N][H][W][CI] inputArr, - int64_al[FH][FW][CI][CO] filterArr, - int32_pl consSF, - int64_al[N][((H-FH+(zPadHLeft+zPadHRight))/strideH)+1][((W-FW+(zPadWLeft+zPadWRight))/strideW)+1][CO] outArr) -{ - int32_pl outH = ((H-FH+(zPadHLeft+zPadHRight))/strideH)+1; - int32_pl outW = ((W-FW+(zPadWLeft+zPadWRight))/strideW)+1; - - Conv2DLoop(N, H, W, CI, FH, FW, CO, zPadHLeft, zPadHRight, zPadWLeft, zPadWRight, strideH, strideW, outH, outW, G, inputArr, filterArr, consSF, outArr); -} (**************************) (* Generic implementation of Conv2D with Groups *) @@ -669,26 +655,6 @@ def void Conv3DReshapeInput(int32_pl N, int32_pl D, int32_pl H, int32_pl W, int3 }; } -(* int64_al[N][D][H][W][CI] inputArr, - int64_al[FD][FH][FW][CI][CO] filterArr, - int64_al[N][((D-FD+zPadDLeft+zPadDRight)/strideD)+1][((H-FH+zPadHLeft+zPadHRight)/strideH)+1][((W-FW+zPadWLeft+zPadWRight)/strideW)+1][CO] outArr -*) -(* Loop implementation of convolution run faster with multithreadin *) -def void Conv3DCSFLoop(int32_pl N, int32_pl D, int32_pl H, int32_pl W, int32_pl CI, - int32_pl FD, int32_pl FH, int32_pl FW, int32_pl CO, - int32_pl zPadDLeft, int32_pl zPadDRight, int32_pl zPadHLeft, int32_pl zPadHRight, int32_pl zPadWLeft, int32_pl zPadWRight, - int32_pl strideD, int32_pl strideH, int32_pl strideW, - int64_al[N][D][H][W][CI] inputArr, - int64_al[FD][FH][FW][CI][CO] filterArr, - int32_pl consSF, - int64_al[N][((D-FD+(zPadDLeft+zPadDRight))/strideD)+1][((H-FH+(zPadHLeft+zPadHRight))/strideH)+1][((W-FW+(zPadWLeft+zPadWRight))/strideW)+1][CO] outArr) -{ - int32_pl outD = ((D-FD+(zPadDLeft+zPadDRight))/strideD)+1; - int32_pl outH = ((H-FH+(zPadHLeft+zPadHRight))/strideH)+1; - int32_pl outW = ((W-FW+(zPadWLeft+zPadWRight))/strideW)+1; - - Conv3DLoop(N, D, H, W, CI, FD, FH, FW, CO, zPadDLeft, zPadDRight, zPadHLeft, zPadHRight, zPadWLeft, zPadWRight, strideD, strideH, strideW, outD, outH, outW, inputArr, filterArr, consSF, outArr); -} (* int64_al[N][D][H][W][CI] inputArr, int64_al[FD][FH][FW][CI][CO] filterArr, @@ -1015,24 +981,6 @@ def void ConvTranspose3DReshapeInput(int32_pl N, int32_pl DPrime, int32_pl HPrim }; }; } - -(* int64_al[N][DPrime][HPrime][WPrime][CI] inputArr, - int64_al[FD][FH][FW][CO][CI] filter, - int64_al[N][D][H][W][CO] outputArr -*) -def void ConvTranspose3DCSFLoop(int32_pl N, int32_pl DPrime, int32_pl HPrime, int32_pl WPrime, int32_pl CI, - int32_pl FD, int32_pl FH, int32_pl FW, int32_pl CO, - int32_pl D, int32_pl H, int32_pl W, - int32_pl zPadTrDLeft, int32_pl zPadTrDRight, int32_pl zPadTrHLeft, int32_pl zPadTrHRight, int32_pl zPadTrWLeft, int32_pl zPadTrWRight, - int32_pl strideD, int32_pl strideH, int32_pl strideW, - int64_al[N][DPrime][HPrime][WPrime][CI] inputArr, - int64_al[FD][FH][FW][CO][CI] filterArr, - int32_pl consSF, - int64_al[N][D][H][W][CO] outArr) -{ - ConvTranspose3DLoop(N, DPrime, HPrime, WPrime, CI, FD, FH, FW, CO, zPadTrDLeft, zPadTrDRight, zPadTrHLeft, zPadTrHRight, zPadTrWLeft, zPadTrWRight, strideD, strideH, strideW, D, H, W, inputArr, filterArr, consSF, outArr); -} - (* int64_al[N][DPrime][HPrime][WPrime][CI] inputArr, int64_al[FD][FH][FW][CO][CI] filter, int64_al[N][D][H][W][CO] outputArr @@ -1081,4 +1029,4 @@ def void ClearMemPublic4(int32_pl s1, int32_pl s2, int32_pl s3, int32_pl s4, int def void ClearMemPublic5(int32_pl s1, int32_pl s2, int32_pl s3, int32_pl s4, int32_pl s5, int32_pl[s1][s2][s3][s4][s5] arr) { return; -} \ No newline at end of file +} diff --git a/Athos/TFEzPCLibrary/Library64_cpp.ezpc b/Athos/TFEzPCLibrary/Library64_cpp.ezpc index 3c2a6e0..c9afe6f 100644 --- a/Athos/TFEzPCLibrary/Library64_cpp.ezpc +++ b/Athos/TFEzPCLibrary/Library64_cpp.ezpc @@ -83,6 +83,25 @@ def void Conv2DLoop(int32_pl N, int32_pl H, int32_pl W, int32_pl CI, }; }; } +(* int64_al[N][H][W][CI] inputArr, + int64_al[FH][FW][CI][CO] filterArr, + int64_al[N][((H-FH+zPadHLeft+zPadHRight)/strideH)+1][((W-FW+zPadWLeft+zPadWRight)/strideW)+1][CO] outArr +*) + +def void Conv2DCSFLoop(int32_pl N, int32_pl H, int32_pl W, int32_pl CI, + int32_pl FH, int32_pl FW, int32_pl CO, + int32_pl zPadHLeft, int32_pl zPadHRight, int32_pl zPadWLeft, int32_pl zPadWRight, + int32_pl strideH, int32_pl strideW, int32_pl G, + int64_al[N][H][W][CI] inputArr, + int64_al[FH][FW][CI][CO] filterArr, + int32_pl consSF, + int64_al[N][((H-FH+(zPadHLeft+zPadHRight))/strideH)+1][((W-FW+(zPadWLeft+zPadWRight))/strideW)+1][CO] outArr) +{ + int32_pl outH = ((H-FH+(zPadHLeft+zPadHRight))/strideH)+1; + int32_pl outW = ((W-FW+(zPadWLeft+zPadWRight))/strideW)+1; + + Conv2DLoop(N, H, W, CI, FH, FW, CO, zPadHLeft, zPadHRight, zPadWLeft, zPadWRight, strideH, strideW, outH, outW, G, inputArr, filterArr, consSF, outArr); +} (**************************) def void Conv3DLoop(int32_pl N, int32_pl D, int32_pl H, int32_pl W, int32_pl CI, @@ -126,6 +145,26 @@ def void Conv3DLoop(int32_pl N, int32_pl D, int32_pl H, int32_pl W, int32_pl CI, }; } +(* int64_al[N][D][H][W][CI] inputArr, + int64_al[FD][FH][FW][CI][CO] filterArr, + int64_al[N][((D-FD+zPadDLeft+zPadDRight)/strideD)+1][((H-FH+zPadHLeft+zPadHRight)/strideH)+1][((W-FW+zPadWLeft+zPadWRight)/strideW)+1][CO] outArr +*) +(* Loop implementation of convolution run faster with multithreadin *) +def void Conv3DCSFLoop(int32_pl N, int32_pl D, int32_pl H, int32_pl W, int32_pl CI, + int32_pl FD, int32_pl FH, int32_pl FW, int32_pl CO, + int32_pl zPadDLeft, int32_pl zPadDRight, int32_pl zPadHLeft, int32_pl zPadHRight, int32_pl zPadWLeft, int32_pl zPadWRight, + int32_pl strideD, int32_pl strideH, int32_pl strideW, + int64_al[N][D][H][W][CI] inputArr, + int64_al[FD][FH][FW][CI][CO] filterArr, + int32_pl consSF, + int64_al[N][((D-FD+(zPadDLeft+zPadDRight))/strideD)+1][((H-FH+(zPadHLeft+zPadHRight))/strideH)+1][((W-FW+(zPadWLeft+zPadWRight))/strideW)+1][CO] outArr) +{ + int32_pl outD = ((D-FD+(zPadDLeft+zPadDRight))/strideD)+1; + int32_pl outH = ((H-FH+(zPadHLeft+zPadHRight))/strideH)+1; + int32_pl outW = ((W-FW+(zPadWLeft+zPadWRight))/strideW)+1; + + Conv3DLoop(N, D, H, W, CI, FD, FH, FW, CO, zPadDLeft, zPadDRight, zPadHLeft, zPadHRight, zPadWLeft, zPadWRight, strideD, strideH, strideW, outD, outH, outW, inputArr, filterArr, consSF, outArr); +} (**************************) def void ConvTranspose3DLoop(int32_pl N, int32_pl D, int32_pl H, int32_pl W, int32_pl CI, @@ -171,7 +210,22 @@ def void ConvTranspose3DLoop(int32_pl N, int32_pl D, int32_pl H, int32_pl W, int }; }; } - +(* int64_al[N][DPrime][HPrime][WPrime][CI] inputArr, + int64_al[FD][FH][FW][CO][CI] filter, + int64_al[N][D][H][W][CO] outputArr +*) +def void ConvTranspose3DCSFLoop(int32_pl N, int32_pl DPrime, int32_pl HPrime, int32_pl WPrime, int32_pl CI, + int32_pl FD, int32_pl FH, int32_pl FW, int32_pl CO, + int32_pl D, int32_pl H, int32_pl W, + int32_pl zPadTrDLeft, int32_pl zPadTrDRight, int32_pl zPadTrHLeft, int32_pl zPadTrHRight, int32_pl zPadTrWLeft, int32_pl zPadTrWRight, + int32_pl strideD, int32_pl strideH, int32_pl strideW, + int64_al[N][DPrime][HPrime][WPrime][CI] inputArr, + int64_al[FD][FH][FW][CO][CI] filterArr, + int32_pl consSF, + int64_al[N][D][H][W][CO] outArr) +{ + ConvTranspose3DLoop(N, DPrime, HPrime, WPrime, CI, FD, FH, FW, CO, zPadTrDLeft, zPadTrDRight, zPadTrHLeft, zPadTrHRight, zPadTrWLeft, zPadTrWRight, strideD, strideH, strideW, D, H, W, inputArr, filterArr, consSF, outArr); +} (**************************) def void ArgMax1(int32_pl outArrS1, int32_pl inArrS1, int32_pl inArrS2, int64_al[inArrS1][inArrS2] inArr, int32_pl dim, int64_al[outArrS1] outArr){ @@ -594,4 +648,4 @@ def void StartComputation() def void EndComputation() { return; -} \ No newline at end of file +} diff --git a/Athos/TFEzPCLibrary/Library64_porthos.ezpc b/Athos/TFEzPCLibrary/Library64_porthos.ezpc index cc1cf66..f092141 100644 --- a/Athos/TFEzPCLibrary/Library64_porthos.ezpc +++ b/Athos/TFEzPCLibrary/Library64_porthos.ezpc @@ -76,7 +76,7 @@ extern void ClearMemSecret1(int32_pl s1, int64_al[s1] arr); extern void ClearMemSecret2(int32_pl s1, int32_pl s2, int64_al[s1][s2] arr); extern void ClearMemSecret3(int32_pl s1, int32_pl s2, int32_pl s3, int64_al[s1][s2][s3] arr); extern void ClearMemSecret4(int32_pl s1, int32_pl s2, int32_pl s3, int32_pl s4, int64_al[s1][s2][s3][s4] arr); -extern void ClearMemSecret5(int32_pl s1, int32_pl s2, int32_pl s3, int32_pl s4, int32_pl s5, int64_al[s1][s2][s3][s4][s5] arr) +extern void ClearMemSecret5(int32_pl s1, int32_pl s2, int32_pl s3, int32_pl s4, int32_pl s5, int64_al[s1][s2][s3][s4][s5] arr); extern void ClearMemPublic2(int32_pl s1, int32_pl s2, int32_pl[s1][s2] arr); -- GitLab