134 lines
5.2 KiB
Python
Executable file
134 lines
5.2 KiB
Python
Executable file
#!/usr/bin/env python3
|
|
|
|
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
r"""
|
|
To infer the model on framework runtime, you can use `run_inference_on_fw.py` script.
|
|
It infers data obtained from pointed data loader locally and saves received data into npz files.
|
|
Those files are stored in directory pointed by `--output-dir` argument.
|
|
|
|
Example call:
|
|
|
|
```shell script
|
|
python ./triton/run_inference_on_fw.py \
|
|
--input-path /models/exported/model.onnx \
|
|
--input-type onnx \
|
|
--dataloader triton/dataloader.py \
|
|
--data-dir /data/imagenet \
|
|
--batch-size 32 \
|
|
--output-dir /results/dump_local \
|
|
--dump-labels
|
|
```
|
|
"""
|
|
|
|
import argparse
|
|
import logging
|
|
import os
|
|
from pathlib import Path
|
|
|
|
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
|
|
os.environ["TF_ENABLE_DEPRECATION_WARNINGS"] = "0"
|
|
|
|
from tqdm import tqdm
|
|
|
|
# method from PEP-366 to support relative import in executed modules
|
|
if __package__ is None:
|
|
__package__ = Path(__file__).parent.name
|
|
|
|
from .deployment_toolkit.args import ArgParserGenerator
|
|
from .deployment_toolkit.core import DATALOADER_FN_NAME, BaseLoader, BaseRunner, Format, load_from_file
|
|
from .deployment_toolkit.dump import NpzWriter
|
|
from .deployment_toolkit.extensions import loaders, runners
|
|
|
|
LOGGER = logging.getLogger("run_inference_on_fw")
|
|
|
|
|
|
def _verify_and_format_dump(args, ids, x, y_pred, y_real):
|
|
data = {"outputs": y_pred, "ids": {"ids": ids}}
|
|
if args.dump_inputs:
|
|
data["inputs"] = x
|
|
if args.dump_labels:
|
|
if not y_real:
|
|
raise ValueError(
|
|
"Found empty label values. Please provide labels in dataloader_fn or do not use --dump-labels argument"
|
|
)
|
|
data["labels"] = y_real
|
|
return data
|
|
|
|
|
|
def _parse_and_validate_args():
|
|
supported_inputs = set(runners.supported_extensions) & set(loaders.supported_extensions)
|
|
|
|
parser = argparse.ArgumentParser(description="Dump local inference output of given model", allow_abbrev=False)
|
|
parser.add_argument("--input-path", help="Path to input model", required=True)
|
|
parser.add_argument("--input-type", help="Input model type", choices=supported_inputs, required=True)
|
|
parser.add_argument("--dataloader", help="Path to python file containing dataloader.", required=True)
|
|
parser.add_argument("--output-dir", help="Path to dir where output files will be stored", required=True)
|
|
parser.add_argument("--dump-labels", help="Dump labels to output dir", action="store_true", default=False)
|
|
parser.add_argument("--dump-inputs", help="Dump inputs to output dir", action="store_true", default=False)
|
|
parser.add_argument("-v", "--verbose", help="Verbose logs", action="store_true", default=False)
|
|
|
|
args, *_ = parser.parse_known_args()
|
|
|
|
get_dataloader_fn = load_from_file(args.dataloader, label="dataloader", target=DATALOADER_FN_NAME)
|
|
ArgParserGenerator(get_dataloader_fn).update_argparser(parser)
|
|
|
|
Loader: BaseLoader = loaders.get(args.input_type)
|
|
ArgParserGenerator(Loader, module_path=args.input_path).update_argparser(parser)
|
|
|
|
Runner: BaseRunner = runners.get(args.input_type)
|
|
ArgParserGenerator(Runner).update_argparser(parser)
|
|
|
|
args = parser.parse_args()
|
|
|
|
types_requiring_io_params = []
|
|
|
|
if args.input_type in types_requiring_io_params and not all(p for p in [args.inputs, args.outputs]):
|
|
parser.error(f"For {args.input_type} input provide --inputs and --outputs parameters")
|
|
|
|
return args
|
|
|
|
|
|
def main():
|
|
args = _parse_and_validate_args()
|
|
|
|
log_level = logging.INFO if not args.verbose else logging.DEBUG
|
|
log_format = "%(asctime)s %(levelname)s %(name)s %(message)s"
|
|
logging.basicConfig(level=log_level, format=log_format)
|
|
|
|
LOGGER.info(f"args:")
|
|
for key, value in vars(args).items():
|
|
LOGGER.info(f" {key} = {value}")
|
|
|
|
Loader: BaseLoader = loaders.get(args.input_type)
|
|
Runner: BaseRunner = runners.get(args.input_type)
|
|
|
|
loader = ArgParserGenerator(Loader, module_path=args.input_path).from_args(args)
|
|
runner = ArgParserGenerator(Runner).from_args(args)
|
|
LOGGER.info(f"Loading {args.input_path}")
|
|
model = loader.load(args.input_path)
|
|
with runner.init_inference(model=model) as runner_session, NpzWriter(args.output_dir) as writer:
|
|
get_dataloader_fn = load_from_file(args.dataloader, label="dataloader", target=DATALOADER_FN_NAME)
|
|
dataloader_fn = ArgParserGenerator(get_dataloader_fn).from_args(args)
|
|
LOGGER.info(f"Data loader initialized; Running inference")
|
|
for ids, x, y_real in tqdm(dataloader_fn(), unit="batch", mininterval=10):
|
|
y_pred = runner_session(x)
|
|
data = _verify_and_format_dump(args, ids=ids, x=x, y_pred=y_pred, y_real=y_real)
|
|
writer.write(**data)
|
|
LOGGER.info(f"Inference finished")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|