222 lines
6.6 KiB
Python
222 lines
6.6 KiB
Python
# !/usr/bin/env python
|
|
# -*- coding: utf-8 -*-
|
|
|
|
# ==============================================================================
|
|
#
|
|
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
#
|
|
# ==============================================================================
|
|
|
|
"""
|
|
Usage:
|
|
|
|
python export_saved_model.py \
|
|
--activation_fn='relu' \
|
|
--batch_size=16 \
|
|
--data_format='NCHW' \
|
|
--input_dtype="fp32" \
|
|
--export_dir="exported_models" \
|
|
--model_checkpoint_path="path/to/checkpoint/model.ckpt-2500" \
|
|
--unet_variant='tinyUNet' \
|
|
--use_xla \
|
|
--use_tf_amp
|
|
"""
|
|
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
import os
|
|
import argparse
|
|
import pprint
|
|
|
|
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
|
|
|
|
import tensorflow as tf
|
|
|
|
from dllogger.logger import LOGGER
|
|
|
|
from model.unet import UNet_v1
|
|
from model.blocks.activation_blck import authorized_activation_fn
|
|
|
|
from utils.cmdline_helper import _add_bool_argument
|
|
|
|
|
|
def get_export_flags():
|
|
parser = argparse.ArgumentParser(description="JoC-UNet_v1-TF-ExportFlags")
|
|
|
|
parser.add_argument('--export_dir', default=None, required=True, type=str, help='The export directory.')
|
|
parser.add_argument('--model_checkpoint_path', default=None, required=True, help='Checkpoint path.')
|
|
|
|
parser.add_argument(
|
|
'--data_format',
|
|
choices=['NHWC', 'NCHW'],
|
|
type=str,
|
|
default="NCHW",
|
|
required=False,
|
|
help="""Which Tensor format is used for computation inside the mode"""
|
|
)
|
|
|
|
parser.add_argument(
|
|
'--input_dtype',
|
|
choices=['fp32', 'fp16'],
|
|
type=str,
|
|
default="fp32",
|
|
required=False,
|
|
help="""Tensorflow dtype of the input tensor"""
|
|
)
|
|
|
|
parser.add_argument(
|
|
'--unet_variant',
|
|
default="tinyUNet",
|
|
choices=UNet_v1.authorized_models_variants,
|
|
type=str,
|
|
required=False,
|
|
help="""Which model size is used. This parameter control directly the size and the number of parameters"""
|
|
)
|
|
|
|
parser.add_argument(
|
|
'--activation_fn',
|
|
choices=authorized_activation_fn,
|
|
type=str,
|
|
default="relu",
|
|
required=False,
|
|
help="""Which activation function is used after the convolution layers"""
|
|
)
|
|
|
|
_add_bool_argument(
|
|
parser=parser,
|
|
name="use_tf_amp",
|
|
default=False,
|
|
required=False,
|
|
help="Enable Automatic Mixed Precision Computation to maximise performance."
|
|
)
|
|
|
|
_add_bool_argument(
|
|
parser=parser,
|
|
name="use_xla",
|
|
default=False,
|
|
required=False,
|
|
help="Enable Tensorflow XLA to maximise performance."
|
|
)
|
|
|
|
parser.add_argument('--batch_size', default=16, type=int, help='Evaluation batch size.')
|
|
|
|
FLAGS, unknown_args = parser.parse_known_args()
|
|
|
|
if len(unknown_args) > 0:
|
|
|
|
for bad_arg in unknown_args:
|
|
print("ERROR: Unknown command line arg: %s" % bad_arg)
|
|
|
|
raise ValueError("Invalid command line arg(s)")
|
|
|
|
return FLAGS
|
|
|
|
|
|
def export_model(RUNNING_CONFIG):
|
|
|
|
if RUNNING_CONFIG.use_tf_amp:
|
|
os.environ["TF_ENABLE_AUTO_MIXED_PRECISION_GRAPH_REWRITE"] = "1"
|
|
|
|
model = UNet_v1(
|
|
model_name="UNet_v1",
|
|
input_format="NHWC",
|
|
compute_format=RUNNING_CONFIG.data_format,
|
|
n_output_channels=1,
|
|
unet_variant=RUNNING_CONFIG.unet_variant,
|
|
weight_init_method="he_normal",
|
|
activation_fn=RUNNING_CONFIG.activation_fn
|
|
)
|
|
|
|
config_proto = tf.ConfigProto()
|
|
|
|
config_proto.allow_soft_placement = True
|
|
config_proto.log_device_placement = False
|
|
|
|
config_proto.gpu_options.allow_growth = True
|
|
|
|
if RUNNING_CONFIG.use_xla: # Only working on single GPU
|
|
LOGGER.log("XLA is activated - Experimental Feature")
|
|
config_proto.graph_options.optimizer_options.global_jit_level = tf.OptimizerOptions.ON_1
|
|
|
|
config_proto.gpu_options.force_gpu_compatible = True # Force pinned memory
|
|
|
|
run_config = tf.estimator.RunConfig(
|
|
model_dir=None,
|
|
tf_random_seed=None,
|
|
save_summary_steps=1e9, # disabled
|
|
save_checkpoints_steps=None,
|
|
save_checkpoints_secs=None,
|
|
session_config=config_proto,
|
|
keep_checkpoint_max=None,
|
|
keep_checkpoint_every_n_hours=1e9, # disabled
|
|
log_step_count_steps=1e9,
|
|
train_distribute=None,
|
|
device_fn=None,
|
|
protocol=None,
|
|
eval_distribute=None,
|
|
experimental_distribute=None
|
|
)
|
|
|
|
estimator = tf.estimator.Estimator(
|
|
model_fn=model,
|
|
model_dir=RUNNING_CONFIG.model_checkpoint_path,
|
|
config=run_config,
|
|
params={'debug_verbosity': 0}
|
|
)
|
|
|
|
LOGGER.log('[*] Exporting the model ...')
|
|
|
|
input_type = tf.float32 if RUNNING_CONFIG.input_dtype else tf.float16
|
|
|
|
def get_serving_input_receiver_fn():
|
|
|
|
input_shape = [RUNNING_CONFIG.batch_size, 512, 512, 1]
|
|
|
|
def serving_input_receiver_fn():
|
|
features = tf.placeholder(dtype=input_type, shape=input_shape, name='input_tensor')
|
|
|
|
return tf.estimator.export.TensorServingInputReceiver(features=features, receiver_tensors=features)
|
|
|
|
return serving_input_receiver_fn
|
|
|
|
export_path = estimator.export_saved_model(
|
|
export_dir_base=RUNNING_CONFIG.export_dir,
|
|
serving_input_receiver_fn=get_serving_input_receiver_fn(),
|
|
checkpoint_path=RUNNING_CONFIG.model_checkpoint_path
|
|
)
|
|
|
|
LOGGER.log('[*] Done! path: `%s`' % export_path.decode())
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
tf.logging.set_verbosity(tf.logging.ERROR)
|
|
tf.disable_eager_execution()
|
|
|
|
flags = get_export_flags()
|
|
|
|
for endpattern in [".index", ".meta"]:
|
|
file_to_check = flags.model_checkpoint_path + endpattern
|
|
if not os.path.isfile(file_to_check):
|
|
raise FileNotFoundError("The checkpoint file `%s` does not exist" % file_to_check)
|
|
|
|
print(" ========================= Export Flags =========================\n")
|
|
pprint.pprint(dict(flags._get_kwargs()))
|
|
print("\n %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%")
|
|
|
|
export_model(flags)
|