[ResNet50/TF1][Triton] Bermuda 0.6.11 update

This commit is contained in:
Piotr Marcinkiewicz 2021-07-01 12:06:38 -07:00 committed by Krzysztof Kudrynski
parent c481324031
commit 96b70f5d16
26 changed files with 352 additions and 85 deletions

View file

@ -196,7 +196,7 @@ processing, and training of the model.
Generate the configuration from your model repository.
```
python3 triton/config_model_on_trion.py \
python3 triton/config_model_on_triton.py \
--model-repository ${MODEL_REPOSITORY_PATH} \
--model-path ${SHARED_DIR}/model \
--model-format ${FORMAT} \

View file

@ -1,6 +1,6 @@
#!/usr/bin/env python3
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -16,9 +16,8 @@
r"""
Using `calculate_metrics.py` script, you can obtain model accuracy/error metrics using defined `MetricsCalculator` class.
See [documentation](https://gitlab-master.nvidia.com/dl/JoC/bermuda-api/-/blob/develop/bermuda_api_toolset/docs/metrics.md) on preparation of this class.
Data provided to `MetricsCalculator` are obtained from [npz dump files](https://gitlab-master.nvidia.com/dl/JoC/bermuda-api/-/blob/develop/bermuda_api_toolset/docs/dump_files.md)
Data provided to `MetricsCalculator` are obtained from npz dump files
stored in directory pointed by `--dump-dir` argument.
Above files are prepared by `run_inference_on_fw.py` and `run_inference_on_triton.py` scripts.

View file

@ -1,6 +1,6 @@
#!/usr/bin/env python3
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -15,11 +15,11 @@
# limitations under the License.
r"""
To deploy model in Triton, you can use `deploy_model.py` script.
To configure model on Triton, you can use `config_model_on_triton.py` script.
This will prepare layout of Model Repository, including Model Configuration.
```shell script
python ./triton/deploy_model.py \
python ./triton/config_model_on_triton.py \
--model-repository /model_repository \
--model-path /models/exported/model.onnx \
--model-format onnx \
@ -28,7 +28,7 @@ python ./triton/deploy_model.py \
--max-batch-size 32 \
--precision fp16 \
--backend-accelerator trt \
--load-model \
--load-model explicit \
--timeout 120 \
--verbose
```
@ -36,14 +36,16 @@ python ./triton/deploy_model.py \
If Triton server to which we prepare model repository is running with **explicit model control mode**,
use `--load-model` argument to send request load_model request to Triton Inference Server.
If server is listening on non-default address or port use `--server-url` argument to point server control endpoint.
If it is required to use HTTP protocol to communcate with Triton server use `--http` argument.
If it is required to use HTTP protocol to communicate with Triton server use `--http` argument.
To improve inference throughput you can use
[dynamic batching](https://github.com/triton-inference-server/server/blob/master/docs/model_configuration.md#dynamic-batcher)
for your model by providing `--preferred-batch-sizes` and `--max-queue-delay-us` parameters.
For models which doesn't support batching, set `--max-batch-sizes` to 0.
By default Triton will [automatically obtain inputs and outputs definitions](https://github.com/triton-inference-server/server/blob/master/docs/model_configuration.md#auto-generated-model-configuration).
but for TorchScript models script uses file with I/O specs. This file is automatically generated
but for TorchScript ang TF GraphDef models script uses file with I/O specs. This file is automatically generated
when the model is converted to ScriptModule (either traced or scripted).
If there is a need to pass different than default path to I/O spec file use `--io-spec` CLI argument.
@ -67,13 +69,14 @@ I/O spec file is yaml file with below structure:
import argparse
import logging
import time
from service_maker import Accelerator, Format, Precision
from service_maker.args import str2bool
from service_maker.log import dump_arguments, set_logger
from service_maker.triton import ModelConfig, TritonClient, TritonModelStore
from model_navigator import Accelerator, Format, Precision
from model_navigator.args import str2bool
from model_navigator.log import set_logger, log_dict
from model_navigator.triton import ModelConfig, TritonClient, TritonModelStore
LOGGER = logging.getLogger("deploy_model")
LOGGER = logging.getLogger("config_model")
def _available_enum_values(my_enum):
@ -85,7 +88,7 @@ def main():
description="Create Triton model repository and model configuration", allow_abbrev=False
)
parser.add_argument("--model-repository", required=True, help="Path to Triton model repository.")
parser.add_argument("--model-path", required=True, help="Path to model to deploy")
parser.add_argument("--model-path", required=True, help="Path to model to configure")
# TODO: automation
parser.add_argument(
@ -161,7 +164,7 @@ def main():
args = parser.parse_args()
set_logger(verbose=args.verbose)
dump_arguments(args)
log_dict("args", vars(args))
config = ModelConfig.create(
model_path=args.model_path,
@ -184,8 +187,14 @@ def main():
if args.load_model != "none":
client = TritonClient(server_url=args.server_url, verbose=args.verbose)
client.wait_for_server_ready(timeout=args.timeout)
if args.load_model == "explicit":
client.load_model(model_name=args.model_name)
if args.load_model == "poll":
time.sleep(15)
client.wait_for_model(model_name=args.model_name, model_version=args.model_version, timeout_s=args.timeout)

View file

@ -1,6 +1,6 @@
#!/usr/bin/env python3
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -16,7 +16,7 @@
r"""
`convert_model.py` script allows to convert between model formats with additional model optimizations
for faster inference.
It converts model from results of [`get_model`](https://gitlab-master.nvidia.com/dl/JoC/bermuda-api/-/blob/develop/bermuda_api_toolset/docs/model.md) function.
It converts model from results of get_model function.
Currently supported input and output formats are:

View file

@ -1 +1 @@
0.4.6-46-g5bc739c
0.6.11

View file

@ -0,0 +1,13 @@
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View file

@ -1,7 +1,21 @@
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import inspect
import logging
from typing import Callable, Dict, Optional, Union
from typing import Any, Callable, Dict, Optional, Union
from .core import GET_ARGPARSER_FN_NAME, load_from_file
@ -39,7 +53,7 @@ def add_args_for_fn_signature(parser, fn) -> argparse.ArgumentParser:
if parameter.annotation == bool:
argument_kwargs["type"] = str2bool
argument_kwargs["choices"] = [0, 1]
elif type(parameter.annotation) == type(Union):
elif isinstance(parameter.annotation, type(Optional[Any])):
types = [type_ for type_ in parameter.annotation.__args__ if not isinstance(None, type_)]
if len(types) != 1:
raise RuntimeError(

View file

@ -0,0 +1,13 @@
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View file

@ -1,3 +1,17 @@
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from pathlib import Path
from typing import Dict, Optional, Union

View file

@ -1,3 +1,17 @@
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import Dict, Iterable, Optional

View file

@ -1,3 +1,17 @@
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import sys
from pathlib import Path
@ -10,7 +24,7 @@ try:
import pycuda.autoinit
import pycuda.driver as cuda
except (ImportError, Exception) as e:
logging.getLogger(__name__).debug(f"Problems with importing pycuda package; {e}")
logging.getLogger(__name__).warning(f"Problems with importing pycuda package; {e}")
# pytype: enable=import-error
import tensorrt as trt # pytype: disable=import-error
@ -199,4 +213,4 @@ if "pycuda.driver" in sys.modules:
runners.register_extension(Format.TRT.value, TensorRTRunner)
savers.register_extension(Format.TRT.value, TensorRTSaver)
else:
LOGGER.debug("Do not register TensorRT extension due problems with importing pycuda.driver package.")
LOGGER.warning("Do not register TensorRT extension due problems with importing pycuda.driver package.")

View file

@ -1,3 +1,17 @@
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from pathlib import Path
from typing import Dict, Iterable, Optional, Tuple, Union
@ -322,6 +336,13 @@ class TFKerasLoader(BaseLoader):
tf.keras.backend.clear_session()
tf.keras.backend.set_learning_phase(False)
def _add_suffix_as_quickfix_for_tf24_func_refactor(spec):
if not spec.name.endswith(":0"):
spec = spec._replace(name=spec.name + ":0")
return spec
inputs_dict = {name: _add_suffix_as_quickfix_for_tf24_func_refactor(spec) for name, spec in inputs_dict.items()}
return Model(graph_def, precision, inputs_dict, outputs_dict)

View file

@ -1,3 +1,17 @@
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import Iterable
# pytype: disable=import-error

View file

@ -1,3 +1,17 @@
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Iterable
from ..core import BaseConverter, Format, Model, Precision, ShapeSpec

View file

@ -1,3 +1,17 @@
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import Counter
from typing import Callable, Dict, List

View file

@ -1,3 +1,17 @@
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import abc
import importlib
import logging
@ -25,6 +39,9 @@ class Parameter(Enum):
def __lt__(self, other: "Parameter") -> bool:
return self.value < other.value
def __str__(self):
return self.value
class Accelerator(Parameter):
AMP = "amp"

View file

@ -1,3 +1,17 @@
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pathlib import Path
from typing import Dict, Iterable
@ -131,3 +145,4 @@ class NpzWriter:
def __exit__(self, exc_type, exc_val, exc_tb):
self.flush()

View file

@ -1,3 +1,17 @@
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
import logging
import os

View file

@ -1,3 +1,17 @@
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import csv
import re
from typing import Dict, List

View file

@ -1,15 +1,32 @@
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
from typing import List, Optional
def warmup(
model_name: str,
batch_sizes: List[int],
triton_gpu_engine_count: int = 1,
triton_instances: int = 1,
profiling_data: str = "random",
input_shapes: Optional[List[str]] = None,
server_url: str = "localhost",
measurement_window: int = 10000,
shared_memory: bool = False
):
print("\n")
print(f"==== Warmup start ====")
@ -17,30 +34,33 @@ def warmup(
input_shapes = " ".join(map(lambda shape: f" --shape {shape}", input_shapes)) if input_shapes else ""
bs = set()
bs.add(min(batch_sizes))
bs.add(max(batch_sizes))
measurement_window = 6 * measurement_window
for batch_size in bs:
exec_args = f"""-max-threads {triton_instances} \
-m {model_name} \
-x 1 \
-c {triton_instances} \
-t {triton_instances} \
-p {measurement_window} \
-v \
-i http \
-u {server_url}:8000 \
-b {batch_size} \
--input-data {profiling_data} {input_shapes}
"""
max_batch_size = max(batch_sizes)
max_total_requests = 2 * max_batch_size * triton_instances * triton_gpu_engine_count
max_concurrency = min(256, max_total_requests)
batch_size = max(1, max_total_requests // 256)
result = os.system(f"perf_client {exec_args}")
if result != 0:
print(f"Failed running performance tests. Perf client failed with exit code {result}")
exit(1)
step = max(1, max_concurrency // 2)
min_concurrency = step
exec_args = f"""-m {model_name} \
-x 1 \
-p {measurement_window} \
-v \
-i http \
-u {server_url}:8000 \
-b {batch_size} \
--concurrency-range {min_concurrency}:{max_concurrency}:{step} \
--input-data {profiling_data} {input_shapes}"""
if shared_memory:
exec_args += " --shared-memory=cuda"
result = os.system(f"perf_client {exec_args}")
if result != 0:
print(f"Failed running performance tests. Perf client failed with exit code {result}")
sys.exit(1)
print("\n")
print(f"==== Warmup done ====")

View file

@ -1,3 +1,16 @@
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
networkx==2.5
numpy<1.20.0,>=1.16.0 # # numpy 1.20+ requires py37+
onnx==1.8.0
@ -9,4 +22,4 @@ tf2onnx==1.8.3
tabulate>=0.8.7
natsort>=7.0.0
# use tags instead of branch names - because there might be docker cache hit causing not fetching most recent changes on branch
service_maker @ git+https://access-token:usVyg8b11sn9gCacsVCf@gitlab-master.nvidia.com/dl/JoC/service_maker.git@1b83b96#egg=service_maker
model_navigator @ git+https://github.com/triton-inference-server/model_navigator.git@v0.1.0#egg=model_navigator

View file

@ -1,6 +1,6 @@
#!/usr/bin/env python3
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -16,8 +16,7 @@
r"""
To infer the model on framework runtime, you can use `run_inference_on_fw.py` script.
It infers data obtained from pointed data loader locally and saves received data into
[npz files](https://gitlab-master.nvidia.com/dl/JoC/bermuda-api/-/blob/develop/bermuda_api_toolset/docs/dump_files.md).
It infers data obtained from pointed data loader locally and saves received data into npz files.
Those files are stored in directory pointed by `--output-dir` argument.
Example call:

View file

@ -1,6 +1,6 @@
#!/usr/bin/env python3
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -16,8 +16,7 @@
r"""
To infer the model deployed on Triton, you can use `run_inference_on_triton.py` script.
It sends a request with data obtained from pointed data loader and dumps received data into
[npz files](https://gitlab-master.nvidia.com/dl/JoC/bermuda-api/-/blob/develop/bermuda_api_toolset/docs/dump_files.md).
It sends a request with data obtained from pointed data loader and dumps received data into npz files.
Those files are stored in directory pointed by `--output-dir` argument.
Currently, the client communicates with the Triton server asynchronously using GRPC protocol.
@ -223,9 +222,15 @@ def _parse_args():
parser.add_argument("--dump-inputs", help="Dump inputs to output dir", action="store_true", default=False)
parser.add_argument("-v", "--verbose", help="Verbose logs", action="store_true", default=False)
parser.add_argument("--output-dir", required=True, help="Path to directory where outputs will be saved")
parser.add_argument("--response-wait-time", required=False, help="Maximal time to wait for response", default=120)
parser.add_argument(
"--max-unresponded-requests", required=False, help="Maximal number of unresponded requests", default=128
"--response-wait-time", required=False, help="Maximal time to wait for response", default=120, type=float
)
parser.add_argument(
"--max-unresponded-requests",
required=False,
help="Maximal number of unresponded requests",
default=128,
type=int,
)
args, *_ = parser.parse_known_args()

View file

@ -1,6 +1,6 @@
#!/usr/bin/env python3
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -24,6 +24,7 @@ where N and M are variable-size dimensions, to tell perf_analyzer to send batch-
import argparse
import csv
import os
import sys
from pathlib import Path
from typing import Dict, List, Optional
@ -68,14 +69,15 @@ def _parse_batch_sizes(batch_sizes: str):
def offline_performance(
model_name: str,
batch_sizes: List[int],
result_path: str,
input_shapes: Optional[List[str]] = None,
profiling_data: str = "random",
triton_instances: int = 1,
server_url: str = "localhost",
measurement_window: int = 10000,
model_name: str,
batch_sizes: List[int],
result_path: str,
input_shapes: Optional[List[str]] = None,
profiling_data: str = "random",
triton_instances: int = 1,
server_url: str = "localhost",
measurement_window: int = 10000,
shared_memory: bool = False
):
print("\n")
print(f"==== Static batching analysis start ====")
@ -99,13 +101,15 @@ def offline_performance(
-u {server_url}:8000 \
-b {batch_size} \
-f {performance_partial_file} \
--input-data {profiling_data} {input_shapes}
"""
--input-data {profiling_data} {input_shapes}"""
if shared_memory:
exec_args += " --shared-memory=cuda"
result = os.system(f"perf_client {exec_args}")
if result != 0:
print(f"Failed running performance tests. Perf client failed with exit code {result}")
exit(1)
sys.exit(1)
update_performance_data(results, batch_size, performance_partial_file)
os.remove(performance_partial_file)
@ -141,6 +145,8 @@ def main():
parser.add_argument(
"--measurement-window", required=False, help="Time which perf_analyzer will wait for results", default=10000
)
parser.add_argument("--shared-memory", help="Use shared memory for communication with Triton", action="store_true",
default=False)
args = parser.parse_args()
@ -152,6 +158,7 @@ def main():
profiling_data=args.input_data,
input_shapes=args.input_shape,
measurement_window=args.measurement_window,
shared_memory=args.shared_memory
)
offline_performance(
@ -163,6 +170,7 @@ def main():
input_shapes=args.input_shape,
result_path=args.result_path,
measurement_window=args.measurement_window,
shared_memory=args.shared_memory
)

View file

@ -1,6 +1,6 @@
#!/usr/bin/env python3
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -24,6 +24,7 @@ where N and M are variable-size dimensions, to tell perf_analyzer to send batch-
import argparse
import csv
import os
import sys
from pathlib import Path
from typing import List, Optional
@ -66,15 +67,16 @@ def _parse_batch_sizes(batch_sizes: str):
def online_performance(
model_name: str,
batch_sizes: List[int],
result_path: str,
input_shapes: Optional[List[str]] = None,
profiling_data: str = "random",
triton_instances: int = 1,
triton_gpu_engine_count: int = 1,
server_url: str = "localhost",
measurement_window: int = 10000,
model_name: str,
batch_sizes: List[int],
result_path: str,
input_shapes: Optional[List[str]] = None,
profiling_data: str = "random",
triton_instances: int = 1,
triton_gpu_engine_count: int = 1,
server_url: str = "localhost",
measurement_window: int = 10000,
shared_memory: bool = False
):
print("\n")
print(f"==== Dynamic batching analysis start ====")
@ -85,13 +87,13 @@ def online_performance(
print(f"Running performance tests for dynamic batching")
performance_file = f"triton_performance_dynamic_partial.csv"
steps = 16
max_batch_size = max(batch_sizes)
max_concurrency = max(steps, max_batch_size * triton_instances * triton_gpu_engine_count)
step = max(1, max_concurrency // steps)
max_total_requests = 2 * max_batch_size * triton_instances * triton_gpu_engine_count
max_concurrency = min(256, max_total_requests)
batch_size = max(1, max_total_requests // 256)
step = max(1, max_concurrency // 32)
min_concurrency = step
batch_size = 1
exec_args = f"""-m {model_name} \
-x 1 \
@ -102,13 +104,15 @@ def online_performance(
-b {batch_size} \
-f {performance_file} \
--concurrency-range {min_concurrency}:{max_concurrency}:{step} \
--input-data {profiling_data} {input_shapes}
"""
--input-data {profiling_data} {input_shapes}"""
if shared_memory:
exec_args += " --shared-memory=cuda"
result = os.system(f"perf_client {exec_args}")
if result != 0:
print(f"Failed running performance tests. Perf client failed with exit code {result}")
exit(1)
sys.exit(1)
results = list()
update_performance_data(results=results, performance_file=performance_file)
@ -149,6 +153,8 @@ def main():
parser.add_argument(
"--measurement-window", required=False, help="Time which perf_analyzer will wait for results", default=10000
)
parser.add_argument("--shared-memory", help="Use shared memory for communication with Triton", action="store_true",
default=False)
args = parser.parse_args()
@ -157,9 +163,11 @@ def main():
model_name=args.model_name,
batch_sizes=_parse_batch_sizes(args.batch_sizes),
triton_instances=args.triton_instances,
triton_gpu_engine_count=args.number_of_model_instances,
profiling_data=args.input_data,
input_shapes=args.input_shape,
measurement_window=args.measurement_window,
shared_memory=args.shared_memory
)
online_performance(
@ -172,6 +180,7 @@ def main():
input_shapes=args.input_shape,
result_path=args.result_path,
measurement_window=args.measurement_window,
shared_memory=args.shared_memory
)

View file

@ -15,7 +15,7 @@
export PRECISION="fp16"
export FORMAT="tf-trt"
export BATCH_SIZE="1,2,4,8,16,32,64,128"
export BACKEND_ACCELERATOR="trt"
export BACKEND_ACCELERATOR="cuda"
export MAX_BATCH_SIZE="128"
export NUMBER_OF_MODEL_INSTANCES="2"
export TRITON_MAX_QUEUE_DELAY="1"