[Convnets/TF] TF-TRT support

This commit is contained in:
Lukasz Pierscieniewski 2021-11-02 06:53:59 -07:00 committed by Krzysztof Kudrynski
parent a0c9442f63
commit dcd3bbac09
15 changed files with 223 additions and 103 deletions

View file

@ -22,7 +22,7 @@ import os
import tensorflow as tf
import horovod.tensorflow as hvd
from utils import hvd_wrapper as hvd
from model import resnet
tf.app.flags.DEFINE_string(
@ -75,8 +75,6 @@ FLAGS = tf.app.flags.FLAGS
def main(_):
# Initialize Horovod (TODO: Remove dependency of horovod for freezing graphs)
hvd.init()
if not FLAGS.output_file:

View file

@ -0,0 +1,134 @@
import argparse
import os
import pathlib
import time
import tempfile
import tensorflow as tf
import numpy as np
from tensorflow.python.compiler.tensorrt import trt_convert as trt
import dllogger
from runtime import runner_utils
from runtime import runner
from model.resnet import model_architectures
from utils import data_utils
from utils import hvd_wrapper as hvd
OUTPUT_SAVED_MODEL_PATH = tempfile.mkdtemp(prefix="tftrt-converted")
LOG_FREQUENCY = 100
def argument_parser() -> argparse.Namespace:
parser = argparse.ArgumentParser()
exclusive_args = parser.add_mutually_exclusive_group()
exclusive_args.add_argument("--model", type=str, default=None, help="Saved model location to use for inference")
exclusive_args.add_argument("--architecture", type=str, choices=model_architectures.keys())
parser.add_argument("--log-path", type=str, default="./log.json", help="Path to log file")
parser.add_argument("--tf-trt", action="store_true", default=False, help="Use TF-TRT for inference")
parser.add_argument("--amp", action="store_true", default=False, help="Use AMP for inference")
parser.add_argument("--data-dir", type=str, required=False,
default=None, help="Localization of validation data")
parser.add_argument("--batch-size", type=int, default=1, help="Batch size for inference")
return parser.parse_args()
def main(args: argparse.Namespace):
hvd.init()
dllogger.init(backends=[
dllogger.JSONStreamBackend(verbosity=dllogger.Verbosity.VERBOSE, filename=args.log_path),
dllogger.StdOutBackend(verbosity=dllogger.Verbosity.VERBOSE)
])
dllogger.log(data=vars(args), step='PARAMETER')
if args.model is None:
saved_model_to_load = tempfile.mkdtemp(prefix="tftrt-savedmodel")
r = runner.Runner(n_classes=1001, architecture=args.architecture, use_tf_amp=args.amp,
model_dir=saved_model_to_load)
r.train("batch", 1, 1, args.batch_size, is_benchmark=True)
r.evaluate("batch", 1, args.batch_size, export_dir=saved_model_to_load,
is_benchmark=True)
saved_model_to_load = r.exported_path.decode("utf-8")
else:
saved_model_to_load = args.model
output_tensor_name = "y_preds_ref:0" if not args.tf_trt else "ArgMax:0"
batch_size = args.batch_size
if args.tf_trt:
converter = trt.TrtGraphConverter(input_saved_model_dir=str(saved_model_to_load),
precision_mode="FP16" if args.amp else "FP32")
converter.convert()
converter.save(OUTPUT_SAVED_MODEL_PATH)
saved_model_to_load = OUTPUT_SAVED_MODEL_PATH
elif args.amp:
os.environ["TF_ENABLE_AUTO_MIXED_PRECISION_GRAPH_REWRITE"] = "1"
if args.data_dir is not None:
filenames, _, num_steps, _, _ = runner_utils.parse_tfrecords_dataset(
data_dir=str(args.data_dir),
mode="validation",
iter_unit="epoch",
num_iter=1,
global_batch_size=batch_size,
)
dataset = data_utils.get_tfrecords_input_fn(filenames=filenames,
batch_size=batch_size,
height=224,
width=224,
training=False,
distort_color=False,
num_threads=1,
deterministic=True)
iterator = dataset.make_initializable_iterator()
next_item = iterator.get_next()
else:
num_steps=60000 / batch_size
with tf.Session() as sess:
if args.data_dir is not None:
sess.run(iterator.initializer)
tf.saved_model.loader.load(sess,
[tf.saved_model.tag_constants.SERVING],
str(saved_model_to_load))
try:
start_time = time.time()
last_time = start_time
image_processed = 0
image_correct = 0
for samples_processed in range(int(num_steps)):
if args.data_dir is not None:
next_batch_image, next_batch_target = sess.run(next_item)
else:
if samples_processed == 0:
next_batch_image = np.random.normal(size=(batch_size, 224, 224, 3))
next_batch_target = np.random.randint(0, 1000, size=(batch_size,))
output = sess.run([output_tensor_name], feed_dict={"input_tensor:0": next_batch_image})
image_processed += args.batch_size
image_correct += np.sum(output == next_batch_target)
if samples_processed % LOG_FREQUENCY == 0 and samples_processed != 0:
current_time = time.time()
current_throughput = LOG_FREQUENCY * batch_size / (current_time - last_time)
dllogger.log(step=(0, samples_processed), data={"throughput": current_throughput})
last_time = current_time
except tf.errors.OutOfRangeError:
pass
finally:
dllogger.log(step=tuple(), data={"throughput": image_processed / (last_time - start_time),
"accuracy": image_correct / image_processed})
if __name__ == "__main__":
main(argument_parser())

View file

@ -22,10 +22,9 @@ warnings.simplefilter("ignore")
import tensorflow as tf
import horovod.tensorflow as hvd
from utils import hvd_wrapper as hvd
import dllogger
from utils import hvd_utils
from runtime import Runner
from model.resnet import model_architectures
@ -36,7 +35,7 @@ if __name__ == "__main__":
tf.logging.set_verbosity(tf.logging.ERROR)
FLAGS = parse_cmdline(model_architectures.keys())
hvd.init()
hvd.init(True)
if hvd.rank() == 0:
log_path = os.path.join(FLAGS.results_dir, FLAGS.log_filename)
@ -100,11 +99,10 @@ if __name__ == "__main__":
if FLAGS.mode in ["train_and_evaluate", 'evaluate', 'inference_benchmark']:
if FLAGS.mode == 'inference_benchmark' and hvd_utils.is_using_hvd():
if FLAGS.mode == 'inference_benchmark' and hvd.size() > 1:
raise NotImplementedError("Only single GPU inference is implemented.")
elif not hvd_utils.is_using_hvd() or hvd.rank() == 0:
elif hvd.rank() == 0:
runner.evaluate(iter_unit=FLAGS.iter_unit if FLAGS.mode != "train_and_evaluate" else "epoch",
num_iter=FLAGS.num_iter if FLAGS.mode != "train_and_evaluate" else 1,
warmup_steps=FLAGS.warmup_steps,
@ -124,10 +122,10 @@ if __name__ == "__main__":
if not os.path.isfile(FLAGS.to_predict):
raise ValueError("Only prediction on single images is supported!")
if hvd_utils.is_using_hvd():
if hvd.size() > 1:
raise NotImplementedError("Only single GPU inference is implemented.")
elif not hvd_utils.is_using_hvd() or hvd.rank() == 0:
else:
runner.predict(FLAGS.to_predict,
quantize=FLAGS.quantize,
symmetric=FLAGS.symmetric,

View file

@ -64,10 +64,10 @@ def conv2d_block(
trainable=is_training,
dtype=tf.float32)
net = tf.nn.conv2d(inputs,
group_filter,
strides=strides,
padding='SAME',
data_format=data_format)
group_filter,
strides=strides,
padding='SAME',
data_format=data_format)
if use_batch_norm:
net = layers.batch_norm(
net,

View file

@ -19,15 +19,13 @@ from __future__ import print_function
import tensorflow as tf
import horovod.tensorflow as hvd
from utils import hvd_wrapper as hvd
import dllogger
from model import layers
from model import blocks
from utils import var_storage
from utils import hvd_utils
from utils.data_utils import normalized_inputs
from utils.learning_rate import learning_rate_scheduler
@ -337,8 +335,8 @@ class ResnetModel(object):
if params["apply_loss_scaling"]:
optimizer = FixedLossScalerOptimizer(optimizer, scale=params["loss_scale"])
if hvd_utils.is_using_hvd():
optimizer = hvd.DistributedOptimizer(optimizer)
if hvd.size() > 1:
optimizer = hvd.hvd_global_object.DistributedOptimizer(optimizer)
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
if mode != tf.estimator.ModeKeys.TRAIN:

View file

@ -276,7 +276,6 @@ The `utils/` directory contains the following utility modules:
- `cmdline_helper.py`: helper module for command line processing
- `data_utils.py`: module defining input data pipelines
- `dali_utils.py`: helper module for DALI
- `hvd_utils.py`: helper module for Horovod
- `image_processing.py`: image processing and data augmentation functions
- `learning_rate.py`: definition of used learning rate schedule
- `optimizers.py`: definition of used custom optimizers
@ -447,7 +446,11 @@ To run inference on a single example with a checkpoint and a model script, use:
`python main.py --mode predict --model_dir <path to model> --to_predict <path to image> --results_dir <path to results>`
The optional `--xla` and `--amp` flags control XLA and AMP during inference.
The optional `--xla` and `--amp` flags control XLA and AMP during inference. To run inference using TF-TRT, please use the following command:
`python inference.py --model <path to model> --tf-trt --batch-size <inference_batch_size> --data-dir <path to data>`
The optional `--amp` flag controls AMP during inference.
## Performance

View file

@ -283,7 +283,6 @@ The `utils/` directory contains the following utility modules:
- `cmdline_helper.py`: helper module for command line processing
- `data_utils.py`: module defining input data pipelines
- `dali_utils.py`: helper module for DALI
- `hvd_utils.py`: helper module for Horovod
- `image_processing.py`: image processing and data augmentation functions
- `learning_rate.py`: definition of used learning rate schedule
- `optimizers.py`: definition of used custom optimizers

View file

@ -21,13 +21,11 @@ import warnings
import tensorflow as tf
import numpy as np
import horovod.tensorflow as hvd
from model import resnet
from utils import hooks
from utils import data_utils
from utils import hvd_utils
from utils import hvd_wrapper as hvd
from runtime import runner_utils
@ -142,8 +140,8 @@ class Runner(object):
gpu_id=gpu_id)
run_config_additional = tf.contrib.training.HParams(
model_dir=model_dir, #if not hvd_utils.is_using_hvd() or hvd.rank() == 0 else None,
log_dir=log_dir if not hvd_utils.is_using_hvd() or hvd.rank() == 0 else None,
model_dir=model_dir,
log_dir=log_dir if hvd.rank() == 0 else None,
data_dir=data_dir,
data_idx_dir=data_idx_dir,
num_preprocessing_threads=num_preprocessing_threads)
@ -196,11 +194,7 @@ class Runner(object):
@staticmethod
def _get_global_batch_size(worker_batch_size):
if hvd_utils.is_using_hvd():
return worker_batch_size * hvd.size()
else:
return worker_batch_size
return worker_batch_size * hvd.size()
@staticmethod
def _get_session_config(mode, use_xla, use_dali, use_cpu, gpu_memory_fraction, gpu_id=0):
@ -225,7 +219,7 @@ class Runner(object):
config.gpu_options.visible_device_list = str(gpu_id)
config.gpu_options.force_gpu_compatible = True # Force pinned memory
if hvd_utils.is_using_hvd():
if hvd.size() > 1:
config.gpu_options.visible_device_list = str(hvd.local_rank())
config.gpu_options.force_gpu_compatible = True # Force pinned memory
@ -248,10 +242,7 @@ class Runner(object):
mode)
if seed is not None:
if hvd_utils.is_using_hvd():
tf_random_seed = 2 * (seed + hvd.rank())
else:
tf_random_seed = 2 * seed
tf_random_seed = 2 * (seed + hvd.rank())
else:
tf_random_seed = None
@ -277,11 +268,8 @@ class Runner(object):
experimental_distribute=None)
if mode == 'train':
if hvd_utils.is_using_hvd():
config = config.replace(save_checkpoints_steps=1000 if hvd.rank() == 0 else None,
keep_checkpoint_every_n_hours=3)
else:
config = config.replace(save_checkpoints_steps=1000, keep_checkpoint_every_n_hours=3)
config = config.replace(save_checkpoints_steps=1000 if hvd.rank() == 0 else None,
keep_checkpoint_every_n_hours=3)
return config
@ -343,7 +331,7 @@ class Runner(object):
else:
use_static_loss_scaling = False # Make sure it hasn't been set to True on FP32 training
num_gpus = 1 if not hvd_utils.is_using_hvd() else hvd.size()
num_gpus = hvd.size()
global_batch_size = batch_size * num_gpus
if self.run_hparams.data_dir is not None:
@ -402,8 +390,8 @@ class Runner(object):
)
training_hooks.append(self.training_logging_hook)
if hvd_utils.is_using_hvd():
bcast_hook = hvd.BroadcastGlobalVariablesHook(0)
if hvd.size() > 1:
bcast_hook = hvd.hvd_global_object.BroadcastGlobalVariablesHook(0)
training_hooks.append(bcast_hook)
training_hooks.append(hooks.PrefillStagingAreasHook())
@ -527,7 +515,7 @@ class Runner(object):
if self.run_hparams.data_dir is None and not is_benchmark:
raise ValueError('`data_dir` must be specified for evaluation!')
if hvd_utils.is_using_hvd() and hvd.rank() != 0:
if hvd.rank() != 0:
raise RuntimeError('Multi-GPU inference is not supported')
estimator_params = {'quantize': quantize,
@ -620,18 +608,25 @@ class Runner(object):
)
eval_throughput = self.eval_logging_hook.mean_throughput.value()
eval_latencies = np.array(self.eval_logging_hook.latencies) * 1000
eval_latencies_q = np.quantile(eval_latencies, q=[0.9, 0.95, 0.99])
eval_latencies_mean = np.mean(eval_latencies)
if len(self.eval_logging_hook.latencies) > 0:
eval_latencies = np.array(self.eval_logging_hook.latencies) * 1000
eval_latencies_q = np.quantile(eval_latencies, q=[0.9, 0.95, 0.99])
eval_latencies_mean = np.mean(eval_latencies)
additional_metrics = {
'eval_latency_avg': eval_latencies_mean,
'eval_latency_p90': eval_latencies_q[0],
'eval_latency_p95': eval_latencies_q[1],
'eval_latency_p99': eval_latencies_q[2],
}
else:
additional_metrics = {}
dllogger.log(data={
'top1_accuracy': float(eval_results['top1_accuracy']),
'top5_accuracy': float(eval_results['top5_accuracy']),
'eval_throughput': eval_throughput,
'eval_latency_avg': eval_latencies_mean,
'eval_latency_p90': eval_latencies_q[0],
'eval_latency_p95': eval_latencies_q[1],
'eval_latency_p99': eval_latencies_q[2],
**additional_metrics
},
step=tuple())
@ -644,7 +639,7 @@ class Runner(object):
data_format=self.run_hparams.input_format,
dtype=self.run_hparams.dtype)
image_classifier.export_savedmodel(export_dir, input_receiver_fn)
self.exported_path = image_classifier.export_savedmodel(export_dir, input_receiver_fn)
except KeyboardInterrupt:
print("Keyboard interrupt")

