DeepLearningExamples/TensorFlow/Classification/ConvNets/inference.py
2021-11-02 06:53:59 -07:00

134 lines
5.4 KiB
Python

import argparse
import os
import pathlib
import time
import tempfile
import tensorflow as tf
import numpy as np
from tensorflow.python.compiler.tensorrt import trt_convert as trt
import dllogger
from runtime import runner_utils
from runtime import runner
from model.resnet import model_architectures
from utils import data_utils
from utils import hvd_wrapper as hvd
OUTPUT_SAVED_MODEL_PATH = tempfile.mkdtemp(prefix="tftrt-converted")
LOG_FREQUENCY = 100
def argument_parser() -> argparse.Namespace:
parser = argparse.ArgumentParser()
exclusive_args = parser.add_mutually_exclusive_group()
exclusive_args.add_argument("--model", type=str, default=None, help="Saved model location to use for inference")
exclusive_args.add_argument("--architecture", type=str, choices=model_architectures.keys())
parser.add_argument("--log-path", type=str, default="./log.json", help="Path to log file")
parser.add_argument("--tf-trt", action="store_true", default=False, help="Use TF-TRT for inference")
parser.add_argument("--amp", action="store_true", default=False, help="Use AMP for inference")
parser.add_argument("--data-dir", type=str, required=False,
default=None, help="Localization of validation data")
parser.add_argument("--batch-size", type=int, default=1, help="Batch size for inference")
return parser.parse_args()
def main(args: argparse.Namespace):
hvd.init()
dllogger.init(backends=[
dllogger.JSONStreamBackend(verbosity=dllogger.Verbosity.VERBOSE, filename=args.log_path),
dllogger.StdOutBackend(verbosity=dllogger.Verbosity.VERBOSE)
])
dllogger.log(data=vars(args), step='PARAMETER')
if args.model is None:
saved_model_to_load = tempfile.mkdtemp(prefix="tftrt-savedmodel")
r = runner.Runner(n_classes=1001, architecture=args.architecture, use_tf_amp=args.amp,
model_dir=saved_model_to_load)
r.train("batch", 1, 1, args.batch_size, is_benchmark=True)
r.evaluate("batch", 1, args.batch_size, export_dir=saved_model_to_load,
is_benchmark=True)
saved_model_to_load = r.exported_path.decode("utf-8")
else:
saved_model_to_load = args.model
output_tensor_name = "y_preds_ref:0" if not args.tf_trt else "ArgMax:0"
batch_size = args.batch_size
if args.tf_trt:
converter = trt.TrtGraphConverter(input_saved_model_dir=str(saved_model_to_load),
precision_mode="FP16" if args.amp else "FP32")
converter.convert()
converter.save(OUTPUT_SAVED_MODEL_PATH)
saved_model_to_load = OUTPUT_SAVED_MODEL_PATH
elif args.amp:
os.environ["TF_ENABLE_AUTO_MIXED_PRECISION_GRAPH_REWRITE"] = "1"
if args.data_dir is not None:
filenames, _, num_steps, _, _ = runner_utils.parse_tfrecords_dataset(
data_dir=str(args.data_dir),
mode="validation",
iter_unit="epoch",
num_iter=1,
global_batch_size=batch_size,
)
dataset = data_utils.get_tfrecords_input_fn(filenames=filenames,
batch_size=batch_size,
height=224,
width=224,
training=False,
distort_color=False,
num_threads=1,
deterministic=True)
iterator = dataset.make_initializable_iterator()
next_item = iterator.get_next()
else:
num_steps=60000 / batch_size
with tf.Session() as sess:
if args.data_dir is not None:
sess.run(iterator.initializer)
tf.saved_model.loader.load(sess,
[tf.saved_model.tag_constants.SERVING],
str(saved_model_to_load))
try:
start_time = time.time()
last_time = start_time
image_processed = 0
image_correct = 0
for samples_processed in range(int(num_steps)):
if args.data_dir is not None:
next_batch_image, next_batch_target = sess.run(next_item)
else:
if samples_processed == 0:
next_batch_image = np.random.normal(size=(batch_size, 224, 224, 3))
next_batch_target = np.random.randint(0, 1000, size=(batch_size,))
output = sess.run([output_tensor_name], feed_dict={"input_tensor:0": next_batch_image})
image_processed += args.batch_size
image_correct += np.sum(output == next_batch_target)
if samples_processed % LOG_FREQUENCY == 0 and samples_processed != 0:
current_time = time.time()
current_throughput = LOG_FREQUENCY * batch_size / (current_time - last_time)
dllogger.log(step=(0, samples_processed), data={"throughput": current_throughput})
last_time = current_time
except tf.errors.OutOfRangeError:
pass
finally:
dllogger.log(step=tuple(), data={"throughput": image_processed / (last_time - start_time),
"accuracy": image_correct / image_processed})
if __name__ == "__main__":
main(argument_parser())