diff --git a/Athos/CompileTFGraph.py b/Athos/CompileTFGraph.py index 71d893ec645cd05f553c7601bfe2657e55c2ec1c..8ee57c84ead36739a1c53c0de9f61ac6946ef949 100644 --- a/Athos/CompileTFGraph.py +++ b/Athos/CompileTFGraph.py @@ -56,10 +56,7 @@ Config file should be a json in the following format: args = parser.parse_args() return args - -if __name__ == "__main__": - args = parse_args() - params = parse_config.get_params(args.config) +def generate_code(params): # Mandatory model_name = params["model_name"] input_tensor_info = params["input_tensors"] @@ -79,11 +76,12 @@ if __name__ == "__main__": "CPPRING", ], "Target must be any of ABY/CPP/CPPRING/PORTHOS/PORTHOS2PC" + cwd = os.getcwd() athos_dir = os.path.dirname(os.path.abspath(__file__)) model_abs_path = os.path.abspath(model_name) model_abs_dir = os.path.dirname(model_abs_path) # Generate graphdef and sizeInfo metadata - compile_tf.compile( + weights_path = compile_tf.compile( model_name, input_tensor_info, output_tensors, scale, save_weights ) @@ -91,7 +89,7 @@ if __name__ == "__main__": Athos.process_tf_graph(model_abs_path) # Compile to ezpc - model_base_name = model_name[:-3] + model_base_name = os.path.basename(model_abs_path)[:-3] ezpc_file_name = "{mname}_{bl}_{target}.ezpc".format( mname=model_base_name, bl=bitlength, target=target.lower() ) @@ -160,11 +158,12 @@ if __name__ == "__main__": modulo = params["modulo"] backend = "OT" if params["backend"] is None else params["backend"] ezpc_dir = os.path.join(athos_dir, "../EzPC/EzPC/") + # Copy generated code to the ezpc directory os.system("cp {ezpc} {ezpc_dir}".format(ezpc=ezpc_abs_path, ezpc_dir=ezpc_dir)) os.chdir(ezpc_dir) ezpc_args = "" - ezpc_args += "--bitlen {bl} --codegen {target} --disable-tac".format( - bl=lib_bitlength, target=target + ezpc_args += "--bitlen {bl} --codegen {target} --disable-tac ".format( + bl=bitlength, target=target ) output_name = ezpc_file_name[:-5] + "0.cpp" if modulo is not None: @@ -174,9 +173,59 @@ if __name__ == "__main__": output_name = ezpc_file_name[:-5] + "_{}0.cpp".format(backend.upper()) if target in ["PORTHOS"]: ezpc_args += "--sf {} ".format(scale) + os.system( "eval `opam config env`; ./ezpc.sh {} ".format(ezpc_file_name) + ezpc_args ) os.system( "cp {output} {model_dir} ".format(output=output_name, model_dir=model_abs_dir) ) + output_file = os.path.join(model_abs_dir, output_name) + + if target == "PORTHOS2PC": + program_name = model_base_name + "_" + target + "_" + backend + ".out" + else: + program_name = model_base_name + "_" + target + ".out" + program_path = os.path.join(model_abs_dir, program_name) + os.chdir(model_abs_dir) + if target in [ "CPP", "CPPRING"]: + os.system( + "g++ -O3 -w {file} -o {output}".format(file=output_file, output=program_path) + ) + elif target == "PORTHOS": + porthos_src = os.path.join(athos_dir, "..", "Porthos", "src") + porthos_lib = os.path.join(porthos_src, "build", "lib") + if os.path.exists(porthos_lib): + os.system( + """g++ -O3 -fopenmp -pthread -w -march=native -msse4.1 -maes -mpclmul \ + -mrdseed -fpermissive -fpic -std=c++17 -L {porthos_lib} -I {porthos_headers} {file} \ + -lPorthos-Protocols -lssl -lcrypto -lrt -lboost_system \ + -o {output}""".format(porthos_lib=porthos_lib, porthos_headers=porthos_src, + file=output_file, output=program_path) + ) + else: + print("Not compiling generated code. Please follow the readme and build Porthos.") + elif target == "PORTHOS2PC": + sci = os.path.join(athos_dir, "..", "SCI") + sci_src = os.path.join(sci, "src") + sci_lib = os.path.join(sci, "build", "lib") + eigen_path = os.path.join(sci, "extern", "eigen") + seal_lib_path = os.path.join(sci, "extern", "SEAL", "native", "lib") + if os.path.exists(sci_lib): + os.system( + """g++ -O3 -fpermissive -pthread -w -maes -msse4.1 -mavx -mavx2 -mrdseed \ + -faligned-new -std=c++17 -fopenmp -I {eigen} -I {sci_src} {file} \ + -L {sci_lib} -lSCI-LinearHE -L {seal} -lseal -lssl -lcrypto \ + -o {output}""".format(eigen=eigen_path, sci_src=sci_src, + file=output_file,sci_lib=sci_lib,seal=seal_lib_path, output=program_path) + ) + else: + print("Not compiling generated code. Please follow the readme and build SCI.") + + os.chdir(cwd) + return (program_path, weights_path) + +if __name__ == "__main__": + args = parse_args() + params = parse_config.get_params(args.config) + generate_code(params) \ No newline at end of file