Skip to content
Snippets Groups Projects
Commit 2aa79b12 authored by Bhatu's avatar Bhatu
Browse files

Minor improvements to compiler scripts, handle const to var better for some ops

parent 4346e966
No related branches found
No related tags found
No related merge requests found
......@@ -120,11 +120,12 @@ def compile(model_fname, input_t_info, output_t_names, scaling_factor, save_weig
)
tf_graph_io.dump_graph_def_pb(
optimized_graph_def, "optimised_" + model_fname
optimized_graph_def, "optimised_" + model_name + ".pb"
)
DumpTFMtData.save_graphdef(optimized_graph_def)
DumpTFMtData.save_sizeinfo(optimized_graph_def, sess, feed_dict)
print("Model compilation done.")
weights_path = ""
if save_weights:
weights_fname = (
model_name
......@@ -140,8 +141,9 @@ def compile(model_fname, input_t_info, output_t_names, scaling_factor, save_weig
DumpTFMtData.save_weights(
optimized_graph_def, sess, feed_dict, weights_fname, scaling_factor
)
weights_path = os.path.join(model_dir, weights_fname)
os.chdir(cwd)
return
return weights_path
def parse_args():
......
......@@ -61,6 +61,9 @@ def get_white_list(graph):
mean_axes_ops = set(
i.inputs[1].op.name for i in graph.get_operations() if i.type == "Mean"
)
sum_axes_ops = set(
i.inputs[1].op.name for i in graph.get_operations() if i.type == "Sum"
)
split_dim_ops = set(
i.inputs[0].op.name for i in graph.get_operations() if i.type == "Split"
)
......@@ -69,14 +72,24 @@ def get_white_list(graph):
for i in graph.get_operations()
if i.type == "ConcatV2" or i.type == "Concat"
)
argmax_axes_ops = set(
i.inputs[1].op.name for i in graph.get_operations() if i.type == "ArgMax"
)
divisor_ops = set(
i.inputs[1].op.name for i in graph.get_operations() if i.type in ["FloorDiv", "RealDiv"]
)
white_list = (
transp_perm_ops
| padding_ops
| slice_begin_ops
| slice_size_ops
| mean_axes_ops
| sum_axes_ops
| split_dim_ops
| concat_axes_ops
| argmax_axes_ops
| divisor_ops
)
return list(white_list)
......
......@@ -115,7 +115,7 @@ def get_shape_list(shape_string):
if shape_string == "":
return shape
for i in shape_string.split(","):
assert i.isnumeric(), "Given input shape has non-integer values"
assert i.isnumeric(), "Given input shape has non-integer value : {}".format(i)
shape.append(int(i))
return shape
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment