from __future__ import absolute_import from __future__ import division from __future__ import print_function import os import tensorflow as tf import horovod.tensorflow as hvd from model import resnet_v1_5 tf.app.flags.DEFINE_string( 'model_name', 'resnet50_v1.5', 'The name of the architecture to save. The default name was being ' 'used to train the model') tf.app.flags.DEFINE_integer( 'image_size', 224, 'The image size to use, otherwise use the model default_image_size.') tf.app.flags.DEFINE_integer( 'num_classes', 1001, 'The number of classes to predict.') tf.app.flags.DEFINE_integer( 'batch_size', None, 'Batch size for the exported model. Defaulted to "None" so batch size can ' 'be specified at model runtime.') tf.app.flags.DEFINE_string('input_format', 'NHWC', 'The dataformat used by the layers in the model') tf.app.flags.DEFINE_string('compute_format', 'NHWC', 'The dataformat used by the layers in the model') tf.app.flags.DEFINE_string('checkpoint', '', 'The trained model checkpoint.') tf.app.flags.DEFINE_string( 'output_file', '', 'Where to save the resulting file to.') tf.app.flags.DEFINE_bool( 'quantize', False, 'whether to use quantized graph or not.') tf.app.flags.DEFINE_bool( 'symmetric', False, 'Using symmetric quantization or not.') tf.app.flags.DEFINE_bool( 'use_qdq', False, 'Use quantize and dequantize op instead of fake quant op') tf.app.flags.DEFINE_bool( 'use_final_conv', False, 'whether to use quantized graph or not.') tf.app.flags.DEFINE_bool('write_text_graphdef', False, 'Whether to write a text version of graphdef.') FLAGS = tf.app.flags.FLAGS def main(_): # Initialize Horovod (TODO: Remove dependency of horovod for freezing graphs) hvd.init() if not FLAGS.output_file: raise ValueError('You must supply the path to save to with --output_file') tf.logging.set_verbosity(tf.logging.INFO) with tf.Graph().as_default() as graph: if FLAGS.input_format=='NCHW': input_shape = [FLAGS.batch_size, 3, FLAGS.image_size, FLAGS.image_size] else: input_shape = [FLAGS.batch_size, FLAGS.image_size, FLAGS.image_size, 3] input_images = tf.placeholder(name='input', dtype=tf.float32, shape=input_shape) network = resnet_v1_5.ResnetModel(FLAGS.model_name, FLAGS.num_classes, FLAGS.compute_format, FLAGS.input_format) probs, logits = network.build_model( input_images, training=False, reuse=False, use_final_conv=FLAGS.use_final_conv) if FLAGS.quantize: tf.contrib.quantize.experimental_create_eval_graph(symmetric=FLAGS.symmetric, use_qdq=FLAGS.use_qdq) # Define the saver and restore the checkpoint saver = tf.train.Saver() with tf.Session() as sess: if FLAGS.checkpoint: saver.restore(sess, FLAGS.checkpoint) else: sess.run(tf.global_variables_initializer()) graph_def = graph.as_graph_def() frozen_graph_def = tf.graph_util.convert_variables_to_constants(sess, graph_def, [probs.op.name]) # Write out the frozen graph tf.io.write_graph( frozen_graph_def, os.path.dirname(FLAGS.output_file), os.path.basename(FLAGS.output_file), as_text=FLAGS.write_text_graphdef) if __name__ == '__main__': tf.app.run()