DeepLearningExamples/TensorFlow/Classification/ConvNets/export_frozen_graph.py

104 lines
3.5 KiB
Python
Raw Normal View History

2020-06-22 01:41:23 +02:00
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import tensorflow as tf
import horovod.tensorflow as hvd
from model import resnet_v1_5
tf.app.flags.DEFINE_string(
'model_name', 'resnet50_v1.5', 'The name of the architecture to save. The default name was being '
'used to train the model')
tf.app.flags.DEFINE_integer(
'image_size', 224,
'The image size to use, otherwise use the model default_image_size.')
tf.app.flags.DEFINE_integer(
'num_classes', 1001,
'The number of classes to predict.')
tf.app.flags.DEFINE_integer(
'batch_size', None,
'Batch size for the exported model. Defaulted to "None" so batch size can '
'be specified at model runtime.')
tf.app.flags.DEFINE_string('input_format', 'NHWC',
'The dataformat used by the layers in the model')
tf.app.flags.DEFINE_string('compute_format', 'NHWC',
'The dataformat used by the layers in the model')
tf.app.flags.DEFINE_string('checkpoint', '',
'The trained model checkpoint.')
tf.app.flags.DEFINE_string(
'output_file', '', 'Where to save the resulting file to.')
tf.app.flags.DEFINE_bool(
'quantize', False, 'whether to use quantized graph or not.')
tf.app.flags.DEFINE_bool(
'symmetric', False, 'Using symmetric quantization or not.')
tf.app.flags.DEFINE_bool(
'use_qdq', False, 'Use quantize and dequantize op instead of fake quant op')
tf.app.flags.DEFINE_bool(
'use_final_conv', False, 'whether to use quantized graph or not.')
tf.app.flags.DEFINE_bool('write_text_graphdef', False,
'Whether to write a text version of graphdef.')
FLAGS = tf.app.flags.FLAGS
def main(_):
# Initialize Horovod (TODO: Remove dependency of horovod for freezing graphs)
hvd.init()
if not FLAGS.output_file:
raise ValueError('You must supply the path to save to with --output_file')
tf.logging.set_verbosity(tf.logging.INFO)
with tf.Graph().as_default() as graph:
if FLAGS.input_format=='NCHW':
input_shape = [FLAGS.batch_size, 3, FLAGS.image_size, FLAGS.image_size]
else:
input_shape = [FLAGS.batch_size, FLAGS.image_size, FLAGS.image_size, 3]
input_images = tf.placeholder(name='input', dtype=tf.float32, shape=input_shape)
network = resnet_v1_5.ResnetModel(FLAGS.model_name, FLAGS.num_classes, FLAGS.compute_format, FLAGS.input_format)
probs, logits = network.build_model(
input_images,
training=False,
reuse=False,
use_final_conv=FLAGS.use_final_conv)
if FLAGS.quantize:
tf.contrib.quantize.experimental_create_eval_graph(symmetric=FLAGS.symmetric, use_qdq=FLAGS.use_qdq)
# Define the saver and restore the checkpoint
saver = tf.train.Saver()
with tf.Session() as sess:
if FLAGS.checkpoint:
saver.restore(sess, FLAGS.checkpoint)
else:
sess.run(tf.global_variables_initializer())
graph_def = graph.as_graph_def()
frozen_graph_def = tf.graph_util.convert_variables_to_constants(sess, graph_def, [probs.op.name])
# Write out the frozen graph
tf.io.write_graph(
frozen_graph_def,
os.path.dirname(FLAGS.output_file),
os.path.basename(FLAGS.output_file),
as_text=FLAGS.write_text_graphdef)
if __name__ == '__main__':
tf.app.run()