Skip to content
Snippets Groups Projects
Commit 7a0f955a authored by Bhatu's avatar Bhatu
Browse files

Fix tf size inference for multiple output tensors.

There was an assumption that ops only have single tensor outputs.
However ops like split return multiple tensors. This fixes that.
parent 07901e61
No related branches found
No related tags found
No related merge requests found
......@@ -40,7 +40,7 @@ def get_optimized_graph_def(output_tensor):
def save_graph_metadata(output_tensor, sess, feed_dict):
#First save the graph def
graph_def = tf.get_default_graph().as_graph_def()
graph_def = sess.graph_def
transforms = [
'remove_nodes(op=Identity)',
'strip_unused_nodes',
......@@ -54,11 +54,17 @@ def save_graph_metadata(output_tensor, sess, feed_dict):
# Save size information for tensors on which output depends
tensors_to_evaluate = []
tensors_to_evaluate_names = []
graph = tf.get_default_graph()
graph = sess.graph
for node in optimized_graph_def.node:
cur_output = graph.get_operation_by_name(node.name).outputs[0]
tensors_to_evaluate.append(cur_output)
tensors_to_evaluate_names.append(node.name)
output_number = 0
for cur_output in graph.get_operation_by_name(node.name).outputs:
tensors_to_evaluate.append(cur_output)
if output_number == 0:
tensor_name = node.name
else:
tensor_name = cur_output.name
tensors_to_evaluate_names.append(tensor_name)
output_number += 1
tensors_evaluated = sess.run(tensors_to_evaluate, feed_dict)
tensors_shape = list(map(lambda x : x.shape, tensors_evaluated))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment