DeepLearningExamples/TensorFlow/Segmentation/UNet_Industrial/export_saved_model.py

222 lines
6.6 KiB
Python

# !/usr/bin/env python
# -*- coding: utf-8 -*-
# ==============================================================================
#
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# ==============================================================================
"""
Usage:
python export_saved_model.py \
--activation_fn='relu' \
--batch_size=16 \
--data_format='NCHW' \
--input_dtype="fp32" \
--export_dir="exported_models" \
--model_checkpoint_path="path/to/checkpoint/model.ckpt-2500" \
--unet_variant='tinyUNet' \
--use_xla \
--use_tf_amp
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import argparse
import pprint
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
import tensorflow as tf
from dllogger.logger import LOGGER
from model.unet import UNet_v1
from model.blocks.activation_blck import authorized_activation_fn
from utils.cmdline_helper import _add_bool_argument
def get_export_flags():
parser = argparse.ArgumentParser(description="JoC-UNet_v1-TF-ExportFlags")
parser.add_argument('--export_dir', default=None, required=True, type=str, help='The export directory.')
parser.add_argument('--model_checkpoint_path', default=None, required=True, help='Checkpoint path.')
parser.add_argument(
'--data_format',
choices=['NHWC', 'NCHW'],
type=str,
default="NCHW",
required=False,
help="""Which Tensor format is used for computation inside the mode"""
)
parser.add_argument(
'--input_dtype',
choices=['fp32', 'fp16'],
type=str,
default="fp32",
required=False,
help="""Tensorflow dtype of the input tensor"""
)
parser.add_argument(
'--unet_variant',
default="tinyUNet",
choices=UNet_v1.authorized_models_variants,
type=str,
required=False,
help="""Which model size is used. This parameter control directly the size and the number of parameters"""
)
parser.add_argument(
'--activation_fn',
choices=authorized_activation_fn,
type=str,
default="relu",
required=False,
help="""Which activation function is used after the convolution layers"""
)
_add_bool_argument(
parser=parser,
name="use_tf_amp",
default=False,
required=False,
help="Enable Automatic Mixed Precision Computation to maximise performance."
)
_add_bool_argument(
parser=parser,
name="use_xla",
default=False,
required=False,
help="Enable Tensorflow XLA to maximise performance."
)
parser.add_argument('--batch_size', default=16, type=int, help='Evaluation batch size.')
FLAGS, unknown_args = parser.parse_known_args()
if len(unknown_args) > 0:
for bad_arg in unknown_args:
print("ERROR: Unknown command line arg: %s" % bad_arg)
raise ValueError("Invalid command line arg(s)")
return FLAGS
def export_model(RUNNING_CONFIG):
if RUNNING_CONFIG.use_tf_amp:
os.environ["TF_ENABLE_AUTO_MIXED_PRECISION_GRAPH_REWRITE"] = "1"
model = UNet_v1(
model_name="UNet_v1",
input_format="NHWC",
compute_format=RUNNING_CONFIG.data_format,
n_output_channels=1,
unet_variant=RUNNING_CONFIG.unet_variant,
weight_init_method="he_normal",
activation_fn=RUNNING_CONFIG.activation_fn
)
config_proto = tf.ConfigProto()
config_proto.allow_soft_placement = True
config_proto.log_device_placement = False
config_proto.gpu_options.allow_growth = True
if RUNNING_CONFIG.use_xla: # Only working on single GPU
LOGGER.log("XLA is activated - Experimental Feature")
config_proto.graph_options.optimizer_options.global_jit_level = tf.OptimizerOptions.ON_1
config_proto.gpu_options.force_gpu_compatible = True # Force pinned memory
run_config = tf.estimator.RunConfig(
model_dir=None,
tf_random_seed=None,
save_summary_steps=1e9, # disabled
save_checkpoints_steps=None,
save_checkpoints_secs=None,
session_config=config_proto,
keep_checkpoint_max=None,
keep_checkpoint_every_n_hours=1e9, # disabled
log_step_count_steps=1e9,
train_distribute=None,
device_fn=None,
protocol=None,
eval_distribute=None,
experimental_distribute=None
)
estimator = tf.estimator.Estimator(
model_fn=model,
model_dir=RUNNING_CONFIG.model_checkpoint_path,
config=run_config,
params={'debug_verbosity': 0}
)
LOGGER.log('[*] Exporting the model ...')
input_type = tf.float32 if RUNNING_CONFIG.input_dtype else tf.float16
def get_serving_input_receiver_fn():
input_shape = [RUNNING_CONFIG.batch_size, 512, 512, 1]
def serving_input_receiver_fn():
features = tf.placeholder(dtype=input_type, shape=input_shape, name='input_tensor')
return tf.estimator.export.TensorServingInputReceiver(features=features, receiver_tensors=features)
return serving_input_receiver_fn
export_path = estimator.export_saved_model(
export_dir_base=RUNNING_CONFIG.export_dir,
serving_input_receiver_fn=get_serving_input_receiver_fn(),
checkpoint_path=RUNNING_CONFIG.model_checkpoint_path
)
LOGGER.log('[*] Done! path: `%s`' % export_path.decode())
if __name__ == '__main__':
tf.logging.set_verbosity(tf.logging.ERROR)
tf.disable_eager_execution()
flags = get_export_flags()
for endpattern in [".index", ".meta"]:
file_to_check = flags.model_checkpoint_path + endpattern
if not os.path.isfile(file_to_check):
raise FileNotFoundError("The checkpoint file `%s` does not exist" % file_to_check)
print(" ========================= Export Flags =========================\n")
pprint.pprint(dict(flags._get_kwargs()))
print("\n %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%")
export_model(flags)