get_pred_tf_graph.py 3.21 KiB
import tensorflow as tf
import numpy as np
import argparse
from tf_graph_io import *
from tf_graph_trans import *
import os.path
import sys
sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'TFCompiler'))
import DumpTFMtData
from os import path
def check_operation_exists(graph, tensor_name):
op_list = [i.name for i in graph.get_operations()]
return tensor_name in op_list
def numpy_float_array_to_float_val_str(input_array):
chunk = ''
for val in np.nditer(input_array):
chunk += str(val) + '\n'
return chunk
def compile(model_fname, input_t_name, output_t_name, input_np_arr, output_fname):
if not model_fname.endswith('.pb'):
sys.exit("Please supply a valid tensorflow protobuf model (.pb extension)")
elif not "mpc_processed_" in model_fname:
sys.exit("""Please process model using preprocess_frozen_tf_graph.py.
This will optimise it and generate a new .pb with mpc_processed prefix.
Use that with this script.""")
else:
model_name = os.path.basename(model_fname)[:-3]
print("Loading processed tf graph ", model_fname)
graph = load_pb(model_fname)
if not check_operation_exists(graph, input_t_name):
sys.exit(input_t_name + " input does not exist in the graph")
if not check_operation_exists(graph, output_t_name):
sys.exit(output_t_name + " output does not exist in the graph")
if not os.path.isfile(input_np_arr):
sys.exit(input_np_arr + " file does not exist.")
input_t = graph.get_operation_by_name(input_t_name).outputs[0]
output_t = graph.get_operation_by_name(output_t_name).outputs[0]
np_input_t = np.load(input_np_arr, allow_pickle=True)
feed_dict = {input_t: np_input_t}
with graph.as_default():
with tf.Session() as sess:
# Run initializers generated by preprocessing
if check_operation_exists(graph, 'init_constvars'):
sess.run(graph.get_operation_by_name('init_constvars'))
else:
sess.run(tf.global_variables_initializer())
model_dir = os.path.realpath(os.path.dirname(model_fname))
os.chdir(model_dir)
output = sess.run(output_t, feed_dict)
with open(output_fname, 'w') as f:
f.write(numpy_float_array_to_float_val_str(output))
def boolean_string(s):
if s not in {'False', 'True'}:
raise ValueError('Not a valid boolean string')
return s == 'True'
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--modelName", required=True, type=str, help="Name of processed tensorflow model (mpc_processed*.pb)")
parser.add_argument("--inputTensorName", required=True, type=str, help="Name of the input tensor for the model. (Op name, dont add '/:0' suffix)")
parser.add_argument("--outputTensorName", required=True, type=str, help="Name of the input tensor for the model. (Op name, dont add '/:0' suffix)")
parser.add_argument("--inputTensorNumpyArr", required=True, type=str, help="Name of the input tensor numpy array file for the model.")
parser.add_argument("--outputFileName", required=True, type=str, help="Name of the output file to store the prediction.")
args = parser.parse_args()
return args
if __name__ == '__main__':
args = parse_args()
compile(args.modelName, args.inputTensorName, args.outputTensorName, args.inputTensorNumpyArr, args.outputFileName)