import tensorflow as tf import numpy as np import argparse from tf_graph_io import * from tf_graph_trans import * import os.path import sys sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'TFCompiler')) import DumpTFMtData from os import path def check_operation_exists(graph, tensor_name): op_list = [i.name for i in graph.get_operations()] return tensor_name in op_list def numpy_float_array_to_float_val_str(input_array): chunk = '' for val in np.nditer(input_array): chunk += str(val) + '\n' return chunk def compile(model_fname, input_t_name, output_t_name, input_np_arr, output_fname): if not model_fname.endswith('.pb'): sys.exit("Please supply a valid tensorflow protobuf model (.pb extension)") elif not "mpc_processed_" in model_fname: sys.exit("""Please process model using preprocess_frozen_tf_graph.py. This will optimise it and generate a new .pb with mpc_processed prefix. Use that with this script.""") else: model_name = os.path.basename(model_fname)[:-3] print("Loading processed tf graph ", model_fname) graph = load_pb(model_fname) if not check_operation_exists(graph, input_t_name): sys.exit(input_t_name + " input does not exist in the graph") if not check_operation_exists(graph, output_t_name): sys.exit(output_t_name + " output does not exist in the graph") if not os.path.isfile(input_np_arr): sys.exit(input_np_arr + " file does not exist.") input_t = graph.get_operation_by_name(input_t_name).outputs[0] output_t = graph.get_operation_by_name(output_t_name).outputs[0] np_input_t = np.load(input_np_arr, allow_pickle=True) feed_dict = {input_t: np_input_t} with graph.as_default(): with tf.Session() as sess: # Run initializers generated by preprocessing if check_operation_exists(graph, 'init_constvars'): sess.run(graph.get_operation_by_name('init_constvars')) else: sess.run(tf.global_variables_initializer()) model_dir = os.path.realpath(os.path.dirname(model_fname)) os.chdir(model_dir) output = sess.run(output_t, feed_dict) with open(output_fname, 'w') as f: f.write(numpy_float_array_to_float_val_str(output)) def boolean_string(s): if s not in {'False', 'True'}: raise ValueError('Not a valid boolean string') return s == 'True' def parse_args(): parser = argparse.ArgumentParser() parser.add_argument("--modelName", required=True, type=str, help="Name of processed tensorflow model (mpc_processed*.pb)") parser.add_argument("--inputTensorName", required=True, type=str, help="Name of the input tensor for the model. (Op name, dont add '/:0' suffix)") parser.add_argument("--outputTensorName", required=True, type=str, help="Name of the input tensor for the model. (Op name, dont add '/:0' suffix)") parser.add_argument("--inputTensorNumpyArr", required=True, type=str, help="Name of the input tensor numpy array file for the model.") parser.add_argument("--outputFileName", required=True, type=str, help="Name of the output file to store the prediction.") args = parser.parse_args() return args if __name__ == '__main__': args = parse_args() compile(args.modelName, args.inputTensorName, args.outputTensorName, args.inputTensorNumpyArr, args.outputFileName)