Skip to content
Snippets Groups Projects
Commit 75469ab4 authored by Bhatu's avatar Bhatu
Browse files

Add scripts to compile tensorflow protobuf dumped models

parent 1cb4cf94
No related branches found
No related tags found
No related merge requests found
......@@ -47,7 +47,7 @@ usage() {
echo -e "<--disable-liveness-opti> :: Disable Liveness Optimization."
echo -e "<--disable-trunc-opti> :: Disable truncation placement optimization."
echo -e "<--exec-python> <num of args for python script> <args for python script>... :: Execute the python script which is passed for compilation.";
echo -e "<--help> :: help options.";
echo -e "<-h|--help> :: help options.";
exit 1;
}
......@@ -99,7 +99,7 @@ do
done
EXECPYTHON=Y
;;
--help)
-h|--help)
HELP=Y
shift # past one arg
;;
......
#!/bin/bash
# Authors: Nishant Kumar, Pratik Bhatu.
# Copyright:
# Copyright (c) 2018 Microsoft Research
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
##########################################################################################
# This is the CrypTFlow compilation script.
# Use this on a network to compile to MPC protocol.
# By default, this assumes there is a ezpc repo one level up - if you want to change it,
# please use Paths.config to override the default paths.
# Same goes for Porthos repository.
# NOTE : When overriding paths in Paths.config, assumption is there is no '/' at the end.
##########################################################################################
# Load overriden paths from config file
. Paths.config
echo -e "Loaded paths: EzPCDir - $EzPCDir, PorthosDir - $PorthosDir"
usage() {
echo -e "CrypTFlow compilation script. Options:";
echo -e "<-b|--bitlen> <bitlen> :: Bit length to compile for. Defaults to 64";
echo -e "<-s|--scaling-fac> <sf> :: Scaling factor to compile for. Defaults to 12.";
echo -e "<-t|--target> <target> :: Compilation target. Possible options: ABY/CPP/CPPRING/PORTHOS/PORTHOS2PC. Defaults to CPP.";
echo -e "<-f|--filename> :: Tensorflow protobuf file to compile."
echo -e "<--modulo> :: Modulo to be used for shares. Applicable for CPPRING/PORTHOS2PC backend. For PORTHOS2PC, for backend type OT, this should be power of 2 and for backend type HE, this should be a prime."
echo -e "<--backend> :: Backend to be used - OT/HE (default OT). Applicable for PORTHOS2PC backend."
echo -e "<--disable-hlil-all-opti> :: Disable all optimizations in HLIL."
echo -e "<--disable-rmo> :: Disable Relu-Maxpool optimization."
echo -e "<--disable-liveness-opti> :: Disable Liveness Optimization."
echo -e "<--disable-trunc-opti> :: Disable truncation placement optimization."
echo -e "<-h|--help> :: help options.";
exit 1;
}
BITLEN="64"
SCALINGFACTOR="12"
COMPILATIONTARGET="CPP"
EXECPYTHONARGS=""
while [[ $# -gt 0 ]]
do
key="$1"
case $key in
-b|--bitlen)
BITLEN="$2"
shift # past argument
shift # past value
;;
-s|--scaling-fac)
SCALINGFACTOR="$2"
shift # past argument
shift # past value
;;
-t|--target)
COMPILATIONTARGET="$2"
shift # past argument
shift # past value
;;
-f|--filename)
FILENAME="$2"
shift
shift
;;
--modulo)
MODULO="$2"
shift
shift
;;
--backend)
BACKEND="$2"
shift
shift
;;
-h|--help)
HELP=Y
shift # past one arg
;;
--disable-hlil-all-opti)
DisableHLILAllOpti=Y
shift # past one arg
;;
--disable-rmo)
DisableRMO=Y
shift # past one arg
;;
--disable-liveness-opti)
DisableLivenessOpti=Y
shift # past one arg
;;
--disable-trunc-opti)
DisableTruncOpti=Y
shift # past one arg
;;
*) # unknown option
usage
;;
esac
done
if [ ! -z "$HELP" ] || [ -z "$FILENAME" ] ; then
usage
fi
ACTUALBITLEN="${BITLEN}"
if [ "$ACTUALBITLEN" -gt 32 ]; then
BITLEN="64"
else
BITLEN="32"
fi
compilationTargetLower=$(echo "$COMPILATIONTARGET" | awk '{print tolower($0)}')
compilationTargetHigher=$(echo "$COMPILATIONTARGET" | awk '{print toupper($0)}')
givenDirPath=$(dirname "$FILENAME")
fullDirPath=$(realpath "$givenDirPath")
porthosFullDirPath=$( realpath "$PorthosDir")
baseFileName=$(basename -- "$FILENAME")
extension="${baseFileName##*.}"
actualFileName="${baseFileName%.*}" #without extension
fullFilePath=$(realpath "$FILENAME")
ezpcOutputFileName=${actualFileName}'_'${BITLEN}'_'${compilationTargetLower}
ezpcOutputFullFileName=${fullDirPath}'/'${ezpcOutputFileName}'.ezpc'
finalCodeOutputFileName=${ezpcOutputFileName}'0.cpp'
if [ "$extension" != "pb" ]; then
echo -e "Error: Provide a tensorflow pb file to compile."
usage
fi
cd "$fullDirPath"
cd - > /dev/null
cd ./TFCompiler
python3 ProcessTFGraph.py "$fullFilePath"
cd ../SeeDot
seedotArgs="--astFile ${fullDirPath}/astOutput.pkl --consSF ${SCALINGFACTOR} --bitlen ${ACTUALBITLEN} --outputFileName ${ezpcOutputFullFileName}"
#Temporarily always disable trunc optimization. TODO: Remove when fixed.
DisableTruncOpti=Y
if [ ! -z "$DisableHLILAllOpti" ]; then
seedotArgs="${seedotArgs} --disableAllOpti True"
fi
if [ ! -z "$DisableRMO" ]; then
seedotArgs="${seedotArgs} --disableRMO True"
fi
if [ ! -z "$DisableLivenessOpti" ]; then
seedotArgs="${seedotArgs} --disableLivenessOpti True"
fi
if [ ! -z "$DisableTruncOpti" ]; then
seedotArgs="${seedotArgs} --disableTruncOpti True"
fi
python3 SeeDot.py $seedotArgs
cd ..
libraryFile="$compilationTargetLower"
if [ "$compilationTargetLower" == "aby" ] || [ "$compilationTargetLower" == "cppring" ] ; then
libraryFile="cpp"
fi
if [ "$libraryFile" == "cpp" ];then
# CPP/ABY backend
cat "./TFEzPCLibrary/Library${BITLEN}_${libraryFile}_pre.ezpc" "./TFEzPCLibrary/Library${BITLEN}_common.ezpc" "./TFEzPCLibrary/Library${BITLEN}_${libraryFile}_post.ezpc" "$ezpcOutputFullFileName" > temp
else
cat "./TFEzPCLibrary/Library${BITLEN}_${libraryFile}.ezpc" "./TFEzPCLibrary/Library${BITLEN}_common.ezpc" "$ezpcOutputFullFileName" > temp
fi
mv temp "$ezpcOutputFullFileName"
cp "$ezpcOutputFullFileName" "$EzPCDir/EzPC"
cd "$EzPCDir/EzPC"
eval `opam config env`
ezpcArgs="--bitlen ${ACTUALBITLEN} --codegen ${compilationTargetHigher} --disable-tac"
if [ ! -z "$MODULO" ]; then
ezpcArgs="${ezpcArgs} --modulo ${MODULO}"
fi
if [ ! -z "$BACKEND" ]; then
backendUpper=$(echo "$BACKEND" | awk '{print toupper($0)}')
ezpcArgs="${ezpcArgs} --backend ${backendUpper}"
finalCodeOutputFileName=${ezpcOutputFileName}_${backendUpper}'0.cpp'
fi
if [ "$compilationTargetLower" == "porthos" ] ; then
ezpcArgs="${ezpcArgs} --sf ${SCALINGFACTOR}"
fi
./ezpc.sh "$ezpcOutputFullFileName" ${ezpcArgs}
if [ "$compilationTargetLower" == "cpp" ] || [ "$compilationTargetLower" == "cppring" ] ; then
cd "$fullDirPath"
g++ -O3 "$finalCodeOutputFileName" -o "$actualFileName.out"
echo -e "All compilation done."
else
cd - > /dev/null
echo -e "All compilation done."
if hash clang-format 2> /dev/null; then
clang-format -style=LLVM $fullDirPath/$finalCodeOutputFileName > tmp_clang
mv tmp_clang $fullDirPath/$finalCodeOutputFileName
fi
fi
import tensorflow as tf
import numpy as np
import argparse
from tf_graph_io import *
from tf_graph_trans import *
import os.path
import sys
sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'TFCompiler'))
import DumpTFMtData
from os import path
def check_operation_exists(graph, tensor_name):
op_list = [i.name for i in graph.get_operations()]
return tensor_name in op_list
def compile(model_fname, input_t_name, output_t_name, scaling_factor, save_weights, input_shape):
if not model_fname.endswith('.pb'):
sys.exit("Please supply a valid tensorflow protobuf model (.pb extension)")
elif not "mpc_processed_" in model_fname:
sys.exit("""Please process model using preprocess_frozen_tf_graph.py.
This will optimise it and generate a new .pb with mpc_processed prefix.
Use that with this script.""")
else:
model_name = os.path.basename(model_fname)[:-3]
print("Loading processed tf graph ", model_fname)
graph = load_pb(model_fname)
if not check_operation_exists(graph, input_t_name):
sys.exit(input_t_name + " input does not exist in the graph")
if not check_operation_exists(graph, output_t_name):
sys.exit(output_t_name + " output does not exist in the graph")
input_t = graph.get_operation_by_name(input_t_name).outputs[0]
output_t = graph.get_operation_by_name(output_t_name).outputs[0]
# Generate random tensor as input
inp_shape = input_t.shape.as_list()
if None in inp_shape:
if input_shape == []:
sys.exit("Please supply shape for the input tensor as it is parametric (? dim) for this model. See --help.")
else:
inp_shape = input_shape
rand_inp_t = np.zeros(inp_shape)
feed_dict = {input_t: rand_inp_t}
with graph.as_default():
with tf.Session() as sess:
# Run initializers generated by preprocessing
if check_operation_exists(graph, 'init_constvars'):
sess.run(graph.get_operation_by_name('init_constvars'))
else:
sess.run(tf.global_variables_initializer())
# Dump sizeInfo, graphDef mtdata and weight dump in model folder.
model_dir = os.path.realpath(os.path.dirname(model_fname))
os.chdir(model_dir)
optimized_graph_def = DumpTFMtData.save_graph_metadata(output_t, sess, feed_dict)
print("Model compilation done.")
trainVarsName = [node.name for node in optimized_graph_def.node if node.op == "VariableV2" or node.op == "Variable"]
trainVars = list(map(lambda x : tf.get_default_graph().get_operation_by_name(x).outputs[0] , trainVarsName))
if save_weights:
DumpTFMtData.updateWeightsForBN(optimized_graph_def, sess)
weights_fname = model_name[len("mpc_processed_"):] + '_input_weights_fixedpt_scale_' + str(scaling_factor) + '.inp'
print("Dumping model weights in ", weights_fname, ". These are to be used as input for party which owns the model")
DumpTFMtData.dumpTrainedWeightsInt(sess, trainVars, weights_fname, scaling_factor, 'w')
def boolean_string(s):
if s not in {'False', 'True'}:
raise ValueError('Not a valid boolean string')
return s == 'True'
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--modelName", required=True, type=str, help="Name of processed tensorflow model (mpc_processed*.pb)")
parser.add_argument("--inputTensorName", required=True, type=str, help="Name of the input tensor for the model. (Op name, dont add '/:0' suffix)")
parser.add_argument("--outputTensorName", required=True, type=str, help="Name of the input tensor for the model. (Op name, dont add '/:0' suffix)")
parser.add_argument("--sf", default=12, type=int, help="scaling factor (int)")
parser.add_argument("--saveWeights", type=boolean_string, default=False, help="Dump model weights in fixedpt {True/False}")
parser.add_argument("--inputTensorShape", type=str, default='', help="Comma separated list of shape for input tensor. eg: \"2,245,234,3\"")
args = parser.parse_args()
return args
def get_shape_list(shape_string):
if shape_string == '':
return []
return [int(i) for i in shape_string.split(",")]
if __name__ == '__main__':
args = parse_args()
shape_list = get_shape_list(args.inputTensorShape)
compile(args.modelName, args.inputTensorName, args.outputTensorName, args.sf, args.saveWeights, shape_list)
'''
Authors: Pratik Bhatu.
Copyright:
Copyright (c) 2020 Microsoft Research
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
'''
import tensorflow as tf
import onnx
from onnx import shape_inference
......
'''
Authors: Pratik Bhatu.
Copyright:
Copyright (c) 2020 Microsoft Research
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
'''
import tensorflow as tf
model_filename = 'chest_xray_covid19_model.h5'
......
from tf_graph_io import *
from tf_graph_trans import *
import sys
import time
import os
# Transpose nodes require perm as compile time constants for parametric codegen
# So we don't eliminate the constants we need dring compile time
def get_const_names(graph):
transp_perm_ops = set(i.inputs[1].op.name for i in graph.get_operations() if i.type == 'Transpose')
padding_ops = set(i.inputs[1].op.name for i in graph.get_operations() if i.type == 'Pad')
slice_begin_ops = set(i.inputs[1].op.name for i in graph.get_operations() if i.type == 'Slice')
slice_size_ops = set(i.inputs[2].op.name for i in graph.get_operations() if i.type == 'Slice')
mean_axes_ops = set(i.inputs[1].op.name for i in graph.get_operations() if i.type == 'Mean')
white_list = transp_perm_ops | padding_ops | slice_begin_ops | slice_size_ops | mean_axes_ops
all_const_ops = set(i.name for i in graph.get_operations() if i.type == 'Const')
return list(all_const_ops - white_list)
if __name__ == "__main__":
if len(sys.argv) != 2:
print("Usage: python preprocess_frozen_tf_graph.py tf_model_name.pb")
sys.exit()
else:
input_fname = sys.argv[1]
actual_fname = os.path.basename(input_fname)
dirname = os.path.dirname(input_fname)
output_fname = os.path.join(dirname, "mpc_processed_" + actual_fname)
print("Loading ", input_fname, "for processing.")
exec_graph = load_pb(input_fname)
print("\n\nThis process will take some time to run as we execute portions of the graph.\n\n")
time.sleep(5)
# Fold away all static computations
print("Running constant folding")
exec_graph = fold_splits(exec_graph)
exec_graph = fold_constants(exec_graph)
# Convert constants to variables so as to separate the data and the generated code
# Otherwise huge arrays will show up as constants in the generated code, thereby
# increasing binary size.
print("Convert frozen constants to variables")
exec_graph = convert_consts_to_var(exec_graph, get_const_names(exec_graph))
# At this stage the graph still has constants embedded in it
# in the assign nodes for variables. We cannot execute the graph without
# these constants. However after inferring the size, we can call remove_dead_nodes
# to optimize away the constants and assign nodes and make the graph amenable
# for codegen
dump_pb(exec_graph, output_fname)
print("The processed graph is dumped in ", output_fname)
import tensorflow as tf
from tensorflow.python.platform import gfile
def display_graph(graph, tensorboard_log_dir):
writer = tf.summary.FileWriter(tensorboard_log_dir, graph)
writer.close()
def load_pb(path_to_pb):
with tf.io.gfile.GFile(path_to_pb, 'rb') as f:
graph_def = tf.compat.v1.GraphDef()
graph_def.ParseFromString(f.read())
with tf.Graph().as_default() as graph:
tf.import_graph_def(graph_def, name="")
return graph
def dump_pb(graph, filename):
with tf.io.gfile.GFile(filename, 'wb') as f:
graph_def = graph.as_graph_def()
f.write(graph_def.SerializeToString())
def save_model(graph, model_name):
with graph.as_default():
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
save_path = tf.train.Saver().save(sess, model_name)
print("Model saved in path: %s" % save_path)
import tensorflow as tf
import tensorflow.contrib.graph_editor as ge
from tensorflow.tools.graph_transforms import TransformGraph
def delete_nodes(graph, ops):
gd = graph.as_graph_def()
nodes_to_delete = set(op.name for op in ops)
new_gd = tf.compat.v1.GraphDef()
nodes_to_keep = []
for n in gd.node:
if not n.name in nodes_to_delete:
nodes_to_keep.append(n)
new_gd.node.extend(nodes_to_keep)
new_graph = tf.Graph()
with new_graph.as_default():
tf.import_graph_def(new_gd, name="")
return new_graph
def remove_dead_nodes(graph, input_tensors, output_tensors):
transforms = ['remove_nodes(op=Identity)', 'strip_unused_nodes']
in_list = [i.name for i in input_tensors]
out_list = [i.name for i in output_tensors]
optimized_graph_def = TransformGraph(graph.as_graph_def(), in_list, out_list, transforms)
with tf.Graph().as_default() as opt_graph:
tf.import_graph_def(optimized_graph_def, name="")
return opt_graph
def convert_consts_to_var(graph, const_names_list):
const_var_names_pairs = []
ops_to_delete = []
with graph.as_default():
var_list = []
for name in const_names_list:
#tensor = graph.get_tensor_by_name('{}:0'.format(name))
tensor = graph.get_operation_by_name(name).outputs[0]
with tf.Session() as sess:
t_value = sess.run(tensor)
t_name = '{}_const_var'.format(name)
var = tf.Variable(t_value, name=t_name)
const_var_names_pairs.append((name, t_name))
var_list.append(var)
for const_name, var_name in const_var_names_pairs:
const_op = graph.get_operation_by_name(const_name)
var_op = graph.get_operation_by_name('{}/read'.format(var_name))
ge.swap_outputs(ge.sgv(const_op), ge.sgv(var_op))
ops_to_delete.append(const_op)
tf.compat.v1.variables_initializer(var_list, 'init_constvars')
return delete_nodes(graph, ops_to_delete)
def get_inputs(op):
return set(input.op for input in op.inputs)
def replace_node_with_const(node):
print("Trying to execute node {}".format(node.name))
graph = node.graph
with graph.as_default():
const_lists = []
with tf.Session() as sess:
for out_t in node.outputs:
const_val = sess.run(out_t)
const_op = tf.constant(const_val).op
const_lists.append(const_op)
ge.swap_outputs(ge.sgv(node), ge.sgv(const_lists))
def DFS(node, visited, const_map, deleted_nodes):
print("Visiting node {}".format(node.name))
visited.add(node)
if node.type == "Const":
const_map[node.name] = True
return True
if len(node.inputs) == 0:
const_map[node.name] = False
return False
for inp_node in get_inputs(node):
if not inp_node in visited:
isConst = DFS(inp_node, visited, const_map, deleted_nodes)
const_map[inp_node.name] = isConst
all_inputs_const = True
for inp_node in get_inputs(node):
all_inputs_const = all_inputs_const and const_map[inp_node.name]
if all_inputs_const:
const_map[node.name] = True
replace_node_with_const(node)
deleted_nodes.add(node)
return True
const_map[node.name] = False
return False
def get_dangling_consts_old(graph):
consts = [ i for i in graph.get_operations() if i.type == 'Const' ]
def has_users(op):
for i in op.outputs:
for j in i.consumers():
if j.type != 'Const':
return True
return False
return [ i for i in consts if not has_users(i)]
def get_dangling_consts(graph, deleted_nodes):
consts = [ i for i in graph.get_operations() if i.type == 'Const' ]
def has_users(op):
for i in op.outputs:
for j in i.consumers():
if j.type != 'Const' and (j not in deleted_nodes):
return True
return False
return [ i for i in consts if not has_users(i)]
def fold_constants(graph):
visited = set({})
const_map = {}
deleted_nodes = set({})
with graph.as_default():
for node in graph.get_operations():
if not node in visited:
isConst = DFS(node, visited, const_map, deleted_nodes)
if isConst:
replace_node_with_const(node)
deleted_nodes.add(node)
useless_consts = get_dangling_consts(graph, deleted_nodes)
print("No. of consts to be removed = {}".format(len(useless_consts)))
deleted_nodes.update(useless_consts)
graph = delete_nodes(graph, deleted_nodes)
consts = [ i for i in graph.get_operations() if i.type == 'Const' ]
print("No. of total consts still remaining = {}".format(len(consts)))
dang_consts = get_dangling_consts_old(graph)
print("No. of dang consts still remaining = {}".format(len(dang_consts)))
return graph
def replace_nodes_with_identity(graph, nop_splits):
with graph.as_default():
for split in nop_splits:
inp_var = split.inputs[1]
identity = tf.identity(inp_var).op
ge.swap_outputs(ge.sgv(split), ge.sgv(identity))
return graph
def fold_splits(graph):
with graph.as_default():
nop_splits = []
for node in graph.get_operations():
if node.type != "Split":
continue
if node.get_attr("num_split") == 1:
nop_splits.append(node)
replace_nodes_with_identity(graph, nop_splits)
graph = delete_nodes(graph, set(nop_splits))
return graph
......@@ -18,7 +18,9 @@ The codebase is organized as follows:
- `TFCompiler`: This contains python modules which are called from the TensorFlow code for the dumping of TensorFlow metadata (required by Athos for compilation to MPC protocols).
- `TFEzPCLibrary`: This contains library code written in EzPC for the TensorFlow nodes required during compilation.
- `CompileTF.sh`: The Athos compilation script. Try `./CompileTF.sh --help` for options.
- `CompileTFGraph.sh`: The Athos compilation script for protobuf models. Try `./CompileTFGraph.sh --help` for options.
- `Paths.config`: This can be used to override the default folders for EzPC and Porthos.
- `CompilerScripts`: This folder contains scripts used for processing and compiling dumped models.
# Usage
Here we provide an example on how to use Athos to compile TensorFlow based ResNet-50 code to Porthos semi-honest 3PC protocol and subsequently run it. The relevant TensorFlow code for ResNet-50 can be found in `./Networks/ResNet/ResNet_main.py`.
......
......@@ -73,7 +73,7 @@ def save_graph_metadata(output_tensor, sess, feed_dict):
return optimized_graph_def
def updateWeightsForBN(optimized_graph_def, sess, feed_dict):
def updateWeightsForBN(optimized_graph_def, sess, feed_dict={}):
def findNodeInGraphDefWithName(graphDef, curName):
for curNode in graphDef.node:
if curNode.name == curName:
......@@ -82,18 +82,18 @@ def updateWeightsForBN(optimized_graph_def, sess, feed_dict):
print("Updating weights for BN...")
graph = tf.get_default_graph()
graph = sess.graph
graphDef = optimized_graph_def
for node in graphDef.node:
if (node.op == 'FusedBatchNorm'):
print("node.name = {0}".format(node.name))
if (node.op == 'FusedBatchNorm' or node.op == 'FusedBatchNormV3'):
print("Updating BN weight, node.name = {0}".format(node.name))
gamma = graph.get_operation_by_name(node.input[1]).outputs[0]
beta = graph.get_operation_by_name(node.input[2]).outputs[0]
mu = graph.get_operation_by_name(node.input[3]).outputs[0]
variance = graph.get_operation_by_name(node.input[4]).outputs[0]
epsilon = 1e-5 # Taken from non-fused BN of TF
epsilon = node.attr['epsilon'].f
rsigma = tf.rsqrt(variance + epsilon)
sess.run(tf.assign(gamma, gamma*rsigma))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment