289 lines
11 KiB
Python
Executable file
289 lines
11 KiB
Python
Executable file
#!/usr/bin/env python3
|
|
|
|
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
r"""
|
|
To infer the model deployed on Triton, you can use `run_inference_on_triton.py` script.
|
|
It sends a request with data obtained from pointed data loader and dumps received data into
|
|
[npz files](https://gitlab-master.nvidia.com/dl/JoC/bermuda-api/-/blob/develop/bermuda_api_toolset/docs/dump_files.md).
|
|
Those files are stored in directory pointed by `--output-dir` argument.
|
|
|
|
Currently, the client communicates with the Triton server asynchronously using GRPC protocol.
|
|
|
|
Example call:
|
|
|
|
```shell script
|
|
python ./triton/run_inference_on_triton.py \
|
|
--server-url localhost:8001 \
|
|
--model-name ResNet50 \
|
|
--model-version 1 \
|
|
--dump-labels \
|
|
--output-dir /results/dump_triton
|
|
```
|
|
"""
|
|
|
|
import argparse
|
|
import functools
|
|
import logging
|
|
import queue
|
|
import threading
|
|
import time
|
|
from pathlib import Path
|
|
from typing import Optional
|
|
|
|
from tqdm import tqdm
|
|
|
|
# pytype: disable=import-error
|
|
try:
|
|
from tritonclient import utils as client_utils # noqa: F401
|
|
from tritonclient.grpc import (
|
|
InferenceServerClient,
|
|
InferInput,
|
|
InferRequestedOutput,
|
|
)
|
|
except ImportError:
|
|
import tritongrpcclient as grpc_client
|
|
from tritongrpcclient import (
|
|
InferenceServerClient,
|
|
InferInput,
|
|
InferRequestedOutput,
|
|
)
|
|
# pytype: enable=import-error
|
|
|
|
# method from PEP-366 to support relative import in executed modules
|
|
if __package__ is None:
|
|
__package__ = Path(__file__).parent.name
|
|
|
|
from .deployment_toolkit.args import ArgParserGenerator
|
|
from .deployment_toolkit.core import DATALOADER_FN_NAME, load_from_file
|
|
from .deployment_toolkit.dump import NpzWriter
|
|
|
|
LOGGER = logging.getLogger("run_inference_on_triton")
|
|
|
|
|
|
class AsyncGRPCTritonRunner:
|
|
DEFAULT_MAX_RESP_WAIT_S = 120
|
|
DEFAULT_MAX_UNRESP_REQS = 128
|
|
DEFAULT_MAX_FINISH_WAIT_S = 900 # 15min
|
|
|
|
def __init__(
|
|
self,
|
|
server_url: str,
|
|
model_name: str,
|
|
model_version: str,
|
|
*,
|
|
dataloader,
|
|
verbose=False,
|
|
resp_wait_s: Optional[float] = None,
|
|
max_unresponded_reqs: Optional[int] = None,
|
|
):
|
|
self._server_url = server_url
|
|
self._model_name = model_name
|
|
self._model_version = model_version
|
|
self._dataloader = dataloader
|
|
self._verbose = verbose
|
|
self._response_wait_t = self.DEFAULT_MAX_RESP_WAIT_S if resp_wait_s is None else resp_wait_s
|
|
self._max_unresp_reqs = self.DEFAULT_MAX_UNRESP_REQS if max_unresponded_reqs is None else max_unresponded_reqs
|
|
|
|
self._results = queue.Queue()
|
|
self._processed_all = False
|
|
self._errors = []
|
|
self._num_waiting_for = 0
|
|
self._sync = threading.Condition()
|
|
self._req_thread = threading.Thread(target=self.req_loop, daemon=True)
|
|
|
|
def __iter__(self):
|
|
self._req_thread.start()
|
|
timeout_s = 0.050 # check flags processed_all and error flags every 50ms
|
|
while True:
|
|
try:
|
|
ids, x, y_pred, y_real = self._results.get(timeout=timeout_s)
|
|
yield ids, x, y_pred, y_real
|
|
except queue.Empty:
|
|
shall_stop = self._processed_all or self._errors
|
|
if shall_stop:
|
|
break
|
|
|
|
LOGGER.debug("Waiting for request thread to stop")
|
|
self._req_thread.join()
|
|
if self._errors:
|
|
error_msg = "\n".join(map(str, self._errors))
|
|
raise RuntimeError(error_msg)
|
|
|
|
def _on_result(self, ids, x, y_real, output_names, result, error):
|
|
with self._sync:
|
|
if error:
|
|
self._errors.append(error)
|
|
else:
|
|
y_pred = {name: result.as_numpy(name) for name in output_names}
|
|
self._results.put((ids, x, y_pred, y_real))
|
|
self._num_waiting_for -= 1
|
|
self._sync.notify_all()
|
|
|
|
def req_loop(self):
|
|
client = InferenceServerClient(self._server_url, verbose=self._verbose)
|
|
self._errors = self._verify_triton_state(client)
|
|
if self._errors:
|
|
return
|
|
|
|
LOGGER.debug(
|
|
f"Triton server {self._server_url} and model {self._model_name}:{self._model_version} " f"are up and ready!"
|
|
)
|
|
|
|
model_config = client.get_model_config(self._model_name, self._model_version)
|
|
model_metadata = client.get_model_metadata(self._model_name, self._model_version)
|
|
LOGGER.info(f"Model config {model_config}")
|
|
LOGGER.info(f"Model metadata {model_metadata}")
|
|
|
|
inputs = {tm.name: tm for tm in model_metadata.inputs}
|
|
outputs = {tm.name: tm for tm in model_metadata.outputs}
|
|
output_names = list(outputs)
|
|
outputs_req = [InferRequestedOutput(name) for name in outputs]
|
|
|
|
self._num_waiting_for = 0
|
|
|
|
for ids, x, y_real in self._dataloader:
|
|
infer_inputs = []
|
|
for name in inputs:
|
|
data = x[name]
|
|
infer_input = InferInput(name, data.shape, inputs[name].datatype)
|
|
|
|
target_np_dtype = client_utils.triton_to_np_dtype(inputs[name].datatype)
|
|
data = data.astype(target_np_dtype)
|
|
|
|
infer_input.set_data_from_numpy(data)
|
|
infer_inputs.append(infer_input)
|
|
|
|
with self._sync:
|
|
|
|
def _check_can_send():
|
|
return self._num_waiting_for < self._max_unresp_reqs
|
|
|
|
can_send = self._sync.wait_for(_check_can_send, timeout=self._response_wait_t)
|
|
if not can_send:
|
|
error_msg = f"Runner could not send new requests for {self._response_wait_t}s"
|
|
self._errors.append(error_msg)
|
|
break
|
|
|
|
callback = functools.partial(AsyncGRPCTritonRunner._on_result, self, ids, x, y_real, output_names)
|
|
client.async_infer(
|
|
model_name=self._model_name,
|
|
model_version=self._model_version,
|
|
inputs=infer_inputs,
|
|
outputs=outputs_req,
|
|
callback=callback,
|
|
)
|
|
self._num_waiting_for += 1
|
|
|
|
# wait till receive all requested data
|
|
with self._sync:
|
|
|
|
def _all_processed():
|
|
LOGGER.debug(f"wait for {self._num_waiting_for} unprocessed jobs")
|
|
return self._num_waiting_for == 0
|
|
|
|
self._processed_all = self._sync.wait_for(_all_processed, self.DEFAULT_MAX_FINISH_WAIT_S)
|
|
if not self._processed_all:
|
|
error_msg = f"Runner {self._response_wait_t}s timeout received while waiting for results from server"
|
|
self._errors.append(error_msg)
|
|
LOGGER.debug("Finished request thread")
|
|
|
|
def _verify_triton_state(self, triton_client):
|
|
errors = []
|
|
if not triton_client.is_server_live():
|
|
errors.append(f"Triton server {self._server_url} is not live")
|
|
elif not triton_client.is_server_ready():
|
|
errors.append(f"Triton server {self._server_url} is not ready")
|
|
elif not triton_client.is_model_ready(self._model_name, self._model_version):
|
|
errors.append(f"Model {self._model_name}:{self._model_version} is not ready")
|
|
return errors
|
|
|
|
|
|
def _parse_args():
|
|
parser = argparse.ArgumentParser(description="Infer model on Triton server", allow_abbrev=False)
|
|
parser.add_argument(
|
|
"--server-url", type=str, default="localhost:8001", help="Inference server URL (default localhost:8001)"
|
|
)
|
|
parser.add_argument("--model-name", help="The name of the model used for inference.", required=True)
|
|
parser.add_argument("--model-version", help="The version of the model used for inference.", required=True)
|
|
parser.add_argument("--dataloader", help="Path to python file containing dataloader.", required=True)
|
|
parser.add_argument("--dump-labels", help="Dump labels to output dir", action="store_true", default=False)
|
|
parser.add_argument("--dump-inputs", help="Dump inputs to output dir", action="store_true", default=False)
|
|
parser.add_argument("-v", "--verbose", help="Verbose logs", action="store_true", default=False)
|
|
parser.add_argument("--output-dir", required=True, help="Path to directory where outputs will be saved")
|
|
parser.add_argument("--response-wait-time", required=False, help="Maximal time to wait for response", default=120)
|
|
parser.add_argument(
|
|
"--max-unresponded-requests", required=False, help="Maximal number of unresponded requests", default=128
|
|
)
|
|
|
|
args, *_ = parser.parse_known_args()
|
|
|
|
get_dataloader_fn = load_from_file(args.dataloader, label="dataloader", target=DATALOADER_FN_NAME)
|
|
ArgParserGenerator(get_dataloader_fn).update_argparser(parser)
|
|
args = parser.parse_args()
|
|
|
|
return args
|
|
|
|
|
|
def main():
|
|
args = _parse_args()
|
|
|
|
log_format = "%(asctime)s %(levelname)s %(name)s %(message)s"
|
|
log_level = logging.INFO if not args.verbose else logging.DEBUG
|
|
logging.basicConfig(level=log_level, format=log_format)
|
|
|
|
LOGGER.info(f"args:")
|
|
for key, value in vars(args).items():
|
|
LOGGER.info(f" {key} = {value}")
|
|
|
|
get_dataloader_fn = load_from_file(args.dataloader, label="dataloader", target=DATALOADER_FN_NAME)
|
|
dataloader_fn = ArgParserGenerator(get_dataloader_fn).from_args(args)
|
|
|
|
runner = AsyncGRPCTritonRunner(
|
|
args.server_url,
|
|
args.model_name,
|
|
args.model_version,
|
|
dataloader=dataloader_fn(),
|
|
verbose=False,
|
|
resp_wait_s=args.response_wait_time,
|
|
max_unresponded_reqs=args.max_unresponded_requests,
|
|
)
|
|
|
|
with NpzWriter(output_dir=args.output_dir) as writer:
|
|
start = time.time()
|
|
for ids, x, y_pred, y_real in tqdm(runner, unit="batch", mininterval=10):
|
|
data = _verify_and_format_dump(args, ids, x, y_pred, y_real)
|
|
writer.write(**data)
|
|
stop = time.time()
|
|
|
|
LOGGER.info(f"\nThe inference took {stop - start:0.3f}s")
|
|
|
|
|
|
def _verify_and_format_dump(args, ids, x, y_pred, y_real):
|
|
data = {"outputs": y_pred, "ids": {"ids": ids}}
|
|
if args.dump_inputs:
|
|
data["inputs"] = x
|
|
if args.dump_labels:
|
|
if not y_real:
|
|
raise ValueError(
|
|
"Found empty label values. Please provide labels in dataloader_fn or do not use --dump-labels argument"
|
|
)
|
|
data["labels"] = y_real
|
|
return data
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|