diff --git a/Athos/Networks/SqueezeNetImgNet/squeezenet_main.py b/Athos/Networks/SqueezeNetImgNet/squeezenet_main.py index f60af753fbb3e2cf2ca145d1bf7465b920aafb5e..c0f69149b9ac317dce13127374f4e69434710f72 100644 --- a/Athos/Networks/SqueezeNetImgNet/squeezenet_main.py +++ b/Athos/Networks/SqueezeNetImgNet/squeezenet_main.py @@ -203,6 +203,7 @@ def build_parser(): ps.add_argument('--saveImgAndWtData', dest='saveImgAndWtData', type=bool, help='bool to indicate if to save img and model weights', required=False) ps.add_argument('--savePreTrainedWeightsFloat', dest='savePreTrainedWeightsFloat', type=bool, help='bool to indicate if to save model weights float', required=False) ps.add_argument('--savePreTrainedWeightsInt', dest='savePreTrainedWeightsInt', type=bool, help='bool to indicate if to save model weights int', required=False) + ps.add_argument('--saveImgAndWeightsSeparately', dest='saveImgAndWeightsSeparately', type=bool, help='bool to indicate if to save image and model weights int separately', required=False) ps.add_argument('--scalingFac', dest='scalingFac', type=int, help='scalingFac', default=15, required=False) return ps @@ -266,9 +267,12 @@ def main(): if options.saveImgAndWtData: DumpTFMtData.dumpImgAndWeightsData(sess, imageData, all_weights, 'SqNetImgNet_img_input.inp', options.scalingFac, alreadyEvaluated=True) if options.savePreTrainedWeightsInt: - DumpTFMtData.dumpTrainedWeights(sess, all_weights, 'SqNet_img_input_weights_int.inp', options.scalingFac, 'w', alreadyEvaluated=True) + DumpTFMtData.dumpTrainedWeightsInt(sess, all_weights, 'SqNet_trained_weights_int.inp', options.scalingFac, 'w', alreadyEvaluated=True) if options.savePreTrainedWeightsFloat: - DumpTFMtData.dumpTrainedWeightsFloat(sess, all_weights, 'SqNet_img_input_weights_float.inp', 'w', alreadyEvaluated=True) - + DumpTFMtData.dumpTrainedWeightsFloat(sess, all_weights, 'SqNet_trained_weights_float.inp', 'w', alreadyEvaluated=True) + if options.saveImgAndWeightsSeparately: + DumpTFMtData.dumpTrainedWeightsInt(sess, all_weights, 'SqNet_trained_weights_int.inp', options.scalingFac, 'w', alreadyEvaluated=True) + DumpTFMtData.dumpImageDataInt(imageData, 'SqNet_image_data.inp', options.scalingFac, 'w') + if __name__ == '__main__': main() \ No newline at end of file