130 lines
4.6 KiB
Python
130 lines
4.6 KiB
Python
#!/usr/bin/env python
|
|
# -*- coding: utf-8 -*-
|
|
|
|
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
import os
|
|
|
|
import tensorflow as tf
|
|
|
|
import horovod.tensorflow as hvd
|
|
from model import resnet
|
|
|
|
tf.app.flags.DEFINE_string(
|
|
'model_name', 'resnet50', 'The name of the architecture to save. The default name was being '
|
|
'used to train the model')
|
|
|
|
tf.app.flags.DEFINE_integer(
|
|
'image_size', 224,
|
|
'The image size to use, otherwise use the model default_image_size.')
|
|
|
|
tf.app.flags.DEFINE_integer(
|
|
'num_classes', 1001,
|
|
'The number of classes to predict.')
|
|
|
|
tf.app.flags.DEFINE_integer(
|
|
'batch_size', None,
|
|
'Batch size for the exported model. Defaulted to "None" so batch size can '
|
|
'be specified at model runtime.')
|
|
|
|
|
|
tf.app.flags.DEFINE_string('input_format', 'NCHW',
|
|
'The dataformat used by the layers in the model')
|
|
|
|
tf.app.flags.DEFINE_string('compute_format', 'NCHW',
|
|
'The dataformat used by the layers in the model')
|
|
|
|
tf.app.flags.DEFINE_string('checkpoint', '',
|
|
'The trained model checkpoint.')
|
|
|
|
tf.app.flags.DEFINE_string(
|
|
'output_file', '', 'Where to save the resulting file to.')
|
|
|
|
tf.app.flags.DEFINE_bool(
|
|
'quantize', False, 'whether to use quantized graph or not.')
|
|
|
|
tf.app.flags.DEFINE_bool(
|
|
'symmetric', False, 'Using symmetric quantization or not.')
|
|
|
|
|
|
tf.app.flags.DEFINE_bool(
|
|
'use_qdq', False, 'Use quantize and dequantize op instead of fake quant op')
|
|
|
|
tf.app.flags.DEFINE_bool(
|
|
'use_final_conv', False, 'whether to use quantized graph or not.')
|
|
|
|
tf.app.flags.DEFINE_bool('write_text_graphdef', False,
|
|
'Whether to write a text version of graphdef.')
|
|
|
|
FLAGS = tf.app.flags.FLAGS
|
|
|
|
|
|
def main(_):
|
|
|
|
# Initialize Horovod (TODO: Remove dependency of horovod for freezing graphs)
|
|
hvd.init()
|
|
|
|
if not FLAGS.output_file:
|
|
raise ValueError('You must supply the path to save to with --output_file')
|
|
|
|
tf.logging.set_verbosity(tf.logging.INFO)
|
|
with tf.Graph().as_default() as graph:
|
|
if FLAGS.input_format=='NCHW':
|
|
input_shape = [FLAGS.batch_size, 3, FLAGS.image_size, FLAGS.image_size]
|
|
else:
|
|
input_shape = [FLAGS.batch_size, FLAGS.image_size, FLAGS.image_size, 3]
|
|
input_images = tf.placeholder(name='input', dtype=tf.float32, shape=input_shape)
|
|
|
|
resnet50_config = resnet.model_architectures[FLAGS.model_name]
|
|
network = resnet.ResnetModel(FLAGS.model_name,
|
|
FLAGS.num_classes,
|
|
resnet50_config['layers'],
|
|
resnet50_config['widths'],
|
|
resnet50_config['expansions'],
|
|
FLAGS.compute_format,
|
|
FLAGS.input_format)
|
|
probs, logits = network.build_model(
|
|
input_images,
|
|
training=False,
|
|
reuse=False,
|
|
use_final_conv=FLAGS.use_final_conv)
|
|
|
|
if FLAGS.quantize:
|
|
tf.contrib.quantize.experimental_create_eval_graph(symmetric=FLAGS.symmetric,
|
|
use_qdq=FLAGS.use_qdq)
|
|
|
|
# Define the saver and restore the checkpoint
|
|
saver = tf.train.Saver()
|
|
with tf.Session() as sess:
|
|
if FLAGS.checkpoint:
|
|
saver.restore(sess, FLAGS.checkpoint)
|
|
else:
|
|
sess.run(tf.global_variables_initializer())
|
|
graph_def = graph.as_graph_def()
|
|
frozen_graph_def = tf.graph_util.convert_variables_to_constants(sess, graph_def, [probs.op.name])
|
|
|
|
# Write out the frozen graph
|
|
tf.io.write_graph(
|
|
frozen_graph_def,
|
|
os.path.dirname(FLAGS.output_file),
|
|
os.path.basename(FLAGS.output_file),
|
|
as_text=FLAGS.write_text_graphdef)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
tf.app.run()
|