Skip to content
Snippets Groups Projects
Commit 818b8c2c authored by Bhatu's avatar Bhatu
Browse files

Handle TransformGraph for cases where constant outputs where converted to vars.

If the graph has a constant output, it will get converted to a variable. While
dumping graph_defs TransformGraph needs to be able to find that output. So we
teach it to find the newly created variable.
parent 2aa79b12
No related branches found
No related tags found
No related merge requests found
......@@ -31,7 +31,21 @@ def strip_variable_init_constants(graph_def, input_tensor_names, output_tensor_n
'remove_nodes(op=Identity)',
'strip_unused_nodes',
]
optimized_graph_def = TransformGraph(graph_def, input_tensor_names, output_tensor_names, transforms)
# Sanity check if output/input nodes were constant and replaced with variables.
all_node_names = set([i.name for i in graph_def.node])
def get_true_names(tensor_names, all_nodes):
real_names = []
for i in tensor_names:
if i not in all_nodes:
var_name = i + "_mpc_const_var"
if var_name in all_nodes:
real_names.append(var_name)
else:
real_names.append(i)
return real_names
real_input_names = get_true_names(input_tensor_names, all_node_names)
real_output_names = get_true_names(output_tensor_names, all_node_names)
optimized_graph_def = TransformGraph(graph_def, real_input_names, real_output_names, transforms)
return optimized_graph_def
def save_graphdef(graph_def):
......@@ -179,6 +193,8 @@ def save_weights(optimized_graph_def, sess, feed_dict, filename, scaling_factor)
values = sess.run(graph_vars, feed_dict)
with open(filename, "w") as ff:
for val in values:
if val.shape == (0,): #Empty array, nothing to dump.
continue
for xx in numpy.nditer(val, order="C"):
ff.write(str(int(xx * (1 << scaling_factor))) + " ")
ff.write("\n")
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment