DeepLearningExamples/TensorFlow/Segmentation/VNet/utils/tf_export.py
Przemek Strzelczyk b4aef9945b Adding VNet/TF
2019-12-02 15:57:25 +01:00

287 lines
9.8 KiB
Python

# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import glob
import inspect
import os
import shutil
import subprocess
from argparse import Namespace
from typing import List, Callable
import tensorflow as tf
from google.protobuf import text_format
from tensorflow.core.framework import graph_pb2
from tensorflow.python.compiler.tensorrt import trt_convert as trt
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import graph_io
from tensorflow.python.platform import gfile
from tensorflow.python.tools import optimize_for_inference_lib
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
def _compress(src_path: str, dst_path: str):
"""
Compress source path into destination path
:param src_path: (str) Source path
:param dst_path: (str) Destination path
"""
print('[*] Compressing...')
shutil.make_archive(dst_path, 'zip', src_path)
print('[*] Compressed the contents in: {}.zip'.format(dst_path))
def _print_input(func: Callable):
"""
Decorator printing function name and args
:param func: (Callable) Decorated function
:return: Wrapped call
"""
def wrapper(*args, **kwargs):
"""
Print the name and arguments of a function
:param args: Named arguments
:param kwargs: Keyword arguments
:return: Original function call
"""
tf.logging.set_verbosity(tf.logging.ERROR)
func_args = inspect.signature(func).bind(*args, **kwargs).arguments
func_args_str = ''.join('\t{} = {!r}\n'.format(*item) for item in func_args.items())
print('[*] Running \'{}\' with arguments:'.format(func.__qualname__))
print(func_args_str[:-1])
return func(*args, **kwargs)
return wrapper
def _parse_placeholder_types(values: str):
"""
Extracts placeholder types from a comma separate list.
:param values: (str) Placeholder types
:return: (List) Placeholder types
"""
values = [int(value) for value in values.split(",")]
return values if len(values) > 1 else values[0]
def _optimize_checkpoint_for_inference(graph_path: str,
input_names: List[str],
output_names: List[str]):
"""
Removes Horovod and training related information from the graph
:param graph_path: (str) Path to the graph.pbtxt file
:param input_names: (str) Input node names
:param output_names: (str) Output node names
"""
print('[*] Optimizing graph for inference ...')
input_graph_def = graph_pb2.GraphDef()
with gfile.Open(graph_path, "rb") as f:
data = f.read()
text_format.Merge(data.decode("utf-8"), input_graph_def)
output_graph_def = optimize_for_inference_lib.optimize_for_inference(
input_graph_def,
input_names,
output_names,
_parse_placeholder_types(str(dtypes.float32.as_datatype_enum)),
False)
print('[*] Saving original graph in: {}'.format(graph_path + '.old'))
shutil.move(graph_path, graph_path + '.old')
print('[*] Writing down optimized graph ...')
graph_io.write_graph(output_graph_def,
os.path.dirname(graph_path),
os.path.basename(graph_path))
@_print_input
def to_savedmodel(input_shape: str,
model_fn: Callable,
checkpoint_dir: str,
output_dir: str,
input_names: List[str],
output_names: List[str],
use_amp: bool,
use_xla: bool,
compress: bool,
params: Namespace):
"""
Export checkpoint to Tensorflow savedModel
:param input_shape: (str) Input shape to the model in format [batch, height, width, channels]
:param model_fn: (Callable) Estimator's model_fn
:param checkpoint_dir: (str) Directory where checkpoints are stored
:param output_dir: (str) Output directory for storage of the generated savedModel
:param input_names: (List[str]) Input node names
:param output_names: (List[str]) Output node names
:param use_amp: (bool )Enable TF-AMP
:param use_xla: (bool) Enable XLA
:param compress: (bool) Compress output
:param params: (Namespace) Namespace to be passed to model_fn
"""
assert os.path.exists(checkpoint_dir), 'Path not found: {}'.format(checkpoint_dir)
assert input_shape is not None, 'Input shape must be provided'
_optimize_checkpoint_for_inference(os.path.join(checkpoint_dir, 'graph.pbtxt'), input_names, output_names)
try:
ckpt_path = os.path.splitext([p for p in glob.iglob(os.path.join(checkpoint_dir, '*.index'))][0])[0]
except IndexError:
raise ValueError('Could not find checkpoint in directory: {}'.format(checkpoint_dir))
config_proto = tf.compat.v1.ConfigProto()
config_proto.allow_soft_placement = True
config_proto.log_device_placement = False
config_proto.gpu_options.allow_growth = True
config_proto.gpu_options.force_gpu_compatible = True
if use_amp:
os.environ["TF_ENABLE_AUTO_MIXED_PRECISION_GRAPH_REWRITE"] = "1"
if use_xla:
config_proto.graph_options.optimizer_options.global_jit_level = tf.compat.v1.OptimizerOptions.ON_1
run_config = tf.estimator.RunConfig(
model_dir=None,
tf_random_seed=None,
save_summary_steps=1e9, # disabled
save_checkpoints_steps=None,
save_checkpoints_secs=None,
session_config=config_proto,
keep_checkpoint_max=None,
keep_checkpoint_every_n_hours=1e9, # disabled
log_step_count_steps=1e9,
train_distribute=None,
device_fn=None,
protocol=None,
eval_distribute=None,
experimental_distribute=None
)
estimator = tf.estimator.Estimator(
model_fn=model_fn,
model_dir=ckpt_path,
config=run_config,
params=params
)
print('[*] Exporting the model ...')
input_type = tf.float16 if use_amp else tf.float32
def get_serving_input_receiver_fn():
def serving_input_receiver_fn():
features = tf.placeholder(dtype=input_type, shape=input_shape, name='input_tensor')
return tf.estimator.export.TensorServingInputReceiver(features=features, receiver_tensors=features)
return serving_input_receiver_fn
export_path = estimator.export_saved_model(
export_dir_base=output_dir,
serving_input_receiver_fn=get_serving_input_receiver_fn(),
checkpoint_path=ckpt_path
)
print('[*] Done! path: `%s`' % export_path.decode())
if compress:
_compress(export_path.decode(), os.path.join(output_dir, 'saved_model'))
@_print_input
def to_tf_trt(savedmodel_dir: str,
output_dir: str,
precision: str,
feed_dict_fn: Callable,
num_runs: int,
output_tensor_names: List[str],
compress: bool):
"""
Export Tensorflow savedModel to TF-TRT
:param savedmodel_dir: (str) Input directory containing a Tensorflow savedModel
:param output_dir: (str) Output directory for storage of the generated TF-TRT exported model
:param precision: (str) Desired precision of the network (FP32, FP16 or INT8)
:param feed_dict_fn: (Callable) Input tensors for INT8 calibration. Model specific.
:param num_runs: (int) Number of calibration runs.
:param output_tensor_names: (List) Name of the output tensor for graph conversion. Model specific.
:param compress: (bool) Compress output
"""
if savedmodel_dir is None or not os.path.exists(savedmodel_dir):
raise FileNotFoundError('savedmodel_dir not found: {}'.format(savedmodel_dir))
if os.path.exists(output_dir):
print('[*] Output dir \'{}\' is not empty. Cleaning up ...'.format(output_dir))
shutil.rmtree(output_dir)
print('[*] Converting model...')
converter = trt.TrtGraphConverter(input_saved_model_dir=savedmodel_dir,
precision_mode=precision)
converter.convert()
if precision == 'INT8':
print('[*] Running INT8 calibration ...')
converter.calibrate(fetch_names=output_tensor_names, num_runs=num_runs, feed_dict_fn=feed_dict_fn)
converter.save(output_dir)
print('[*] Done! TF-TRT saved_model stored in: `%s`' % output_dir)
if compress:
_compress('tftrt_saved_model', output_dir)
@_print_input
def to_onnx(input_dir: str, output_dir: str, compress: bool):
"""
Convert Tensorflow savedModel to ONNX with tf2onnx
:param input_dir: (str) Input directory with a Tensorflow savedModel
:param output_dir: (str) Output directory where to store the ONNX version of the model
:param compress: (bool) Compress output
"""
if not os.path.exists(output_dir):
os.makedirs(output_dir)
file_name = os.path.join(output_dir, 'model.onnx')
print('[*] Converting model...')
ret = subprocess.call(['python', '-m', 'tf2onnx.convert',
'--saved-model', input_dir,
'--output', file_name],
stdout=open(os.devnull, 'w'),
stderr=subprocess.STDOUT)
if ret > 0:
raise RuntimeError('tf2onnx.convert has failed with error: {}'.format(ret))
print('[*] Done! ONNX file stored in: %s' % file_name)
if compress:
_compress(output_dir, 'onnx_model')