View file

@ -278,7 +278,6 @@ The `utils/` directory contains the following utility modules:
- `cmdline_helper.py`: helper module for command line processing
- `data_utils.py`: module defining input data pipelines
- `dali_utils.py`: helper module for DALI
- `hvd_utils.py`: helper module for Horovod
- `image_processing.py`: image processing and data augmentation functions
- `learning_rate.py`: definition of used learning rate schedule
- `optimizers.py`: definition of used custom optimizers

View file

@ -23,8 +23,8 @@ def get_model(
use_dali: bool = False,
gpu_memory_fraction=0.7,
):
import horovod.tensorflow as hvd
from runtime import Runner
from utils import hvd_wrapper as hvd
hvd.init()

View file

@ -18,10 +18,9 @@
import sys
import tensorflow as tf
import horovod.tensorflow as hvd
from utils import image_processing
from utils import hvd_utils
from utils import hvd_wrapper as hvd
from nvidia import dali
import nvidia.dali.plugin.tf as dali_tf

View file

@ -18,11 +18,10 @@
import sys
import tensorflow as tf
import horovod.tensorflow as hvd
from utils import image_processing
from utils import hvd_utils
from utils import dali_utils
from utils import hvd_wrapper as hvd
__all__ = ["get_synth_input_fn", "normalized_inputs"]
@ -82,16 +81,13 @@ def get_tfrecords_input_fn(filenames, batch_size, height, width, training, disto
shuffle_buffer_size = 4096
if deterministic:
if hvd_utils.is_using_hvd():
seed = 13 * (1 + hvd.rank())
else:
seed = 13
seed = 13 * hvd.rank()
else:
seed = None
ds = tf.data.Dataset.from_tensor_slices(filenames)
if hvd_utils.is_using_hvd() and training:
if hvd.size() > 1 and training:
ds = ds.shard(hvd.size(), hvd.rank())
ds = ds.interleave(tf.data.TFRecordDataset, cycle_length=10, block_length=8)

View file

@ -21,8 +21,7 @@ import tensorflow as tf
import dllogger
import signal
import horovod.tensorflow as hvd
from utils.hvd_utils import is_using_hvd
from utils import hvd_wrapper as hvd
__all__ = ['TrainingLoggingHook', 'TrainingPartitionHook']
@ -125,17 +124,17 @@ class TrainingPartitionHook(tf.estimator.SessionRunHook):
signal.signal(signal.SIGTERM, self._signal_handler)
def begin(self):
if is_using_hvd():
if hvd.size() > 1:
with tf.device("/cpu:0"):
self.input_op = tf.placeholder(tf.int32, shape=())
self.allreduce_op = hvd.allreduce(self.input_op, op=hvd.Sum,
name="signal_handler_all_reduce")
self.allreduce_op = hvd.hvd_global_object.allreduce(
self.input_op, op=hvd.hvd_global_object.Sum, name="signal_handler_all_reduce")
def before_run(self, run_context):
fetches = [tf.train.get_global_step()]
feed_dict = None
if is_using_hvd() and (self.global_step % self.sync_freq) == 0:
if hvd.size() > 1 and (self.global_step % self.sync_freq) == 0:
fetches += [self.allreduce_op]
feed_dict = {self.input_op: int(self.signal_recieved)}
@ -144,7 +143,7 @@ class TrainingPartitionHook(tf.estimator.SessionRunHook):
def after_run(self, run_context, run_values):
self.global_step = run_values.results[0] + 1
if is_using_hvd() and len(run_values.results) == 2:
if hvd.size() > 1 and len(run_values.results) == 2:
if run_values.results[1] > 0:
run_context.request_stop()
elif self.signal_recieved:

View file

@ -1,30 +0,0 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import horovod.tensorflow as hvd
def is_using_hvd():
return hvd.size() > 1
rank_env = ['HOROVOD_RANK', 'OMPI_COMM_WORLD_RANK', 'PMI_RANK']
size_env = ['HOROVOD_SIZE', 'OMPI_COMM_WORLD_SIZE', 'PMI_SIZE']
for r_var, s_var in zip(rank_env, size_env):
if r_var in os.environ and s_var in os.environ:
return int(s_var) > 1
return False

View file

@ -0,0 +1,32 @@
hvd_global_object = None
def init(use_horovod: bool = False):
global hvd_global_object
if use_horovod:
import horovod.tensorflow as hvd
hvd.init()
hvd_global_object = hvd
else:
class _DummyWrapper:
def rank(self): return 0
def size(self): return 1
def local_rank(self): return 0
def local_size(self): return 1
hvd_global_object = _DummyWrapper()
def size():
global hvd_global_object
return hvd_global_object.size()
def rank():
global hvd_global_object
return hvd_global_object.rank()
def local_rank():
global hvd_global_object
return hvd_global_object.local_rank()
def local_size():
global hvd_global_object
return hvd_global_object.local_size()