Skip to content
Snippets Groups Projects
get_pred_tf_graph.py 3.21 KiB
import tensorflow as tf
import numpy as np
import argparse

from tf_graph_io import *
from tf_graph_trans import *

import os.path
import sys
sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'TFCompiler'))
import DumpTFMtData

from os import path

def check_operation_exists(graph, tensor_name):
  op_list = [i.name for i in graph.get_operations()]
  return tensor_name in op_list

def numpy_float_array_to_float_val_str(input_array):
  chunk = ''
  for val in np.nditer(input_array):
    chunk += str(val) + '\n'
  return chunk

def compile(model_fname, input_t_name, output_t_name, input_np_arr, output_fname):
  if not model_fname.endswith('.pb'):
    sys.exit("Please supply a valid tensorflow protobuf model (.pb extension)")
  elif not "mpc_processed_" in model_fname:
    sys.exit("""Please process model using preprocess_frozen_tf_graph.py.
This will optimise it and generate a new .pb with mpc_processed prefix.
Use that with this script.""")
  else:
    model_name = os.path.basename(model_fname)[:-3]

  print("Loading processed tf graph ", model_fname)
  graph = load_pb(model_fname)

  if not check_operation_exists(graph, input_t_name):
    sys.exit(input_t_name + " input does not exist in the graph")
  if not check_operation_exists(graph, output_t_name):
    sys.exit(output_t_name + " output does not exist in the graph")
  if not os.path.isfile(input_np_arr):
    sys.exit(input_np_arr + " file does not exist.")

  input_t = graph.get_operation_by_name(input_t_name).outputs[0]
  output_t = graph.get_operation_by_name(output_t_name).outputs[0]

  np_input_t = np.load(input_np_arr, allow_pickle=True)

  feed_dict = {input_t: np_input_t}
  with graph.as_default():
    with tf.Session() as sess:
      # Run initializers generated by preprocessing
      if check_operation_exists(graph, 'init_constvars'):
        sess.run(graph.get_operation_by_name('init_constvars'))
      else:
        sess.run(tf.global_variables_initializer())
      model_dir = os.path.realpath(os.path.dirname(model_fname))
      os.chdir(model_dir)
      output = sess.run(output_t, feed_dict)
      with open(output_fname, 'w') as f:
        f.write(numpy_float_array_to_float_val_str(output)) 

def boolean_string(s):
  if s not in {'False', 'True'}:
    raise ValueError('Not a valid boolean string')
  return s == 'True'

def parse_args():
  parser = argparse.ArgumentParser()
  parser.add_argument("--modelName", required=True, type=str, help="Name of processed tensorflow model (mpc_processed*.pb)")
  parser.add_argument("--inputTensorName", required=True, type=str, help="Name of the input tensor for the model. (Op name, dont add '/:0' suffix)")
  parser.add_argument("--outputTensorName", required=True, type=str, help="Name of the input tensor for the model. (Op name, dont add '/:0' suffix)")
  parser.add_argument("--inputTensorNumpyArr", required=True, type=str, help="Name of the input tensor numpy array file for the model.")
  parser.add_argument("--outputFileName", required=True, type=str, help="Name of the output file to store the prediction.")
  args = parser.parse_args()
  return args

if __name__ == '__main__':
  args = parse_args()
  compile(args.modelName, args.inputTensorName, args.outputTensorName, args.inputTensorNumpyArr, args.outputFileName)