Adding DLRM/PyT

This commit is contained in:
Przemek Strzelczyk 2020-04-08 18:17:57 +02:00
parent 4f0f43b9a5
commit 15807b36bf
32 changed files with 5174 additions and 1 deletions

View file

@ -0,0 +1,34 @@
# Copyright (c) 2020 NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
ARG FROM_IMAGE_NAME=nvcr.io/nvidia/pytorch:20.03-py3
FROM ${FROM_IMAGE_NAME}
RUN apt update && \
apt install -y openjdk-8-jdk && \
curl http://archive.apache.org/dist/spark/spark-2.4.5/spark-2.4.5-bin-hadoop2.7.tgz -o /opt/spark-2.4.5-bin-hadoop2.7.tgz && \
tar zxf /opt/spark-2.4.5-bin-hadoop2.7.tgz -C /opt/ && \
rm /opt/spark-2.4.5-bin-hadoop2.7.tgz
ADD requirements.txt .
RUN pip install -r requirements.txt
RUN pip uninstall -y apex && \
git clone https://github.com/NVIDIA/apex && \
cd apex && \
pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./
WORKDIR /workspace/dlrm
COPY . .

View file

@ -0,0 +1,201 @@
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

View file

@ -0,0 +1,3 @@
DLRM for PyTorch
This repository includes software from https://github.com/facebookresearch/dlrm licensed under the MIT License

View file

@ -0,0 +1,516 @@
# DLRM For PyTorch
This repository provides a script and recipe to train the Deep Learning Recommendation Model (DLRM) to achieve state-of-the-art accuracy and is tested and maintained by NVIDIA.
## Table Of Contents
* [Table Of Contents](#table-of-contents)
* [Model overview](#model-overview)
* [Model architecture](#model-architecture)
* [Default configuration](#default-configuration)
* [Feature support matrix](#feature-support-matrix)
* [Features](#features)
* [Mixed precision training](#mixed-precision-training)
* [Enabling mixed precision](#enabling-mixed-precision)
* [Setup](#setup)
* [Requirements](#requirements)
* [Quick Start Guide](#quick-start-guide)
* [Advanced](#advanced)
* [Scripts and sample code](#scripts-and-sample-code)
* [Parameters](#parameters)
* [Command-line options](#command-line-options)
* [Getting the data](#getting-the-data)
* [Dataset guidelines](#dataset-guidelines)
* [Multi-dataset](#multi-dataset)
* [Preprocess with Spark](#preprocess-with-spark)
* [Training process](#training-process)
* [Inference process](#inference-process)
* [Performance](#performance)
* [Benchmarking](#benchmarking)
* [Training performance benchmark](#training-performance-benchmark)
* [Inference performance benchmark](#inference-performance-benchmark)
* [Results](#results)
* [Training accuracy results](#training-accuracy-results)
* [Training accuracy: NVIDIA DGX-1 (8x V100 32G)](#training-accuracy-nvidia-dgx-1-8x-v100-32g)
* [Training stability test](#training-stability-test)
* [Training performance results](#training-performance-results)
* [Training performance: NVIDIA DGX-1 (8x V100 32G)](#training-performance-nvidia-dgx-1-8x-v100-32g)
* [Release notes](#release-notes)
* [Changelog](#changelog)
* [Known issues](#known-issues)
## Model overview
The Deep Learning Recommendation Model (DLRM) is a recommendation model designed to
make use of both categorical and numerical inputs. It was first described in
[Deep Learning Recommendation Model for Personalization and Recommendation Systems](https://arxiv.org/abs/1906.00091).
This repository provides a reimplementation of the codebase provided originally [here](https://github.com/facebookresearch/dlrm).
The scripts provided enable you to train DLRM on the [Criteo Terabyte Dataset](https://labs.criteo.com/2013/12/download-terabyte-click-logs/).
This model uses a slightly different preprocessing procedure than the one found in the original implementation. You can find a detailed description of the preprocessing steps in the [Dataset guidelines](#dataset-guidelines) section.
Using DLRM you can train a high-quality general model for providing recommendations.
This model is trained with mixed precision using Tensor Cores on NVIDIA Volta and Turing GPUs. Therefore, researchers can get results 1.77x faster than training without Tensor Cores while experiencing the benefits of mixed precision training. It is tested against each NGC monthly container release to ensure consistent accuracy and performance over time.
### Model architecture
DLRM accepts two types of features: categorical and numerical. For each categorical
feature, an embedding table is used to provide dense representation to each unique value. The dense features enter the model and are transformed by a
simple neural network referred to as "bottom MLP". This part of the network consists of a series
of linear layers with ReLU activations. The output of the bottom MLP and the embedding vectors
are then fed into the "dot interaction" operation. The output of "dot interaction" is then concatenated with the features resulting from bottom MLP and fed into the "top MLP" which is also a series of dense layers with activations.
The model outputs a single number which can be interpreted as a likelihood of a certain user clicking an ad.
<p align="center">
<img width="100%" src="./notebooks/DLRM_architecture.png" />
<br>
Figure 1. The architecture of DLRM.
</p>
### Default configuration
The following features were implemented in this model:
- general
- static loss scaling for Tensor Cores (mixed precision) training
- preprocessing
- dataset preprocessing using Spark
### Feature support matrix
The following features are supported by this model:
| Feature | DLRM
|----------------------|--------------------------
|Automatic mixed precision (AMP) | yes
#### Features
Automatic Mixed Precision (AMP) - enables mixed precision training without any changes to the code-base by performing automatic graph rewrites and loss scaling controlled by an environmental variable.
### Mixed precision training
Mixed precision is the combined use of different numerical precisions in a computational method. [Mixed precision](https://arxiv.org/abs/1710.03740) training offers significant computational speedup by performing operations in half-precision format while storing minimal information in single-precision to retain as much information as possible in critical parts of the network. Since the introduction of [Tensor Cores](https://developer.nvidia.com/tensor-cores) in the Volta and Turing architecture, significant training speedups are experienced by switching to mixed precision -- up to 3x overall speedup on the most arithmetically intense model architectures. Using mixed precision training requires two steps:
1. Porting the model to use the FP16 data type where appropriate.
2. Adding loss scaling to preserve small gradient values.
The ability to train deep learning networks with lower precision was introduced in the Pascal architecture and first supported in [CUDA 8](https://devblogs.nvidia.com/parallelforall/tag/fp16/) in the NVIDIA Deep Learning SDK.
For information about:
- How to train using mixed precision, see the [Mixed Precision Training](https://arxiv.org/abs/1710.03740) paper and [Training With Mixed Precision](https://docs.nvidia.com/deeplearning/sdk/mixed-precision-training/index.html) documentation.
- Techniques used for mixed precision training, see the [Mixed-Precision Training of Deep Neural Networks](https://devblogs.nvidia.com/mixed-precision-training-deep-neural-networks/) blog.
- APEX tools for mixed precision training, see the [NVIDIA Apex: Tools for Easy Mixed-Precision Training in PyTorch](https://devblogs.nvidia.com/apex-pytorch-easy-mixed-precision-training/).
#### Enabling mixed precision
Mixed precision training is enabled by default. To turn it off issue the `--nofp16` flag to the `main.py` script.
## Setup
The following section lists the requirements for training DLRM.
### Requirements
This repository contains Dockerfile which extends the PyTorch NGC container and encapsulates some dependencies. Aside from these dependencies, ensure you have the following components:
- [NVIDIA Docker](https://github.com/NVIDIA/nvidia-docker)
- [PyTorch 20.03-py3+] NGC container
- [NVIDIA Volta](https://www.nvidia.com/en-us/data-center/volta-gpu-architecture/) or [Turing](https://www.nvidia.com/en-us/geforce/turing/) based GPU
For more information about how to get started with NGC containers, see the following sections from the NVIDIA GPU Cloud Documentation and the Deep Learning Documentation:
- [Getting Started Using NVIDIA GPU Cloud](https://docs.nvidia.com/ngc/ngc-getting-started-guide/index.html)
- [Accessing And Pulling From The NGC Container Registry](https://docs.nvidia.com/deeplearning/frameworks/user-guide/index.html#accessing_registry)
- [Running PyTorch](https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/running.html#running)
For those unable to use the PyTorch NGC container, to set up the required environment or create your own container, see the versioned [NVIDIA Container Support Matrix](https://docs.nvidia.com/deeplearning/frameworks/support-matrix/index.html).
## Quick Start Guide
To train your model using mixed precision with Tensor Cores or using FP32, perform the following steps using
the default parameters of DLRM on the Criteo Terabyte dataset. For the specifics concerning training and inference,
see the [Advanced](#advanced) section.
1. Clone the repository.
```
git clone https://github.com/NVIDIA/DeepLearningExamples
cd DeepLearningExamples/PyTorch/Recommendation/DLRM
```
2. Build a DLRM Docker container
```bash
docker build . -t nvidia_dlrm_pyt
```
3. Start an interactive session in the NGC container to run preprocessing/training and inference.
The NCF PyTorch container can be launched with:
```bash
mkdir -p data
docker run --runtime=nvidia -it --rm --ipc=host -v ${PWD}/data:/data nvidia_dlrm_pyt bash
```
4. Download and preprocess the dataset.
You can download the data by following the instructions at: http://labs.criteo.com/2013/12/download-terabyte-click-logs/.
When you have successfully downloaded it, put it in the `/data/dlrm/criteo/` directory in the container (`$PWD/data/dlrm/criteo` in the host system).
You can then run the preprocessing with the commands below. Note
that this will require about 4TB of disk storage.
```
cd preproc
./prepare_dataset.sh
cd -
```
5. Start training.
```
python -m dlrm.scripts.main --mode train --dataset /data/dlrm/binary_dataset/
```
6. Start validation/evaluation.
```
python -m dlrm.scripts.main --mode test --dataset /data/dlrm/binary_dataset/
```
## Advanced
The following sections provide greater details of the dataset, running training and inference, and the training results.
### Scripts and sample code
The `dlrm/scripts/main.py` script provides an entry point to most of the functionality. Using different command-line flags allows you to run training, validation and benchmark both training and inference on real or synthetic data.
The `dlrm/model.py` file provides the definition of the DLRM neural network.
Utilities connected to loading the data reside in the `data` directory.
### Parameters
### Command-line options
The `dlrm/scripts/main.py` script supports a number of command-line flags. You can get the descriptions of those by running `python -m dlrm.scripts.main --help`. Running this command will output:
```
USAGE: /workspace/dlrm/dlrm/scripts/main.py [flags]
flags:
/workspace/dlrm/dlrm/scripts/main.py:
--auc_threshold: Stop the training after achieving this AUC
(a number)
--base_device: Device to run the majority of the model operations
(default: 'cuda')
--batch_size: Batch size used for training
(default: '32768')
(an integer)
--benchmark_warmup_steps: Number of initial iterations to exclude from
throughput measurements
(default: '0')
(an integer)
--bottom_mlp_sizes: Linear layer sizes for the bottom MLP
(default: '512,256,128')
(a comma separated list)
--dataset: Full path to binary dataset. Must include files such as:
train_data.bin, test_data.bin
--dataset_subset: Use only a subset of the training data. If None (default)
will use all of it. Must be either None, or a float in range [0,1]
(a number)
--decay_start_step: Optimization step after which to start decaying the
learning rate, if None will start decaying right after the warmup phase is
completed
(default: '64000')
(an integer)
--decay_steps: Polynomial learning rate decay steps. If equal to 0 will not do
any decaying
(default: '80000')
(an integer)
--embedding_dim: Dimensionality of embedding space for categorical features
(default: '128')
(an integer)
--epochs: Number of epochs to train for
(default: '1')
(an integer)
--[no]fp16: If True (default) the script will use Automatic Mixed Precision
(default: 'true')
--[no]hash_indices: If True the model will compute `index := index % table
size` to ensure that the indices match table sizes
(default: 'false')
--inference_benchmark_batch_sizes: Batch sizes for inference throughput and
latency measurements
(default: '1,64,4096')
(a comma separated list)
--inference_benchmark_steps: Number of steps for measuring inference latency
and throughput
(default: '200')
(an integer)
--interaction_op: Type of interaction operation to perform. Supported choices:
'dot' or 'cat'
(default: 'dot')
--load_checkpoint_path: Path from which to load a checkpoint
--log_path: Destination for the log file with various results and statistics
(default: './log.json')
--loss_scale: Static loss scale for Mixed Precision Training
(default: '8192.0')
(a number)
--lr: Base learning rate
(default: '28.0')
(a number)
--max_steps: Stop training after doing this many optimization steps
(an integer)
--max_table_size: Maximum number of rows per embedding table, by default equal
to the number of unique values for each categorical variable
(an integer)
--mode: <train|test|inference_benchmark>: Select task to be performed
(default: 'train')
--num_numerical_features: Number of numerical features in the dataset.
Defaults to 13 for the Criteo Terabyte Dataset
(default: '13')
(an integer)
--output_dir: Path where to save the checkpoints
(default: '/tmp')
--print_freq: Number of optimizations steps between printing training status
to stdout
(default: '200')
(an integer)
--save_checkpoint_path: Path to which to save the training checkpoints
--seed: Random seed
(default: '12345')
(an integer)
--[no]self_interaction: Set to True to use self-interaction
(default: 'false')
-shuffle,--[no]shuffle_batch_order: Read batch in train dataset by random
order
(default: 'false')
--[no]synthetic_dataset: Use synthetic instead of real data for benchmarking
purposes
(default: 'false')
--synthetic_dataset_table_sizes: Embedding table sizes to use with the
synthetic dataset
(a comma separated list)
--test_after: Don't test the model unless this many epochs has been completed
(default: '0.0')
(a number)
--test_batch_size: Batch size used for testing/validation
(default: '32768')
(an integer)
--test_freq: Number of optimization steps between validations. If None will
test after each epoch
(an integer)
--top_mlp_sizes: Linear layer sizes for the top MLP
(default: '1024,1024,512,256,1')
(a comma separated list)
--warmup_factor: Learning rate warmup factor. Must be a non-negative integer
(default: '0')
(an integer)
--warmup_steps: Number of warmup optimization steps
(default: '6400')
(an integer)
```
The following example output is printed when running the model:
```
Epoch:[0/1] [200/128028] eta: 1:28:44 loss: 0.1782 step_time: 0.041657 lr: 0.8794
Epoch:[0/1] [400/128028] eta: 1:25:15 loss: 0.1403 step_time: 0.038504 lr: 1.7544
Epoch:[0/1] [600/128028] eta: 1:23:56 loss: 0.1384 step_time: 0.038422 lr: 2.6294
Epoch:[0/1] [800/128028] eta: 1:23:13 loss: 0.1370 step_time: 0.038421 lr: 3.5044
Epoch:[0/1] [1000/128028] eta: 1:22:45 loss: 0.1362 step_time: 0.038464 lr: 4.3794
Epoch:[0/1] [1200/128028] eta: 1:22:24 loss: 0.1346 step_time: 0.038455 lr: 5.2544
Epoch:[0/1] [1400/128028] eta: 1:22:07 loss: 0.1339 step_time: 0.038459 lr: 6.1294
Epoch:[0/1] [1600/128028] eta: 1:21:52 loss: 0.1320 step_time: 0.038481 lr: 7.0044
Epoch:[0/1] [1800/128028] eta: 1:21:39 loss: 0.1315 step_time: 0.038482 lr: 7.8794
Epoch:[0/1] [2000/128028] eta: 1:21:27 loss: 0.1304 step_time: 0.038466 lr: 8.7544
Epoch:[0/1] [2200/128028] eta: 1:21:15 loss: 0.1305 step_time: 0.038430 lr: 9.6294
```
### Getting the data
This example uses the [Criteo Terabyte Dataset](https://labs.criteo.com/2013/12/download-terabyte-click-logs/).
The first 23 days are used as the training set. The last day is split in half. The first part is used as a validation set and the second one as a hold-out test set.
#### Dataset guidelines
The preprocessing steps applied to the raw data include:
- Replacing the missing values with `0`
- Replacing the categorical values that exist fewer than 15 times with a special value
- Converting the hash values to consecutive integers
- Adding 2 to all the numerical features so that all of them are greater or equal to 1
- Taking a natural logarithm of all numerical features
#### Multi-dataset
Our preprocessing scripts are designed for the Criteo Terabyte Dataset and should work with any other dataset with the same format. The data should be split into text files. Each line of those text files should contain a single training example. An example should consist of multiple fields separated by tabulators:
- The first field is the label `1` for a positive example and `0` for negative.
- The next `N` tokens should contain the numerical features separated by tabs.
- The next `M` tokens should contain the hashed categorical features separated by tabs.
#### Preprocess with Spark
The script `spark_data_utils.py` is a PySpark application, which is used to preprocess the Criteo Terabyte Dataset. In the Docker image, we have installed Spark 2.4.5, which will start a standalone cluster of Spark. The script `run-spark.sh` starts the Spark, then runs several PySpark jobs with `spark_data_utils.py`.
Generate the dictionary
Transform train dataset
Transform test dataset
Transform validation dataset
Change the variables in the `run-spark.sh` script according to your environment.
Configure the paths.
```
export SPARK_LOCAL_DIRS=/data/spark-tmp
export INPUT_PATH=/data/criteo
export OUTPUT_PATH=/data/output
```
Note that the Spark job requires about 3TB disk space used for data shuffle.
`SPARK_LOCAL_DIRS` is the path where Spark uses to write shuffle data.
`INPUT_PATH` is the path of the Criteo Terabyte Dataset, including uncompressed files like day_0, day_1…
`OUTPUT_PATH` is where the script writes the output data. It will generate below subdirectories of `models`, `train`, `test`, and `validation`.
The `model` is the dictionary folder.
The `train` is the train dataset transformed from day_0 to day_22.
The `test` is the test dataset transformed from the prior half of day_23.
The `validation` is the dataset transformed from the latter half of day_23.
Configure the resources which Spark will use.
```
export TOTAL_CORES=80
export TOTAL_MEMORY=800
```
`TOTAL_CORES` is the total CPU cores you want Spark to use.
`TOTAL_MEMORY` is the total memory Spark will use.
Configure frequency limit.
```
USE_FREQUENCY_LIMIT=15
```
The frequency limit is used to filter out the categorical values which appear less than n times in the whole dataset, and make them be 0. Change this variable to 1 to enable it. The default frequency limit is 15 in the script. You also can change the number as you want by changing the line of `OPTS="--frequency_limit 8"`.
After the above configuration, you can run `run-spark.sh` if you already downloaded the dataset or run through `prepare_dataset.sh`, which includes verifying the downloaded dataset and running the job to preprocess the dataset.
### Training process
The main training script resides in `dlrm/scripts/main.py`. Once the training is completed, it stores the checkpoint
in the path specified by `--save_checkpoint_path` and a training log in `--log_path`. The quality of the predictions
generated by the model is measured by the [ROC AUC metric](https://scikit-learn.org/stable/modules/model_evaluation.html#roc-metrics).
The speed of training and inference is measured by throughput i.e., the number
of samples processed per second. We use mixed precision training with static loss scaling for the bottom and top MLPs while embedding tables are stored in FP32 format.
### Inference process
This section describes inference with PyTorch in Python. If you're interested in inference using the Triton Inference Server, refer to `triton/README.md` file.
Two modes for inference are currently supported by the `dlrm/scripts/main.py` script:
1. Inference benchmark this mode will measure and print out throughput and latency numbers for multiple batch sizes. You can activate it by setting the batch sizes to be tested with the `inference_benchmark_batch_sizes` command-line argument. It will use the default test dataset unless the `--synthetic_dataset` flag is passed.
2. Test-only this mode can be used to run a full validation on a checkpoint to measure ROC AUC . You can enable it by passing the `--mode test` flag.
## Performance
### Benchmarking
The following section shows how to run benchmarks measuring the model performance in training and inference modes.
#### Training performance benchmark
To benchmark the training performance on a specific batch size, run:
```
python -m dlrm.scripts.main --mode train --max_steps 500 --benchmark_warmup_steps 250 --dataset /data
```
You can also pass the `--synthetic_dataset` flag if you haven't yet downloaded the dataset.
#### Inference performance benchmark
To benchmark the inference performance on a specific batch size, run:
```
python -m dlrm.scripts.main --mode inference_benchmark --dataset /data
```
You can also pass the `--synthetic_dataset` flag if you haven't yet downloaded the dataset.
### Results
The following sections provide details on how we achieved our performance and accuracy in training and inference.
#### Training accuracy results
##### Training accuracy: NVIDIA DGX-1 (8x V100 32G)
Our results were obtained by running the `dlrm/scripts/main.py` script for one epoch as described in the Quick Start Guide training script in the DLRM Docker container on a single Tesla V100 32G GPU.
| GPUs | Batch size / GPU | Accuracy (AUC) - FP32 | Accuracy (AUC) - mixed precision | Time to train - FP32 [hours] | Time to train - mixed precision [hours] | Time to train speedup (FP32 to mixed precision)
|----|----|----|----|---|---|---|
| 1 | 32k | 0.80362 | 0.80362 | 2.46 | 1.44 | 1.71 |
##### Training stability test
The table below shows the complete convergence data for 16 different random seeds.
| Random seed | Mixed precision AUC | Single precision AUC |
|-------:|---------:|---------:|
| 8 | 0.803696 | 0.803669 |
| 9 | 0.803617 | 0.803574 |
| 10 | 0.803672 | 0.80367 |
| 11 | 0.803699 | 0.803683 |
| 12 | 0.803659 | 0.803724 |
| 13 | 0.803578 | 0.803565 |
| 14 | 0.803609 | 0.803613 |
| 15 | 0.803585 | 0.803615 |
| 16 | 0.803553 | 0.803583 |
| 17 | 0.803644 | 0.803688 |
| 18 | 0.803656 | 0.803609 |
| 19 | 0.803589 | 0.803635 |
| 20 | 0.803567 | 0.803611 |
| 21 | 0.803548 | 0.803487 |
| 22 | 0.803532 | 0.803591 |
| 23 | 0.803625 | 0.803601 |
| **mean** | **0.803614** | **0.803620** |
#### Training performance results
##### Training performance: NVIDIA DGX-1 (8x V100 32G)
Our results were obtained by running:
```
python -m dlrm.scripts.main --mode train --max_steps 200 --benchmark_warmup_steps 50 --fp16 --dataset /data
```
in the DLRM Docker container on NVIDIA DGX-1 with (8x V100 32G) GPUs. Performance numbers (in items/images per second) were averaged over 150 training steps.
| GPUs | Batch size / GPU | Throughput - FP32 | Throughput - mixed precision | Throughput speedup (FP32 - mixed precision) |
|----|---|---|---|---|
| 1 | 32k | 494k | 875k | 1.773 |
We used throughput in items processed per second as the performance metric.
## Release notes
### Changelog
April 2020
- Initial release
### Known issues
There are no known issues with this model

View file

@ -0,0 +1,98 @@
# Copyright (c) 2020 NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import os
import time
import numpy as np
import argparse
import torch
from torch.utils.data import Dataset
class CriteoBinDataset(Dataset):
"""Simple dataloader for a recommender system. Designed to work with a single binary file."""
def __init__(self, data_file, batch_size=1, subset=None,
numerical_features=13, categorical_features=26,
data_type='int32', online_shuffle=True):
self.data_type = np.__dict__[data_type]
bytes_per_feature = self.data_type().nbytes
self.tad_fea = 1 + numerical_features
self.tot_fea = 1 + numerical_features + categorical_features
self.batch_size = batch_size
self.bytes_per_entry = (bytes_per_feature * self.tot_fea * batch_size)
self.num_entries = math.ceil(os.path.getsize(data_file) / self.bytes_per_entry)
if subset is not None:
if subset <= 0 or subset > 1:
raise ValueError('Subset parameter must be in (0,1) range')
self.num_entries = self.num_entries * subset
print('data file:', data_file, 'number of batches:', self.num_entries)
self.file = open(data_file, 'rb')
self.online_shuffle=online_shuffle
def __len__(self):
return self.num_entries
def __getitem__(self, idx):
if idx == 0:
self.file.seek(0, 0)
if self.online_shuffle:
self.file.seek(idx * self.bytes_per_entry, 0)
raw_data = self.file.read(self.bytes_per_entry)
array = np.frombuffer(raw_data, dtype=self.data_type).reshape(-1, self.tot_fea)
# numerical features are encoded as float32
numerical_features = array[:, 1:self.tad_fea].view(dtype=np.float32)
numerical_features = torch.from_numpy(numerical_features)
categorical_features = torch.from_numpy(array[:, self.tad_fea:])
labels = torch.from_numpy(array[:, 0])
return numerical_features, categorical_features, labels
def __del__(self):
self.file.close()
if __name__ == '__main__':
print('Dataloader benchmark')
parser = argparse.ArgumentParser()
parser.add_argument('--file', type=str)
parser.add_argument('--batch_size', type=int)
parser.add_argument('--steps', type=int, default=1000)
args = parser.parse_args()
dataset = CriteoBinDataset(data_file=args.file, batch_size=args.batch_size)
begin = time.time()
for i in range(args.steps):
_ = dataset[i]
end = time.time()
step_time = (end - begin) / args.steps
throughput = args.batch_size / step_time
print(f'Mean step time: {step_time:.6f} [s]')
print(f'Mean throughput: {throughput:,.0f} [samples / s]')

View file

@ -0,0 +1,42 @@
# Copyright (c) 2020 NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import math
from torch.utils.data import Dataset
class SyntheticDataset(Dataset):
"""Synthetic dataset version of criteo dataset."""
def __init__(self, num_entries, device='cuda', batch_size=1, dense_features=13,
categorical_feature_sizes=None):
# dataset. single target, 13 dense features, 26 sparse features
self.sparse_features = len(categorical_feature_sizes)
self.dense_features = dense_features
self.tot_fea = 1 + dense_features + self.sparse_features
self.batch_size = batch_size
self.batches_per_epoch = math.ceil(num_entries / batch_size)
self.categorical_feature_sizes = categorical_feature_sizes
self.device = device
self.tensor = torch.randint(low=0, high=2, size=(self.batch_size, self.tot_fea), device=self.device)
self.tensor = self.tensor.float()
def __len__(self):
return self.batches_per_epoch
def __getitem__(self, idx):
return self.tensor[:, 1:14], self.tensor[:, 14:], self.tensor[:, 0]

View file

@ -0,0 +1,224 @@
# Copyright (c) 2020 NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import json
import math
from absl import logging
import torch
from torch import nn
from typing import List
class Dlrm(nn.Module):
"""Reimplement Facebook's DLRM model
Original implementation is from https://github.com/facebookresearch/dlrm.
"""
def __init__(self, num_numerical_features, categorical_feature_sizes, bottom_mlp_sizes, top_mlp_sizes,
embedding_dim=32, interaction_op="dot", self_interaction=False, hash_indices=False,
base_device="cuda", sigmoid=False):
# Running everything on gpu by default
self._base_device = base_device
self._embedding_device_map = [base_device for _ in range(len(categorical_feature_sizes))]
super(Dlrm, self).__init__()
if embedding_dim != bottom_mlp_sizes[-1]:
raise TypeError("The last bottom MLP layer must have same size as embedding.")
self._embedding_dim = embedding_dim
self._interaction_op = interaction_op
self._self_interaction = self_interaction
self._hash_indices = hash_indices
self._categorical_feature_sizes = copy.copy(categorical_feature_sizes)
# Interactions are among outputs of all the embedding tables and bottom MLP, total number of
# (num_embedding_tables + 1) vectors with size embdding_dim. ``dot`` product interaction computes dot product
# between any 2 vectors. ``cat`` interaction concatenate all the vectors together.
# Output of interaction will have shape [num_interactions, embdding_dim].
self._num_interaction_inputs = len(categorical_feature_sizes) + 1
if interaction_op == "dot":
if self_interaction:
raise NotImplementedError
num_interactions = (self._num_interaction_inputs * (self._num_interaction_inputs - 1)) // 2 + embedding_dim
elif interaction_op == "cat":
num_interactions = self._num_interaction_inputs * embedding_dim
else:
raise TypeError(F"Unknown interaction {interaction_op}.")
self.embeddings = nn.ModuleList()
self._create_embeddings(self.embeddings, embedding_dim, categorical_feature_sizes)
# Create bottom MLP
bottom_mlp_layers = []
input_dims = num_numerical_features
for output_dims in bottom_mlp_sizes:
bottom_mlp_layers.append(
nn.Linear(input_dims, output_dims))
bottom_mlp_layers.append(nn.ReLU(inplace=True))
input_dims = output_dims
self.bottom_mlp = nn.Sequential(*bottom_mlp_layers)
# Create Top MLP
top_mlp_layers = []
input_dims = num_interactions
if self._interaction_op == 'dot':
input_dims += 1 # pad 1 to be multiple of 8
for output_dims in top_mlp_sizes[:-1]:
top_mlp_layers.append(nn.Linear(input_dims, output_dims))
top_mlp_layers.append(nn.ReLU(inplace=True))
input_dims = output_dims
# last Linear layer uses sigmoid
top_mlp_layers.append(nn.Linear(input_dims, top_mlp_sizes[-1]))
if sigmoid:
top_mlp_layers.append(nn.Sigmoid())
self.top_mlp = nn.Sequential(*top_mlp_layers)
self._initialize_mlp_weights()
self._interaction_padding = torch.zeros(1, 1, dtype=torch.float32)
self.tril_indices = torch.tensor([[i for i in range(len(self.embeddings) + 1)
for j in range(i + int(self_interaction))],
[j for i in range(len(self.embeddings) + 1)
for j in range(i + int(self_interaction))]])
def _interaction(self,
bottom_mlp_output: torch.Tensor,
embedding_outputs: List[torch.Tensor],
batch_size: int) -> torch.Tensor:
"""Interaction
"dot" interaction is a bit tricky to implement and test. Break it out from forward so that it can be tested
independently.
Args:
bottom_mlp_output (Tensor):
embedding_outputs (list): Sequence of tensors
batch_size (int):
"""
if self._interaction_padding is None:
self._interaction_padding = torch.zeros(
batch_size, 1, dtype=bottom_mlp_output.dtype, device=bottom_mlp_output.device)
concat = torch.cat([bottom_mlp_output] + embedding_outputs, dim=1)
if self._interaction_op == "dot" and not self._self_interaction:
concat = concat.view((-1, self._num_interaction_inputs, self._embedding_dim))
interaction = torch.bmm(concat, torch.transpose(concat, 1, 2))
interaction_flat = interaction[:, self.tril_indices[0], self.tril_indices[1]]
# concatenate dense features and interactions
interaction_padding = self._interaction_padding.expand(batch_size, 1).to(dtype=bottom_mlp_output.dtype)
interaction_output = torch.cat(
(bottom_mlp_output, interaction_flat, interaction_padding), dim=1)
elif self._interaction_op == "cat":
interaction_output = concat
else:
raise NotImplementedError
return interaction_output
def _initialize_mlp_weights(self):
"""Initializing weights same as original DLRM"""
for module in self.modules():
if isinstance(module, nn.Linear):
nn.init.normal_(module.weight.data, 0., math.sqrt(2. / (module.in_features + module.out_features)))
nn.init.normal_(module.bias.data, 0., math.sqrt(1. / module.out_features))
# Explicitly set weight corresponding to zero padded interaction output. They will
# stay 0 throughout the entire training. An assert can be added to the end of the training
# to prove it doesn't increase model capacity but just 0 paddings.
nn.init.zeros_(self.top_mlp[0].weight[:, -1].data)
@property
def num_categorical_features(self):
return len(self._categorical_feature_sizes)
def extra_repr(self):
s = (F"interaction_op={self._interaction_op}, self_interaction={self._self_interaction}, "
F"hash_indices={self._hash_indices}")
return s
# pylint:enable=missing-docstring
@classmethod
def from_dict(cls, obj_dict, **kwargs):
"""Create from json str"""
return cls(**obj_dict, **kwargs)
def _create_embeddings(self, embeddings, embedding_dim, categorical_feature_sizes):
# Each embedding table has size [num_features, embedding_dim]
for i, num_features in enumerate(categorical_feature_sizes):
# Allocate directly on GPU is much faster than allocating on CPU then copying over
embedding_weight = torch.empty((num_features, embedding_dim), device=self._embedding_device_map[i])
embedding = nn.Embedding.from_pretrained(embedding_weight, freeze=False, sparse=True)
# Initializing embedding same as original DLRM
nn.init.uniform_(
embedding.weight.data,
-math.sqrt(1. / embedding.num_embeddings),
math.sqrt(1. / embedding.num_embeddings))
embeddings.append(embedding)
def set_devices(self, base_device):
"""Set devices to run the model
Args:
base_device (string);
"""
self._base_device = base_device
self.bottom_mlp.to(base_device)
self.top_mlp.to(base_device)
self._interaction_padding = self._interaction_padding.to(base_device)
self._embedding_device_map = [base_device for _ in range(self.num_categorical_features)]
for embedding_id, device in enumerate(self._embedding_device_map):
logging.info("Place embedding %d on device %s", embedding_id, device)
self.embeddings[embedding_id].to(device)
def forward(self, numerical_input, categorical_inputs):
"""
Args:
numerical_input (Tensor): with shape [batch_size, num_numerical_features]
categorical_inputs (Tensor): with shape [batch_size, num_categorical_features]
"""
batch_size = numerical_input.size()[0]
# Put indices on the same device as corresponding embedding
device_indices = []
for embedding_id, _ in enumerate(self.embeddings):
device_indices.append(categorical_inputs[:, embedding_id].to(self._embedding_device_map[embedding_id]))
bottom_mlp_output = self.bottom_mlp(numerical_input)
# embedding_outputs will be a list of (26 in the case of Criteo) fetched embeddings with shape
# [batch_size, embedding_size]
embedding_outputs = []
for embedding_id, embedding in enumerate(self.embeddings):
if self._hash_indices:
device_indices[embedding_id] = device_indices[embedding_id] % embedding.num_embeddings
embedding_outputs.append(embedding(device_indices[embedding_id]).to(self._base_device))
interaction_output = self._interaction(bottom_mlp_output, embedding_outputs, batch_size)
top_mlp_output = self.top_mlp(interaction_output)
return top_mlp_output

View file

@ -0,0 +1,510 @@
# Copyright (c) 2020 NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import datetime
import os
import numpy as np
import json
from pprint import pprint
from time import time
from sklearn.metrics import roc_auc_score
from absl import app
from absl import flags
import dllogger
import torch
from apex import amp
from dlrm.data import data_loader
from dlrm.data.synthetic_dataset import SyntheticDataset
from dlrm.model import Dlrm
import dlrm.scripts.utils as utils
FLAGS = flags.FLAGS
# Basic run settings
flags.DEFINE_enum("mode", default='train', enum_values=['train', 'test', 'inference_benchmark'],
help="Select task to be performed")
flags.DEFINE_integer("seed", 12345, "Random seed")
# Training schedule flags
flags.DEFINE_integer("batch_size", 32768, "Batch size used for training")
flags.DEFINE_integer("test_batch_size", 32768, "Batch size used for testing/validation")
flags.DEFINE_float("lr", 28, "Base learning rate")
flags.DEFINE_integer("epochs", 1, "Number of epochs to train for")
flags.DEFINE_integer("max_steps", None, "Stop training after doing this many optimization steps")
flags.DEFINE_integer("warmup_factor", 0, "Learning rate warmup factor. Must be a non-negative integer")
flags.DEFINE_integer("warmup_steps", 6400, "Number of warmup optimization steps")
flags.DEFINE_integer("decay_steps", 80000, "Polynomial learning rate decay steps. If equal to 0 will not do any decaying")
flags.DEFINE_integer("decay_start_step", 64000,
"Optimization step after which to start decaying the learning rate, if None will start decaying right after the warmup phase is completed")
# Model configuration
flags.DEFINE_integer("embedding_dim", 128, "Dimensionality of embedding space for categorical features")
flags.DEFINE_list("top_mlp_sizes", [1024, 1024, 512, 256, 1], "Linear layer sizes for the top MLP")
flags.DEFINE_list("bottom_mlp_sizes", [512, 256, 128], "Linear layer sizes for the bottom MLP")
flags.DEFINE_string("interaction_op", "dot",
"Type of interaction operation to perform. Supported choices: 'dot' or 'cat'")
flags.DEFINE_boolean("self_interaction", False, "Set to True to use self-interaction")
flags.DEFINE_string(
"dataset", None,
"Full path to binary dataset. Must include files such as: train_data.bin, test_data.bin")
flags.DEFINE_boolean("synthetic_dataset", False, "Use synthetic instead of real data for benchmarking purposes")
flags.DEFINE_list("synthetic_dataset_table_sizes", default=','.join(26 * [str(10**5)]),
help="Embedding table sizes to use with the synthetic dataset")
flags.DEFINE_boolean("shuffle_batch_order", False, "Read batch in train dataset by random order", short_name="shuffle")
flags.DEFINE_integer("num_numerical_features", 13,
"Number of numerical features in the dataset. Defaults to 13 for the Criteo Terabyte Dataset")
flags.DEFINE_integer("max_table_size", None,
"Maximum number of rows per embedding table, by default equal to the number of unique values for each categorical variable")
flags.DEFINE_boolean("hash_indices", False,
"If True the model will compute `index := index % table size` to ensure that the indices match table sizes")
flags.DEFINE_float("dataset_subset", None,
"Use only a subset of the training data. If None (default) will use all of it. Must be either None, or a float in range [0,1]")
# Checkpointing
flags.DEFINE_string("load_checkpoint_path", None, "Path from which to load a checkpoint")
flags.DEFINE_string("save_checkpoint_path", None, "Path to which to save the training checkpoints")
# Saving and logging flags
flags.DEFINE_string("output_dir", "/tmp", "Path where to save the checkpoints")
flags.DEFINE_string("log_path", "./log.json", "Destination for the log file with various results and statistics")
flags.DEFINE_integer("test_freq", None, "Number of optimization steps between validations. If None will test after each epoch")
flags.DEFINE_float("test_after", 0, "Don't test the model unless this many epochs has been completed")
flags.DEFINE_integer("print_freq", 200, "Number of optimizations steps between printing training status to stdout")
flags.DEFINE_integer("benchmark_warmup_steps", 0, "Number of initial iterations to exclude from throughput measurements")
# Machine setting flags
flags.DEFINE_string("base_device", "cuda", "Device to run the majority of the model operations")
flags.DEFINE_boolean("fp16", True, "If True (default) the script will use Automatic Mixed Precision")
flags.DEFINE_float("loss_scale", 8192, "Static loss scale for Mixed Precision Training")
# inference benchmark
flags.DEFINE_list("inference_benchmark_batch_sizes", default=[1, 64, 4096],
help="Batch sizes for inference throughput and latency measurements")
flags.DEFINE_integer("inference_benchmark_steps", 200,
"Number of steps for measuring inference latency and throughput")
flags.DEFINE_float("auc_threshold", None, "Stop the training after achieving this AUC")
def validate_flags():
if FLAGS.max_table_size is not None and not FLAGS.hash_indices:
raise ValueError('Hash indices must be True when setting a max_table_size')
def create_synthetic_datasets(train_batch_size, test_batch_size):
categorical_sizes = get_categorical_feature_sizes()
dataset_train = SyntheticDataset(num_entries=4 * 10**9,
batch_size=train_batch_size,
dense_features=FLAGS.num_numerical_features,
categorical_feature_sizes=categorical_sizes)
dataset_test = SyntheticDataset(num_entries=100 * 10**6,
batch_size=test_batch_size,
dense_features=FLAGS.num_numerical_features,
categorical_feature_sizes=categorical_sizes)
return dataset_train, dataset_test
def create_real_datasets(train_batch_size, test_batch_size, online_shuffle=True):
train_dataset = os.path.join(FLAGS.dataset, "train_data.bin")
test_dataset = os.path.join(FLAGS.dataset, "test_data.bin")
categorical_sizes = get_categorical_feature_sizes()
dataset_train = data_loader.CriteoBinDataset(
data_file=train_dataset,
batch_size=train_batch_size, subset=FLAGS.dataset_subset,
numerical_features=FLAGS.num_numerical_features,
categorical_features=len(categorical_sizes),
online_shuffle=online_shuffle
)
dataset_test = data_loader.CriteoBinDataset(
data_file=test_dataset, batch_size=test_batch_size,
numerical_features=FLAGS.num_numerical_features,
categorical_features=len(categorical_sizes),
online_shuffle = False
)
return dataset_train, dataset_test
def get_dataloaders(train_batch_size, test_batch_size):
print("Creating data loaders")
if FLAGS.synthetic_dataset:
dataset_train, dataset_test = create_synthetic_datasets(train_batch_size, test_batch_size)
else:
dataset_train, dataset_test = create_real_datasets(train_batch_size,
test_batch_size,
online_shuffle=FLAGS.shuffle_batch_order)
if FLAGS.shuffle_batch_order and not FLAGS.synthetic_dataset:
train_sampler = torch.utils.data.RandomSampler(dataset_train)
else:
train_sampler = None
data_loader_train = torch.utils.data.DataLoader(
dataset_train, batch_size=None, num_workers=0, pin_memory=False, sampler=train_sampler)
data_loader_test = torch.utils.data.DataLoader(
dataset_test, batch_size=None, num_workers=0, pin_memory=False)
return data_loader_train, data_loader_test
def get_categorical_feature_sizes():
if FLAGS.synthetic_dataset:
feature_sizes = [int(s) for s in FLAGS.synthetic_dataset_table_sizes]
return feature_sizes
categorical_sizes_file = os.path.join(FLAGS.dataset, "model_size.json")
with open(categorical_sizes_file) as f:
categorical_sizes = json.load(f).values()
categorical_sizes = list(categorical_sizes)
# need to add 1 because the JSON file contains the max value not the count
categorical_sizes = [s + 1 for s in categorical_sizes]
if FLAGS.max_table_size is None:
return categorical_sizes
clipped_sizes = [min(s, FLAGS.max_table_size) for s in categorical_sizes]
return clipped_sizes
def create_model():
print("Creating model")
model_config = {
'top_mlp_sizes': FLAGS.top_mlp_sizes,
'bottom_mlp_sizes': FLAGS.bottom_mlp_sizes,
'embedding_dim': FLAGS.embedding_dim,
'interaction_op': FLAGS.interaction_op,
'self_interaction': FLAGS.self_interaction,
'categorical_feature_sizes': get_categorical_feature_sizes(),
'num_numerical_features': FLAGS.num_numerical_features,
'hash_indices': FLAGS.hash_indices,
'base_device': FLAGS.base_device,
}
model = Dlrm.from_dict(model_config)
print(model)
if FLAGS.load_checkpoint_path is not None:
model.load_state_dict(torch.load(FLAGS.load_checkpoint_path, map_location="cpu"))
model.to(FLAGS.base_device)
return model
def main(argv):
validate_flags()
torch.manual_seed(FLAGS.seed)
utils.init_logging(log_path=FLAGS.log_path)
dllogger.log(data=FLAGS.flag_values_dict(), step='PARAMETER')
data_loader_train, data_loader_test = get_dataloaders(train_batch_size=FLAGS.batch_size,
test_batch_size=FLAGS.test_batch_size)
scaled_lr = FLAGS.lr / FLAGS.loss_scale if FLAGS.fp16 else FLAGS.lr
model = create_model()
optimizer = torch.optim.SGD(model.parameters(), lr=scaled_lr)
if FLAGS.fp16 and FLAGS.mode == 'train':
(model.top_mlp, model.bottom_mlp), optimizer = amp.initialize([model.top_mlp, model.bottom_mlp],
optimizer, opt_level="O2",
loss_scale=1)
elif FLAGS.fp16:
model = model.half()
loss_fn = torch.nn.BCEWithLogitsLoss(reduction="mean")
loss_fn = torch.jit.trace(loss_fn.forward, (torch.rand(FLAGS.batch_size, 1).cuda(),
torch.rand(FLAGS.batch_size, 1).cuda()))
if FLAGS.mode == 'test':
loss, auc, test_step_time = evaluate(model, loss_fn, data_loader_test)
avg_test_throughput = FLAGS.batch_size / test_step_time
results = {'auc': auc,
'avg_inference_latency': test_step_time,
'average_test_throughput': avg_test_throughput}
dllogger.log(data=results, step=tuple())
print(F"Finished testing. Test Loss {loss:.4f}, auc {auc:.4f}")
return
if FLAGS.mode == 'inference_benchmark':
results = {}
if FLAGS.fp16:
# can use pure FP16 for inference
model = model.half()
for batch_size in FLAGS.inference_benchmark_batch_sizes:
batch_size = int(batch_size)
_, benchmark_data_loader = get_dataloaders(train_batch_size=batch_size,
test_batch_size=batch_size)
latencies = inference_benchmark(model=model, data_loader=benchmark_data_loader,
num_batches=FLAGS.inference_benchmark_steps)
print("All inference latencies: {}".format(latencies))
mean_latency = np.mean(latencies)
mean_inference_throughput = batch_size / mean_latency
subresult = {F'mean_inference_latency_batch_{batch_size}': mean_latency,
F'mean_inference_throughput_batch_{batch_size}': mean_inference_throughput}
results.update(subresult)
dllogger.log(data=results, step=tuple())
print(F"Finished inference benchmark.")
return
if FLAGS.mode == 'train':
train(model, loss_fn, optimizer, data_loader_train, data_loader_test, scaled_lr)
def maybe_save_checkpoint(model, path):
if path is None:
return
begin = time()
torch.save(model.state_dict(), path)
end = time()
print(f'Checkpoint saving took {end-begin:,.2f} [s]')
def train(model, loss_fn, optimizer, data_loader_train, data_loader_test, scaled_lr):
"""Train and evaluate the model
Args:
model (dlrm):
loss_fn (torch.nn.Module): Loss function
optimizer (torch.nn.optim):
data_loader_train (torch.utils.data.DataLoader):
data_loader_test (torch.utils.data.DataLoader):
"""
model.train()
base_device = FLAGS.base_device
print_freq = FLAGS.print_freq
steps_per_epoch = len(data_loader_train)
test_freq = FLAGS.test_freq if FLAGS.test_freq is not None else steps_per_epoch
metric_logger = utils.MetricLogger(delimiter=" ")
metric_logger.add_meter('loss', utils.SmoothedValue(window_size=print_freq, fmt='{avg:.4f}'))
metric_logger.add_meter('step_time', utils.SmoothedValue(window_size=print_freq, fmt='{avg:.6f}'))
metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.4f}'))
timer = utils.StepTimer()
best_auc = 0
best_epoch = 0
start_time = time()
for epoch in range(FLAGS.epochs):
batch_iter = iter(data_loader_train)
for step in range(len(data_loader_train)):
timer.click()
global_step = steps_per_epoch * epoch + step
numerical_features, categorical_features, click = next(batch_iter)
categorical_features = categorical_features.to(base_device).to(torch.long)
numerical_features = numerical_features.to(base_device)
click = click.to(base_device).to(torch.float32)
utils.lr_step(optimizer, num_warmup_iter=FLAGS.warmup_steps, current_step=global_step + 1,
base_lr=scaled_lr, warmup_factor=FLAGS.warmup_factor,
decay_steps=FLAGS.decay_steps, decay_start_step=FLAGS.decay_start_step)
if FLAGS.max_steps and global_step > FLAGS.max_steps:
print(F"Reached max global steps of {FLAGS.max_steps}. Stopping.")
break
output = model(numerical_features, categorical_features).squeeze().float()
loss = loss_fn(output, click.squeeze())
optimizer.zero_grad()
if FLAGS.fp16:
loss *= FLAGS.loss_scale
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
else:
loss.backward()
optimizer.step()
loss_value = loss.item()
if timer.measured is None:
# first iteration, no step time etc. to print
continue
if global_step < FLAGS.benchmark_warmup_steps:
metric_logger.update(
loss=loss_value, lr=optimizer.param_groups[0]["lr"])
else:
unscale_factor = FLAGS.loss_scale if FLAGS.fp16 else 1
metric_logger.update(
loss=loss_value / unscale_factor, step_time=timer.measured,
lr=optimizer.param_groups[0]["lr"] * unscale_factor
)
if step % print_freq == 0 and step > 0:
if global_step < FLAGS.benchmark_warmup_steps:
print(F'Warming up, step [{global_step}/{FLAGS.benchmark_warmup_steps}]')
continue
eta_str = datetime.timedelta(seconds=int(metric_logger.step_time.global_avg * (steps_per_epoch - step)))
metric_logger.print(
header=F"Epoch:[{epoch}/{FLAGS.epochs}] [{step}/{steps_per_epoch}] eta: {eta_str}")
if (global_step + 1) % test_freq == 0 and global_step > 0 and global_step / steps_per_epoch >= FLAGS.test_after:
loss, auc, test_step_time = evaluate(model, loss_fn, data_loader_test)
print(F"Epoch {epoch} step {step}. Test loss {loss:.5f}, auc {auc:.6f}")
if auc > best_auc:
best_auc = auc
best_epoch = epoch + ((step + 1) / steps_per_epoch)
maybe_save_checkpoint(model, FLAGS.save_checkpoint_path)
if FLAGS.auc_threshold and auc >= FLAGS.auc_threshold:
stop_time = time()
run_time_s = int(stop_time - start_time)
print(F"Hit target accuracy AUC {FLAGS.auc_threshold} at epoch "
F"{global_step/steps_per_epoch:.2f} in {run_time_s}s. "
F"Average speed {global_step * FLAGS.batch_size / run_time_s:.1f} records/s.")
return
avg_throughput = FLAGS.batch_size / metric_logger.step_time.avg
results = {'best_auc' : best_auc,
'best_epoch' : best_epoch,
'average_train_throughput' : avg_throughput}
if 'test_step_time' in locals():
avg_test_throughput = FLAGS.test_batch_size / test_step_time
results['average_test_throughput'] = avg_test_throughput
dllogger.log(data=results, step=tuple())
def evaluate(model, loss_fn, data_loader):
"""Test dlrm model
Args:
model (dlrm):
loss_fn (torch.nn.Module): Loss function
data_loader (torch.utils.data.DataLoader):
"""
model.eval()
base_device = FLAGS.base_device
print_freq = FLAGS.print_freq
steps_per_epoch = len(data_loader)
metric_logger = utils.MetricLogger(delimiter=" ")
metric_logger.add_meter('loss', utils.SmoothedValue(window_size=print_freq, fmt='{avg:.4f}'))
metric_logger.add_meter('step_time', utils.SmoothedValue(window_size=print_freq, fmt='{avg:.4f}'))
with torch.no_grad():
y_true = []
y_score = []
timer = utils.StepTimer()
batch_iter = iter(data_loader)
timer.click()
for step in range(len(data_loader)):
numerical_features, categorical_features, click = next(batch_iter)
categorical_features = categorical_features.to(base_device).to(torch.long)
numerical_features = numerical_features.to(base_device)
click = click.to(torch.float32).to(base_device)
if FLAGS.fp16:
numerical_features = numerical_features.half()
output = model(numerical_features, categorical_features).squeeze()
loss = loss_fn(output, click)
y_true.append(click)
y_score.append(output)
loss_value = loss.item()
timer.click()
if timer.measured is not None:
metric_logger.update(loss=loss_value, step_time=timer.measured)
if step % print_freq == 0 and step > 0:
metric_logger.print(header=F"Test: [{step}/{steps_per_epoch}]")
y_true = torch.cat(y_true).cpu().numpy()
y_score = torch.cat(y_score).cpu().numpy()
auc = roc_auc_score(y_true=y_true, y_score=y_score)
model.train()
return metric_logger.loss.global_avg, auc, metric_logger.step_time.avg
def inference_benchmark(model, data_loader, num_batches=100):
model.eval()
base_device = FLAGS.base_device
latencies = []
with torch.no_grad():
for step, (numerical_features, categorical_features, click) in enumerate(data_loader):
if step > num_batches:
break
step_start_time = time()
numerical_features = numerical_features.to(base_device)
if FLAGS.fp16:
numerical_features = numerical_features.half()
categorical_features = categorical_features.to(device=base_device, dtype=torch.int64)
_ = model(numerical_features, categorical_features).squeeze()
torch.cuda.synchronize()
step_time = time() - step_start_time
if step >= FLAGS.benchmark_warmup_steps:
latencies.append(step_time)
return latencies
if __name__ == '__main__':
app.run(main)

View file

@ -0,0 +1,278 @@
# Copyright (c) 2020 NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import defaultdict, deque
import datetime
import time
import torch
import torch.distributed as dist
import errno
import os
import dllogger
class SmoothedValue(object):
"""Track a series of values and provide access to smoothed values over a
window or the global series average.
"""
def __init__(self, window_size=20, fmt=None):
if fmt is None:
fmt = "{median:.4f} ({global_avg:.4f})"
self.deque = deque(maxlen=window_size)
self.total = 0.0
self.count = 0
self.fmt = fmt
def update(self, value, n=1):
self.deque.append(value)
self.count += n
self.total += value * n
def synchronize_between_processes(self):
"""
Warning: does not synchronize the deque!
"""
if not is_dist_avail_and_initialized():
return
t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
dist.barrier()
dist.all_reduce(t)
t = t.tolist()
self.count = int(t[0])
self.total = t[1]
@property
def median(self):
d = torch.tensor(list(self.deque))
return d.median().item()
@property
def avg(self):
d = torch.tensor(list(self.deque), dtype=torch.float32)
return d.mean().item()
@property
def global_avg(self):
return self.total / self.count
@property
def max(self):
return max(self.deque)
@property
def value(self):
return self.deque[-1]
def __str__(self):
return self.fmt.format(
median=self.median,
avg=self.avg,
global_avg=self.global_avg,
max=self.max,
value=self.value)
class MetricLogger(object):
def __init__(self, delimiter="\t"):
self.meters = defaultdict(SmoothedValue)
self.delimiter = delimiter
def update(self, **kwargs):
for k, v in kwargs.items():
if isinstance(v, torch.Tensor):
v = v.item()
assert isinstance(v, (float, int))
self.meters[k].update(v)
def __getattr__(self, attr):
if attr in self.meters:
return self.meters[attr]
if attr in self.__dict__:
return self.__dict__[attr]
raise AttributeError("'{}' object has no attribute '{}'".format(
type(self).__name__, attr))
def __str__(self):
loss_str = []
for name, meter in self.meters.items():
loss_str.append(
"{}: {}".format(name, str(meter))
)
return self.delimiter.join(loss_str)
def synchronize_between_processes(self):
for meter in self.meters.values():
meter.synchronize_between_processes()
def add_meter(self, name, meter):
self.meters[name] = meter
def print(self, header=None):
if not header:
header = ''
print_str = header
for name, meter in self.meters.items():
print_str += F" {name}: {meter}"
print(print_str)
def accuracy(output, target, topk=(1,)):
"""Computes the accuracy over the k top predictions for the specified values of k"""
with torch.no_grad():
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
correct = pred.eq(target[None])
res = []
for k in topk:
correct_k = correct[:k].flatten().sum(dtype=torch.float32)
res.append(correct_k * (100.0 / batch_size))
return res
def lr_step(optim, num_warmup_iter, current_step, base_lr, warmup_factor, decay_steps=0, decay_start_step=None):
if decay_start_step is None:
decay_start_step = num_warmup_iter
new_lr = base_lr
if decay_start_step < num_warmup_iter:
raise ValueError('Learning rate warmup must finish before decay starts')
if current_step <= num_warmup_iter:
warmup_step = base_lr / (num_warmup_iter * (2 ** warmup_factor))
new_lr = base_lr - (num_warmup_iter - current_step) * warmup_step
steps_since_decay_start = current_step - decay_start_step
if decay_steps != 0 and steps_since_decay_start > 0:
already_decayed_steps = min(steps_since_decay_start, decay_steps)
new_lr = base_lr * ((decay_steps - already_decayed_steps) / decay_steps) ** 2
min_lr = 0.0000001
new_lr = max(min_lr, new_lr)
for param_group in optim.param_groups:
param_group['lr'] = new_lr
def mkdir(path):
try:
os.makedirs(path)
except OSError as e:
if e.errno != errno.EEXIST:
raise
def setup_for_distributed(is_master):
"""
This function disables printing when not in master process
"""
import builtins as __builtin__
builtin_print = __builtin__.print
def print(*args, **kwargs):
force = kwargs.pop('force', False)
if is_master or force:
builtin_print(*args, **kwargs)
__builtin__.print = print
def is_dist_avail_and_initialized():
if not dist.is_available():
return False
if not dist.is_initialized():
return False
return True
def get_world_size():
if not is_dist_avail_and_initialized():
return 1
return dist.get_world_size()
def get_rank():
if not is_dist_avail_and_initialized():
return 0
return dist.get_rank()
def is_main_process():
return get_rank() == 0
def init_logging(log_path):
json_backend = dllogger.JSONStreamBackend(verbosity=dllogger.Verbosity.VERBOSE,
filename=log_path)
stdout_backend = dllogger.StdOutBackend(verbosity=dllogger.Verbosity.VERBOSE)
stdout_backend._metadata['best_auc'].update({'format': '0:.5f'})
stdout_backend._metadata['best_epoch'].update({'format': '0:.2f'})
stdout_backend._metadata['average_train_throughput'].update({'format': ':.2e'})
stdout_backend._metadata['average_test_throughput'].update({'format': ':.2e'})
dllogger.init(backends=[json_backend, stdout_backend])
def save_on_master(*args, **kwargs):
if is_main_process():
torch.save(*args, **kwargs)
def init_distributed_mode(args):
if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
args.rank = int(os.environ["RANK"])
args.world_size = int(os.environ['WORLD_SIZE'])
args.gpu = int(os.environ['LOCAL_RANK'])
elif 'SLURM_PROCID' in os.environ:
args.rank = int(os.environ['SLURM_PROCID'])
args.gpu = args.rank % torch.cuda.device_count()
elif hasattr(args, "rank"):
pass
else:
print('Not using distributed mode')
args.distributed = False
return
args.distributed = True
torch.cuda.set_device(args.gpu)
args.dist_backend = 'nccl'
print('| distributed init (rank {}): {}'.format(
args.rank, args.dist_url), flush=True)
torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
world_size=args.world_size, rank=args.rank)
setup_for_distributed(args.rank == 0)
class StepTimer():
def __init__(self):
self._previous = None
self._new = None
self.measured = None
def click(self):
self._previous = self._new
self._new = time.time()
if self._previous is not None:
self.measured = self._new - self._previous

View file

@ -0,0 +1,726 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "Gwt7z7qdmTbW"
},
"outputs": [],
"source": [
"# Copyright 2019 NVIDIA Corporation. All Rights Reserved.\n",
"#\n",
"# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
"# you may not use this file except in compliance with the License.\n",
"# You may obtain a copy of the License at\n",
"#\n",
"# http://www.apache.org/licenses/LICENSE-2.0\n",
"#\n",
"# Unless required by applicable law or agreed to in writing, software\n",
"# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
"# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
"# See the License for the specific language governing permissions and\n",
"# limitations under the License.\n",
"# =============================================================================="
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "i4NKCp2VmTbn"
},
"source": [
"<img src=\"http://developer.download.nvidia.com/compute/machine-learning/frameworks/nvidia_logo.png\" style=\"width: 90px; float: right;\">\n",
"\n",
"# DLRM Triton Inference Demo"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "fW0OKDzvmTbt"
},
"source": [
"## Overview\n",
"\n",
"Recomendation system (RecSys) inference involves determining an ordered list of items with which the query user will most likely interact with. For very large commercial databases with millions to hundreds of millions of items to choose from (like advertisements, apps), usually an item retrieval procedure is carried out to reduce the number of items to a more manageable quantity, e.g. a few hundreds to a few thousands. The methods include computationally-light algorithms such as approximate neighborhood search, random forest and filtering based on user preferences. From thereon, a deep learning based RecSys is invoked to re-rank the items and those with the highest scores are presented to the users. This process is well demonstrated in the Google AppStore recommendation system in Figure 1. \n",
"\n",
"![DLRM_model](recsys_inference.PNG)\n",
"\n",
"Figure 1: Googles app recommendation process. [Source](https://arxiv.org/pdf/1606.07792.pdf).\n",
"\n",
"As we can see, for each query user, the number of user-item pairs to score can be as large as a few thousands. This places an extremely heavy duty on RecSys inference server, which must handle high throughput to serve many users concurrently yet at low latency to satisfy stringent latency thresholds of online commerce engines.\n",
"\n",
"The NVIDIA Triton Inference Server [9] provides a cloud inferencing solution optimized for NVIDIA GPUs. The server provides an inference service via an HTTP or GRPC endpoint, allowing remote clients to request inferencing for any model being managed by the server. Triton automatically manages and makes use of all the available GPUs.\n",
"\n",
"We will next see how to prepare the DLRM model for inference with the Triton inference server and see how Triton is up to the task. \n",
"\n",
"### Learning objectives\n",
"\n",
"This notebook demonstrates the steps for preparing a pre-trained DLRM model for deployment and inference with the NVIDIA [Triton inference server](https://github.com/NVIDIA/triton-inference-server). \n",
"\n",
"## Content\n",
"1. [Requirements](#1)\n",
"1. [Prepare model for inference](#2)\n",
"1. [Start the Triton inference server](#3)\n",
"1. [Testing server with the performance client](#4)\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "aDFrE4eqmTbv"
},
"source": [
"<a id=\"1\"></a>\n",
"## 1. Requirements\n",
"\n",
"\n",
"### 1.1 Docker container\n",
"The most convenient way to make use of the NVIDIA DLRM model is via a docker container, which provides a self-contained, isolated and re-producible environment for all experiments.\n",
"\n",
"First, clone the repository:\n",
"\n",
"```\n",
"git clone https://github.com/NVIDIA/DeepLearningExamples\n",
"cd DeepLearningExamples/PyTorch/Recommendation/DLRM\n",
"```\n",
"\n",
"To execute this notebook, first build the following inference container:\n",
"\n",
"```\n",
"docker build -t dlrm-inference . -f triton/Dockerfile\n",
"```\n",
"\n",
"Start in interactive docker session with:\n",
"\n",
"```\n",
"docker run -it --rm --gpus device=0 --shm-size=1g --ulimit memlock=-1 --ulimit stack=67108864 --net=host -v <PATH_TO_SAVED_MODEL>:/models -v <PATH_TO_EXPORT_MODEL>:/repository <PATH_TO_PREPROCESSED_DATA>:/data dlrm-inference bash\n",
"```\n",
"where:\n",
"\n",
"- PATH_TO_SAVED_MODEL: directory containing the trained DLRM models with `.pt` extension.\n",
" \n",
"- PATH_TO_EXPORT_MODEL: directory which will contain the converted model to be used with the NVIDIA Triton inference server.\n",
"\n",
"- PATH_TO_PREPROCESSED_DATA: path to the preprocessed Criteo Terabyte dataset containing 3 binary data files: `test_data.bin`, `train_data.bin` and `val_data.bin` and a JSON `file model_size.json` totalling ~650GB.\n",
"\n",
"Within the docker interactive bash session, start Jupyter with\n",
"\n",
"```\n",
"export PYTHONPATH=/workspace/dlrm\n",
"jupyter notebook --ip 0.0.0.0 --port 8888\n",
"```\n",
"\n",
"Then open the Jupyter GUI interface on your host machine at http://localhost:8888. Within the container, this demo notebook is located at `/workspace/dlrm/notebooks`.\n",
"\n",
"### 1.2 Hardware\n",
"This notebook can be executed on any CUDA-enabled NVIDIA GPU with at least 24GB of GPU memory, although for efficient mixed precision inference, a [Tensor Core NVIDIA GPU](https://www.nvidia.com/en-us/data-center/tensorcore/) is desired (Volta, Turing or newer architectures). "
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "k7RLEcKhmTb0"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Sat Apr 4 00:55:05 2020 \r\n",
"+-----------------------------------------------------------------------------+\r\n",
"| NVIDIA-SMI 440.33.01 Driver Version: 440.33.01 CUDA Version: 10.2 |\r\n",
"|-------------------------------+----------------------+----------------------+\r\n",
"| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |\r\n",
"| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |\r\n",
"|===============================+======================+======================|\r\n",
"| 0 Tesla V100-PCIE... On | 00000000:1A:00.0 Off | 0 |\r\n",
"| N/A 30C P0 37W / 250W | 19757MiB / 32510MiB | 0% Default |\r\n",
"+-------------------------------+----------------------+----------------------+\r\n",
" \r\n",
"+-----------------------------------------------------------------------------+\r\n",
"| Processes: GPU Memory |\r\n",
"| GPU PID Type Process name Usage |\r\n",
"|=============================================================================|\r\n",
"+-----------------------------------------------------------------------------+\r\n"
]
}
],
"source": [
"!nvidia-smi"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "HqSUGePjmTb9"
},
"source": [
"<a id=\"2\"></a>\n",
"## 2. Prepare model for inference\n",
"\n",
"We first convert model to a format accepted by the NVIDIA Triton inference server. Triton can accept TorchScript, ONNX amongst other formats. \n",
"\n",
"To deploy model into Triton compatible format, we provide the deployer.py [script](../triton/deployer.py).\n",
"\n",
"### TorchScript\n",
"TorchScript is a way to create serializable and optimizable models from PyTorch code. Any TorchScript program can be saved from a Python process and loaded in a process where there is no Python dependency.\n",
"\n",
"We provide two options to convert models to TorchScript:\n",
"- --ts-script convert to torchscript using torch.jit.script\n",
"- --ts-trace convert to torchscript using torch.jit.trace\n",
"\n",
"\n",
"In the conversion below, we assume:\n",
"\n",
"- The trained model is stored at /models/dlrm_model_fp16.pt\n",
"\n",
"- The maximum batchsize that Triton will handle is 65536.\n",
"\n",
"- The processed dataset directory is /data which contain a `model_size.json` file."
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"deploying model dlrm-ts-script-16 in format pytorch_libtorch\n",
"done\n"
]
}
],
"source": [
"%%bash\n",
"python ../triton/deployer.py \\\n",
"--ts-script \\\n",
"--triton-model-name dlrm-ts-script-16 \\\n",
"--triton-max-batch-size 65536 \\\n",
"--save-dir /repository \\\n",
"-- --model_checkpoint /models/dlrm_model_fp16.pt \\\n",
"--fp16 \\\n",
"--batch_size 4096 \\\n",
"--num_numerical_features 13 \\\n",
"--embedding_dim 128 \\\n",
"--top_mlp_sizes 1024 1024 512 256 1 \\\n",
"--bottom_mlp_sizes 512 256 128 \\\n",
"--interaction_op dot \\\n",
"--hash_indices \\\n",
"--dataset /data \\\n",
"--dump_perf_data ./perfdata"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "EQAIszkxmTcT"
},
"source": [
"### ONNX\n",
"\n",
"[ONNX](https://onnx.ai/) is an open format built to represent machine learning models. ONNX defines a common set of operators - the building blocks of machine learning and deep learning models - and a common file format to enable AI developers to use models with a variety of frameworks, tools, runtimes, and compilers.\n",
"\n",
"Conversion of DLRM pre-trained PyTorch model to ONNX model can be done with:"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"deploying model dlrm-onnx-16 in format onnxruntime_onnx\n",
"done\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/opt/conda/lib/python3.6/site-packages/torch/onnx/symbolic_opset9.py:2044: UserWarning: Exporting aten::index operator of advanced indexing in opset 11 is achieved by combination of multiple ONNX operators, including Reshape, Transpose, Concat, and Gather. If indices include negative values, the exported graph will produce incorrect results.\n",
" \"If indices include negative values, the exported graph will produce incorrect results.\")\n",
"/opt/conda/lib/python3.6/site-packages/torch/onnx/utils.py:915: UserWarning: No names were found for specified dynamic axes of provided input.Automatically generated names will be applied to each dynamic axes of input input__0\n",
" 'Automatically generated names will be applied to each dynamic axes of input {}'.format(key))\n",
"/opt/conda/lib/python3.6/site-packages/torch/onnx/utils.py:915: UserWarning: No names were found for specified dynamic axes of provided input.Automatically generated names will be applied to each dynamic axes of input input__1\n",
" 'Automatically generated names will be applied to each dynamic axes of input {}'.format(key))\n",
"/opt/conda/lib/python3.6/site-packages/torch/onnx/utils.py:915: UserWarning: No names were found for specified dynamic axes of provided input.Automatically generated names will be applied to each dynamic axes of input output__0\n",
" 'Automatically generated names will be applied to each dynamic axes of input {}'.format(key))\n"
]
}
],
"source": [
"%%bash\n",
"python ../triton/deployer.py \\\n",
"--onnx \\\n",
"--triton-model-name dlrm-onnx-16 \\\n",
"--triton-max-batch-size 4096 \\\n",
"--save-dir /repository \\\n",
"-- --model_checkpoint /models/dlrm_model_fp16.pt \\\n",
"--fp16 \\\n",
"--batch_size 4096 \\\n",
"--num_numerical_features 13 \\\n",
"--embedding_dim 128 \\\n",
"--top_mlp_sizes 1024 1024 512 256 1 \\\n",
"--bottom_mlp_sizes 512 256 128 \\\n",
"--interaction_op dot \\\n",
"--hash_indices \\\n",
"--dataset /data \\\n",
"--dump_perf_data ./perfdata"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "RL8d9IwzmTcV"
},
"source": [
"<a id=\"3\"></a>\n",
"## 3. Start the Triton inference server"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "o6wayGf1mTcX"
},
"source": [
"*Note: this step must be done outside the of the current docker container.*\n",
"\n",
"Open a bash window on the **host machine** and execute the following commands:\n",
"\n",
"```\n",
"docker pull nvcr.io/nvidia/tensorrtserver:20.03-py3\n",
"docker run -d --rm --gpus device=0 --ipc=host --network=host -p 8000:8000 -p 8001:8001 -p 8002:8002 -v <PATH_TO_MODEL_REPOSITORY>:/repository nvcr.io/nvidia/tensorrtserver:20.03-py3 trtserver --model-store=/repository --log-verbose=1 --model-control-mode=explicit\n",
"```\n",
"\n",
"where:\n",
"\n",
"- PATH_TO_MODEL_REPOSITORY: directory on the host machine containing the converted models in section 2 above. \n",
"\n",
"Note that each DLRM model will require ~19GB of GPU memory.\n",
"\n",
"Within the `/models` directory on the inference server, the structure should look similar to the below:\n",
"\n",
"```\n",
"/models\n",
"`-- dlrm-onnx-16\n",
" |-- 1\n",
" | `-- model.onnx\n",
" | |-- bottom_mlp.0.weight\n",
" | |-- bottom_mlp.2.weight\n",
" | |-- bottom_mlp.4.weight\n",
" | |-- embeddings.0.weight\n",
" | |-- embeddings.1.weight\n",
" | |-- embeddings.10.weight\n",
" | |-- embeddings.11.weight\n",
" | |-- embeddings.12.weight\n",
" | |-- embeddings.13.weight\n",
" | |-- embeddings.14.weight\n",
" | |-- embeddings.15.weight\n",
" | |-- embeddings.17.weight\n",
" | |-- embeddings.18.weight\n",
" | |-- embeddings.19.weight\n",
" | |-- embeddings.2.weight\n",
" | |-- embeddings.20.weight\n",
" | |-- embeddings.21.weight\n",
" | |-- embeddings.22.weight\n",
" | |-- embeddings.23.weight\n",
" | |-- embeddings.24.weight\n",
" | |-- embeddings.25.weight\n",
" | |-- embeddings.3.weight\n",
" | |-- embeddings.4.weight\n",
" | |-- embeddings.6.weight\n",
" | |-- embeddings.7.weight\n",
" | |-- embeddings.8.weight\n",
" | |-- embeddings.9.weight\n",
" | |-- model.onnx\n",
" | |-- top_mlp.0.weight\n",
" | |-- top_mlp.2.weight\n",
" | |-- top_mlp.4.weight\n",
" | `-- top_mlp.6.weight\n",
" `-- config.pbtxt\n",
"```"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "X959LYwjmTcw"
},
"source": [
"<a id=\"4\"></a>\n",
"## 4. Testing server with the performance client\n",
"\n",
"After model deployment has completed, we can test the deployed model against the Criteo test dataset. \n",
"\n",
"Note: This requires mounting the Criteo test data to, e.g. `/data/test_data.bin`. Within the dataset directory, there must also be a `model_size.json` file."
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Process is terminated.\n"
]
}
],
"source": [
"%%bash\n",
"python ../triton/client.py \\\n",
"--triton-server-url localhost:8000 \\\n",
"--protocol HTTP \\\n",
"--triton-model-name dlrm-onnx-16 \\\n",
"--num_numerical_features 13 \\\n",
"--dataset_config /data/model_size.json \\\n",
"--inference_data /data/test_data.bin \\\n",
"--batch_size 4096 \\\n",
"--fp16"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The Triton inference server comes with a [performance client](https://docs.nvidia.com/deeplearning/sdk/triton-inference-server-master-branch-guide/docs/optimization.html#perf-client) which is designed to stress test the server using multiple client threads.\n",
"\n",
"The perf_client generates inference requests to your model and measures the throughput and latency of those requests. To get representative results, the perf_client measures the throughput and latency over a time window, and then repeats the measurements until it gets stable values. By default the perf_client uses average latency to determine stability but you can use the --percentile flag to stabilize results based on that confidence level. For example, if --percentile=95 is used the results will be stabilized using the 95-th percentile request latency. \n",
"\n",
"### Request Concurrency\n",
"\n",
"By default perf_client measures your models latency and throughput using the lowest possible load on the model. To do this perf_client sends one inference request to the server and waits for the response. When that response is received, the perf_client immediately sends another request, and then repeats this process during the measurement windows. The number of outstanding inference requests is referred to as the request concurrency, and so by default perf_client uses a request concurrency of 1.\n",
"\n",
"Using the --concurrency-range <start>:<end>:<step> option you can have perf_client collect data for a range of request concurrency levels. Use the --help option to see complete documentation for this and other options.\n",
" \n"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {
"scrolled": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"*** Measurement Settings ***\n",
" Batch size: 4096\n",
" Measurement window: 5000 msec\n",
" Latency limit: 5000 msec\n",
" Concurrency limit: 10 concurrent requests\n",
" Using synchronous calls for inference\n",
" Stabilizing using average latency\n",
"\n",
"Request concurrency: 1\n",
" Pass [1] throughput: 67993.6 infer/sec. Avg latency: 60428 usec (std 22260 usec)\n",
" Pass [2] throughput: 61440 infer/sec. Avg latency: 66310 usec (std 21723 usec)\n",
" Pass [3] throughput: 68812.8 infer/sec. Avg latency: 59617 usec (std 22128 usec)\n",
" Client: \n",
" Request count: 84\n",
" Throughput: 68812.8 infer/sec\n",
" Avg latency: 59617 usec (standard deviation 22128 usec)\n",
" p50 latency: 71920 usec\n",
" p90 latency: 80018 usec\n",
" p95 latency: 83899 usec\n",
" p99 latency: 88054 usec\n",
" Avg gRPC time: 58773 usec (marshal 274 usec + response wait 58458 usec + unmarshal 41 usec)\n",
" Server: \n",
" Request count: 102\n",
" Avg request latency: 57208 usec (overhead 6 usec + queue 20184 usec + compute 37018 usec)\n",
"\n",
"Request concurrency: 2\n",
" Pass [1] throughput: 154010 infer/sec. Avg latency: 53139 usec (std 22418 usec)\n",
" Pass [2] throughput: 155648 infer/sec. Avg latency: 52483 usec (std 24768 usec)\n",
" Pass [3] throughput: 150733 infer/sec. Avg latency: 54271 usec (std 23803 usec)\n",
" Client: \n",
" Request count: 184\n",
" Throughput: 150733 infer/sec\n",
" Avg latency: 54271 usec (standard deviation 23803 usec)\n",
" p50 latency: 57022 usec\n",
" p90 latency: 83000 usec\n",
" p95 latency: 84782 usec\n",
" p99 latency: 88989 usec\n",
" Avg gRPC time: 55692 usec (marshal 274 usec + response wait 55374 usec + unmarshal 44 usec)\n",
" Server: \n",
" Request count: 216\n",
" Avg request latency: 53506 usec (overhead 244 usec + queue 19818 usec + compute 33444 usec)\n",
"\n",
"Request concurrency: 3\n",
" Pass [1] throughput: 189235 infer/sec. Avg latency: 64917 usec (std 21807 usec)\n",
" Pass [2] throughput: 201523 infer/sec. Avg latency: 60425 usec (std 24622 usec)\n",
" Pass [3] throughput: 203981 infer/sec. Avg latency: 60661 usec (std 24397 usec)\n",
" Client: \n",
" Request count: 249\n",
" Throughput: 203981 infer/sec\n",
" Avg latency: 60661 usec (standard deviation 24397 usec)\n",
" p50 latency: 72344 usec\n",
" p90 latency: 87765 usec\n",
" p95 latency: 91976 usec\n",
" p99 latency: 95775 usec\n",
" Avg gRPC time: 57213 usec (marshal 291 usec + response wait 56875 usec + unmarshal 47 usec)\n",
" Server: \n",
" Request count: 315\n",
" Avg request latency: 55254 usec (overhead 545 usec + queue 19408 usec + compute 35301 usec)\n",
"\n",
"Request concurrency: 4\n",
" Pass [1] throughput: 273613 infer/sec. Avg latency: 59555 usec (std 22608 usec)\n",
" Pass [2] throughput: 288358 infer/sec. Avg latency: 56895 usec (std 21886 usec)\n",
" Pass [3] throughput: 285082 infer/sec. Avg latency: 57494 usec (std 21833 usec)\n",
" Client: \n",
" Request count: 348\n",
" Throughput: 285082 infer/sec\n",
" Avg latency: 57494 usec (standard deviation 21833 usec)\n",
" p50 latency: 62012 usec\n",
" p90 latency: 83694 usec\n",
" p95 latency: 84966 usec\n",
" p99 latency: 93177 usec\n",
" Avg gRPC time: 59042 usec (marshal 317 usec + response wait 58669 usec + unmarshal 56 usec)\n",
" Server: \n",
" Request count: 404\n",
" Avg request latency: 56316 usec (overhead 569 usec + queue 19140 usec + compute 36607 usec)\n",
"\n",
"Request concurrency: 5\n",
" Pass [1] throughput: 335872 infer/sec. Avg latency: 60666 usec (std 22599 usec)\n",
" Pass [2] throughput: 308838 infer/sec. Avg latency: 65721 usec (std 22284 usec)\n",
" Pass [3] throughput: 339968 infer/sec. Avg latency: 59920 usec (std 22992 usec)\n",
" Client: \n",
" Request count: 415\n",
" Throughput: 339968 infer/sec\n",
" Avg latency: 59920 usec (standard deviation 22992 usec)\n",
" p50 latency: 67406 usec\n",
" p90 latency: 84561 usec\n",
" p95 latency: 86191 usec\n",
" p99 latency: 94862 usec\n",
" Avg gRPC time: 61127 usec (marshal 304 usec + response wait 60771 usec + unmarshal 52 usec)\n",
" Server: \n",
" Request count: 490\n",
" Avg request latency: 58036 usec (overhead 696 usec + queue 18923 usec + compute 38417 usec)\n",
"\n",
"Request concurrency: 6\n",
" Pass [1] throughput: 368640 infer/sec. Avg latency: 66037 usec (std 20247 usec)\n",
" Pass [2] throughput: 348979 infer/sec. Avg latency: 71309 usec (std 20236 usec)\n",
" Pass [3] throughput: 334234 infer/sec. Avg latency: 72704 usec (std 18491 usec)\n",
" Client: \n",
" Request count: 408\n",
" Throughput: 334234 infer/sec\n",
" Avg latency: 72704 usec (standard deviation 18491 usec)\n",
" p50 latency: 80327 usec\n",
" p90 latency: 87164 usec\n",
" p95 latency: 91824 usec\n",
" p99 latency: 95617 usec\n",
" Avg gRPC time: 71989 usec (marshal 315 usec + response wait 71617 usec + unmarshal 57 usec)\n",
" Server: \n",
" Request count: 504\n",
" Avg request latency: 68951 usec (overhead 957 usec + queue 18350 usec + compute 49644 usec)\n",
"\n",
"Request concurrency: 7\n",
" Pass [1] throughput: 395674 infer/sec. Avg latency: 72406 usec (std 18789 usec)\n",
" Pass [2] throughput: 407142 infer/sec. Avg latency: 69909 usec (std 19644 usec)\n",
" Pass [3] throughput: 355533 infer/sec. Avg latency: 81048 usec (std 12687 usec)\n",
" Client: \n",
" Request count: 434\n",
" Throughput: 355533 infer/sec\n",
" Avg latency: 81048 usec (standard deviation 12687 usec)\n",
" p50 latency: 84046 usec\n",
" p90 latency: 91642 usec\n",
" p95 latency: 94089 usec\n",
" p99 latency: 100453 usec\n",
" Avg gRPC time: 79919 usec (marshal 313 usec + response wait 79552 usec + unmarshal 54 usec)\n",
" Server: \n",
" Request count: 525\n",
" Avg request latency: 76078 usec (overhead 1042 usec + queue 17815 usec + compute 57221 usec)\n",
"\n",
"Request concurrency: 8\n",
" Pass [1] throughput: 524288 infer/sec. Avg latency: 62235 usec (std 15989 usec)\n",
" Pass [2] throughput: 524288 infer/sec. Avg latency: 62741 usec (std 15967 usec)\n",
" Pass [3] throughput: 517734 infer/sec. Avg latency: 63449 usec (std 15144 usec)\n",
" Client: \n",
" Request count: 632\n",
" Throughput: 517734 infer/sec\n",
" Avg latency: 63449 usec (standard deviation 15144 usec)\n",
" p50 latency: 68562 usec\n",
" p90 latency: 75212 usec\n",
" p95 latency: 77256 usec\n",
" p99 latency: 79685 usec\n",
" Avg gRPC time: 62683 usec (marshal 304 usec + response wait 62321 usec + unmarshal 58 usec)\n",
" Server: \n",
" Request count: 768\n",
" Avg request latency: 58942 usec (overhead 1574 usec + queue 2167 usec + compute 55201 usec)\n",
"\n",
"Request concurrency: 9\n",
" Pass [1] throughput: 376832 infer/sec. Avg latency: 98868 usec (std 34719 usec)\n",
" Pass [2] throughput: 407142 infer/sec. Avg latency: 90421 usec (std 35435 usec)\n",
" Pass [3] throughput: 346522 infer/sec. Avg latency: 106082 usec (std 33649 usec)\n",
" Client: \n",
" Request count: 423\n",
" Throughput: 346522 infer/sec\n",
" Avg latency: 106082 usec (standard deviation 33649 usec)\n",
" p50 latency: 122774 usec\n",
" p90 latency: 139616 usec\n",
" p95 latency: 143511 usec\n",
" p99 latency: 148324 usec\n",
" Avg gRPC time: 106566 usec (marshal 323 usec + response wait 106177 usec + unmarshal 66 usec)\n",
" Server: \n",
" Request count: 505\n",
" Avg request latency: 102100 usec (overhead 1046 usec + queue 43598 usec + compute 57456 usec)\n",
"\n",
"Request concurrency: 10\n",
" Pass [1] throughput: 407962 infer/sec. Avg latency: 100260 usec (std 27654 usec)\n",
" Pass [2] throughput: 403866 infer/sec. Avg latency: 101427 usec (std 34082 usec)\n",
" Pass [3] throughput: 412058 infer/sec. Avg latency: 99376 usec (std 31125 usec)\n",
" Client: \n",
" Request count: 503\n",
" Throughput: 412058 infer/sec\n",
" Avg latency: 99376 usec (standard deviation 31125 usec)\n",
" p50 latency: 100025 usec\n",
" p90 latency: 137764 usec\n",
" p95 latency: 141030 usec\n",
" p99 latency: 144104 usec\n",
" Avg gRPC time: 98137 usec (marshal 348 usec + response wait 97726 usec + unmarshal 63 usec)\n",
" Server: \n",
" Request count: 612\n",
" Avg request latency: 94377 usec (overhead 1417 usec + queue 40909 usec + compute 52051 usec)\n",
"\n",
"Inferences/Second vs. Client Average Batch Latency\n",
"Concurrency: 1, throughput: 68812.8 infer/sec, latency 59617 usec\n",
"Concurrency: 2, throughput: 150733 infer/sec, latency 54271 usec\n",
"Concurrency: 3, throughput: 203981 infer/sec, latency 60661 usec\n",
"Concurrency: 4, throughput: 285082 infer/sec, latency 57494 usec\n",
"Concurrency: 5, throughput: 339968 infer/sec, latency 59920 usec\n",
"Concurrency: 6, throughput: 334234 infer/sec, latency 72704 usec\n",
"Concurrency: 7, throughput: 355533 infer/sec, latency 81048 usec\n",
"Concurrency: 8, throughput: 517734 infer/sec, latency 63449 usec\n",
"Concurrency: 9, throughput: 346522 infer/sec, latency 106082 usec\n",
"Concurrency: 10, throughput: 412058 infer/sec, latency 99376 usec\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"WARNING: Overriding max_threads specification to ensure requested concurrency range.\n"
]
}
],
"source": [
"%%bash\n",
"/workspace/install/bin/perf_client \\\n",
"--max-threads 10 \\\n",
"-m dlrm-onnx-16 \\\n",
"-x 1 \\\n",
"-p 5000 \\\n",
"-v -i gRPC \\\n",
"-u localhost:8001 \\\n",
"-b 4096 \\\n",
"-l 5000 \\\n",
"--concurrency-range 1:10 \\\n",
"--input-data ./perfdata \\\n",
"-f result.csv"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Visualizing Latency vs. Throughput\n",
"\n",
"The perf_client provides the -f option to generate a file containing CSV output of the results.\n",
"You can import the CSV file into a spreadsheet to help visualize the latency vs inferences/second tradeoff as well as see some components of the latency. Follow these steps:\n",
"- Open this [spreadsheet](https://docs.google.com/spreadsheets/d/1IsdW78x_F-jLLG4lTV0L-rruk0VEBRL7Mnb-80RGLL4)\n",
"\n",
"- Make a copy from the File menu “Make a copy…”\n",
"\n",
"- Open the copy\n",
"\n",
"- Select the A1 cell on the “Raw Data” tab\n",
"\n",
"- From the File menu select “Import…”\n",
"\n",
"- Select “Upload” and upload the file\n",
"\n",
"- Select “Replace data at selected cell” and then select the “Import data” button\n",
"\n",
"![DLRM_model](latency_vs_throughput.PNG)\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "g8MxXY5GmTc8"
},
"source": [
"# Conclusion\n",
"\n",
"In this notebook, we have walked through the complete process of preparing the pretrained DLRM for inference with the Triton inference server. Then, we stress test the server with the performance client to verify inference throughput.\n",
"\n",
"## What's next\n",
"Now it's time to deploy your own DLRM model with Triton. "
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "249yGNLmmTc_"
},
"outputs": [],
"source": []
}
],
"metadata": {
"colab": {
"include_colab_link": true,
"name": "TensorFlow_UNet_Industrial_Colab_train_and_inference.ipynb",
"provenance": []
},
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.9"
}
},
"nbformat": 4,
"nbformat_minor": 1
}

Binary file not shown.

After

Width:  |  Height:  |  Size: 29 KiB

View file

@ -0,0 +1,470 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "Gwt7z7qdmTbW"
},
"outputs": [],
"source": [
"# Copyright 2019 NVIDIA Corporation. All Rights Reserved.\n",
"#\n",
"# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
"# you may not use this file except in compliance with the License.\n",
"# You may obtain a copy of the License at\n",
"#\n",
"# http://www.apache.org/licenses/LICENSE-2.0\n",
"#\n",
"# Unless required by applicable law or agreed to in writing, software\n",
"# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
"# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
"# See the License for the specific language governing permissions and\n",
"# limitations under the License.\n",
"# =============================================================================="
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "i4NKCp2VmTbn"
},
"source": [
"<img src=\"http://developer.download.nvidia.com/compute/machine-learning/frameworks/nvidia_logo.png\" style=\"width: 90px; float: right;\">\n",
"\n",
"# DLRM Training and Inference Demo"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "fW0OKDzvmTbt"
},
"source": [
"## Overview\n",
"\n",
"\n",
"DLRM is a deep learning based approach to recommendation introduced by Facebook. \n",
"Like other deep learning based approaches, DLRM is designed to make use of both categorical and numerical inputs which are usually present in RecSys training data. The architecture of DLRM can be understood via Figure 1. In order to handle categorical data, embedding layers map each category to a dense representation before being fed into dense multilayer perceptrons (MLP). Continuous features can be fed directly into a dense MLP. At the next level, second-order interactions of different features are computed explicitly by taking the dot product between all pairs of embedding vectors and processed dense features. Those pairwise interactions are fed into a top level MLP to compute the likelihood of interaction between users and items. \n",
"\n",
"Compared to other DL based approaches to recommendation, DLRM differs in two ways. First, DLRM computes the feature interaction explicitly while limiting the order of interaction to pairwise interactions. Second, DLRM treats each embedded feature vector (corresponding to categorical features) as a single unit, whereas other methods treat each element in the feature vector as a new unit that should yield different cross terms. These design choices help reduce computational/memory cost while maintaining competitive accuracy.\n",
"\n",
"![DLRM_model](DLRM_architecture.png)\n",
"\n",
"Figure 1. DLRM architecture.\n",
"\n",
"### Learning objectives\n",
"\n",
"This notebook demonstrates the steps for training a DLRM model. We then employ the trained model to make inference on new data.\n",
"\n",
"## Content\n",
"1. [Requirements](#1)\n",
"1. [Data download and preprocessing](#2)\n",
"1. [Training](#3)\n",
"1. [Testing trained model](#4)\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "aDFrE4eqmTbv"
},
"source": [
"<a id=\"1\"></a>\n",
"## 1. Requirements\n",
"\n",
"\n",
"### 1.1 Docker container\n",
"The most convenient way to make use of the NVIDIA DLRM model is via a docker container, which provides a self-contained, isolated and re-producible environment for all experiments. Refer to the [Quick Start Guide section](../README.md) of the Readme documentation for a comprehensive guide. We briefly summarize the steps here.\n",
"\n",
"First, clone the repository:\n",
"\n",
"```\n",
"git clone https://github.com/NVIDIA/DeepLearningExamples\n",
"cd DeepLearningExamples/PyTorch/Recommendation/DLRM\n",
"```\n",
"\n",
"Next, build the DLRM container:\n",
"```\n",
"docker build . -t nvidia_dlrm_pyt\n",
"```\n",
"\n",
"Make a directory for storing DLRM data and start a docker container with:\n",
"```\n",
"mkdir -p data\n",
"docker run --runtime=nvidia -it --rm --ipc=host -v ${PWD}/data:/data nvidia_dlrm_pyt bash\n",
"```\n",
"\n",
"Within the docker interactive bash session, start Jupyter with\n",
"\n",
"```\n",
"export PYTHONPATH=/workspace/dlrm\n",
"jupyter notebook --ip 0.0.0.0 --port 8888\n",
"```\n",
"\n",
"Then open the Jupyter GUI interface on your host machine at http://localhost:8888. Within the container, the demo notebooks are located at `/workspace/dlrm/notebooks`.\n",
"\n",
"### 1.2 Hardware\n",
"This notebook can be executed on any CUDA-enabled NVIDIA GPU with at least 24GB of GPU memory, although for efficient mixed precision training, a [Tensor Core NVIDIA GPU](https://www.nvidia.com/en-us/data-center/tensorcore/) is desired (Volta, Turing or newer architectures). "
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "k7RLEcKhmTb0"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Sat Mar 28 06:36:59 2020 \n",
"+-----------------------------------------------------------------------------+\n",
"| NVIDIA-SMI 440.33.01 Driver Version: 440.33.01 CUDA Version: 10.2 |\n",
"|-------------------------------+----------------------+----------------------+\n",
"| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |\n",
"| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |\n",
"|===============================+======================+======================|\n",
"| 0 Tesla V100-SXM2... On | 00000000:06:00.0 Off | 0 |\n",
"| N/A 32C P0 42W / 300W | 0MiB / 32510MiB | 0% Default |\n",
"+-------------------------------+----------------------+----------------------+\n",
"| 1 Tesla V100-SXM2... On | 00000000:07:00.0 Off | 0 |\n",
"| N/A 34C P0 43W / 300W | 0MiB / 32510MiB | 0% Default |\n",
"+-------------------------------+----------------------+----------------------+\n",
"| 2 Tesla V100-SXM2... On | 00000000:0A:00.0 Off | 0 |\n",
"| N/A 34C P0 43W / 300W | 0MiB / 32510MiB | 0% Default |\n",
"+-------------------------------+----------------------+----------------------+\n",
"| 3 Tesla V100-SXM2... On | 00000000:0B:00.0 Off | 0 |\n",
"| N/A 32C P0 43W / 300W | 0MiB / 32510MiB | 0% Default |\n",
"+-------------------------------+----------------------+----------------------+\n",
"| 4 Tesla V100-SXM2... On | 00000000:85:00.0 Off | 0 |\n",
"| N/A 33C P0 43W / 300W | 0MiB / 32510MiB | 0% Default |\n",
"+-------------------------------+----------------------+----------------------+\n",
"| 5 Tesla V100-SXM2... On | 00000000:86:00.0 Off | 0 |\n",
"| N/A 35C P0 44W / 300W | 0MiB / 32510MiB | 0% Default |\n",
"+-------------------------------+----------------------+----------------------+\n",
"| 6 Tesla V100-SXM2... On | 00000000:89:00.0 Off | 0 |\n",
"| N/A 37C P0 44W / 300W | 0MiB / 32510MiB | 0% Default |\n",
"+-------------------------------+----------------------+----------------------+\n",
"| 7 Tesla V100-SXM2... On | 00000000:8A:00.0 Off | 0 |\n",
"| N/A 34C P0 43W / 300W | 0MiB / 32510MiB | 0% Default |\n",
"+-------------------------------+----------------------+----------------------+\n",
" \n",
"+-----------------------------------------------------------------------------+\n",
"| Processes: GPU Memory |\n",
"| GPU PID Type Process name Usage |\n",
"|=============================================================================|\n",
"| No running processes found |\n",
"+-----------------------------------------------------------------------------+\n"
]
}
],
"source": [
"!nvidia-smi"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "HqSUGePjmTb9"
},
"source": [
"<a id=\"2\"></a>\n",
"## 2. Data download and preprocessing\n",
"\n",
"Commercial recommendation systems are often trained on huge data sets, often in the order of terabytes, if not more. While datasets of this scale are rarely available to the public, the Criteo Terabyte click logs public [dataset](https://labs.criteo.com/2013/12/download-terabyte-click-logs/) offers a rare glimpse into the scale of real enterprise data: it contains ~1.3TB of uncompressed click logs collected over the course of 24 days, that can be used to train RecSys models that predict the ads click through rate. Yet, real datasets can be potentially one or two orders of magnitude larger, as enterprises will try to leverage as much historical data as they can use, for this will generally translate into better accuracy.\n",
"\n",
"Herein, we employ the Criteo Terabyte dataset to demonstrate the efficiency of the GPU-optimized DLRM training procedure. Each record in this dataset contains 40 columns: the first is a label column that indicates whether an user clicks an ad (value 1) or not (value 0). The next 13 columns are numeric, and the last 26 are categorical columns containing obfuscated hashed values. The columns and their values are all anonymized to protect user privacy.\n",
"\n",
"\n",
"We will first download and preprocess the Criteo Terabyte dataset. Note that this will require about 1TB of disk storage.\n",
"\n",
"Notice: before downloading data, you must check out and agree with the terms and conditions of the Criteo Terabyte [dataset](https://labs.criteo.com/2013/12/download-terabyte-click-logs/).\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "S2PR7weWmTcK"
},
"outputs": [],
"source": [
"! cd ../preproc && ./prepare_dataset.sh"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "EQAIszkxmTcT"
},
"source": [
"The original Facebook DLRM code base comes with a data preprocessing utility to preprocess the data. For continuous features, the data preprocessing steps include filling in missing values with 0 and normalization (shifting the values to be >=1 and taking natural logarithm). For categorical features, the preprocessing steps include building embedding tables and transforming hashed values into integer indicators. This code runs on a single CPU thread and takes ~6.5 days to transform the whole Criteo Terabyte data set. \n",
"\n",
"We improve the data preprocessing process with Spark on CPU to make use of all CPU threads. In the docker image, we have installed spark 2.4.5, which well start a standalone Spark cluster.This results in significant improvement in data pre-processing speed, scaling approximately linearly with the number of available CPU threads. This outputs the transformed data in parquet format. We finally convert the parquet data into the binary format similar to that designed by the Facebook team specially for the Criteo dataset. \n",
"\n",
"Our preprocessing scripts are designed for the Criteo Terabyte Dataset and should work with any other dataset with the same format. The data should be split into text files. Each line of those text files should contain a single training example. An example should consist of multiple fields separated by tabulators:\n",
"- The first field is the label `1` for a positive example and `0` for negative.\n",
"- The next `N` tokens should contain the numerical features separated by tabs.\n",
"- The next `M` tokens should contain the hashed categorical features separated by tabs.\n",
"\n",
"The outcomes of the data preprocessing steps are by default stored in `/data/dlrm/binary_dataset` containing 3 binary data files: `test_data.bin`, `train_data.bin` and `val_data.bin` and a JSON `file model_size.json` totalling ~650GB.\n",
"\n",
"Tips: by defaul the preprocessing script uses the first 23 days of the Criteo Terabyte dataset for training and the last day for validation. For a quick experiment, you can download and make use of a smaller number of days by modifying the `preproc/run_spark.sh` script."
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "RL8d9IwzmTcV"
},
"source": [
"<a id=\"3\"></a>\n",
"## 3. Training"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "o6wayGf1mTcX"
},
"source": [
"The repository provides several training recipes on 1 GPU with FP32 and automatic mixed precisions."
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "HapDsY4VmTce"
},
"source": [
"#### Training with FP32\n",
"Training on 1 GPU with FP32 with the `--nofp16` option."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": false
},
"outputs": [],
"source": [
"%run ../dlrm/scripts/main \\\n",
"--mode train \\\n",
"--dataset /data/dlrm/binary_dataset \\\n",
"--nofp16 \\\n",
"--save_checkpoint_path ./dlrm_model_fp32.pt"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"On a V100 32GB, training takes approximately 2h56m for 1 epoch to an AUC of ~0.8. The final result should look similar to the below.\n",
"\n",
"```\n",
"Epoch:[0/1] [127600/128028] eta: 0:00:34 loss: 0.1226 step_time: 0.080038 lr: 1.1766\n",
"Epoch:[0/1] [127800/128028] eta: 0:00:18 loss: 0.1224 step_time: 0.080307 lr: 1.1480\n",
"Epoch:[0/1] [128000/128028] eta: 0:00:02 loss: 0.1221 step_time: 0.080562 lr: 1.1199\n",
"Test: [200/2721] loss: 0.1236 step_time: 0.0303\n",
"Test: [400/2721] loss: 0.1248 step_time: 0.0245\n",
"Test: [600/2721] loss: 0.1262 step_time: 0.0244\n",
"Test: [800/2721] loss: 0.1262 step_time: 0.0245\n",
"Test: [1000/2721] loss: 0.1293 step_time: 0.0245\n",
"Test: [1200/2721] loss: 0.1307 step_time: 0.0245\n",
"Test: [1400/2721] loss: 0.1281 step_time: 0.0245\n",
"Test: [1600/2721] loss: 0.1242 step_time: 0.0246\n",
"Test: [1800/2721] loss: 0.1230 step_time: 0.0245\n",
"Test: [2000/2721] loss: 0.1226 step_time: 0.0244\n",
"Test: [2200/2721] loss: 0.1239 step_time: 0.0246\n",
"Test: [2400/2721] loss: 0.1256 step_time: 0.0249\n",
"Test: [2600/2721] loss: 0.1247 step_time: 0.0248\n",
"Epoch 0 step 128027. Test loss 0.12557, auc 0.803517\n",
"Checkpoint saving took 42.90 [s]\n",
"DLL 2020-03-29 15:59:44.759627 - () best_auc : 0.80352 best_epoch : 1.00 average_train_throughput : 4.07e+05 average_test_throughput : 1.33e+06 \n",
"```"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "j-aFEwb4mTcn"
},
"source": [
"#### Training with mixed-precision\n",
"Mixed precision training can be done with the `--fp16` option. Under the hood, the NVIDIA Pytorch extension library [Apex](https://github.com/NVIDIA/apex) to enable mixed precision training.\n",
"\n",
"Note: for subsequent launches of the %run magic, please restart your kernel manualy or execute the below cell to restart kernel."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Note: for subsequent launches of the %run magic, \n",
"# please restart your kernel manualy or execute this cell to restart kernel.\n",
"import os\n",
"os._exit(00)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "o3AZ-CXYmTcp",
"scrolled": false
},
"outputs": [],
"source": [
"%run ../dlrm/scripts/main \\\n",
"--mode train \\\n",
"--dataset /data/dlrm/binary_dataset \\\n",
"--fp16 \\\n",
"--save_checkpoint_path ./dlrm_model_fp16.pt"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"On a V100 32GB, training takes approximately 1h41m for 1 epoch to an AUC of ~0.8. Thus, mixed precision training provides a speed up of ~ 1.7x.\n",
"\n",
"The final result should look similar to the below.\n",
"\n",
"```\n",
"...\n",
"Epoch:[0/1] [127800/128028] eta: 0:00:11 loss: 0.1224 step_time: 0.050719 lr: 1.1480\n",
"Epoch:[0/1] [128000/128028] eta: 0:00:01 loss: 0.1221 step_time: 0.050499 lr: 1.1199\n",
"Test: [200/2721] loss: 0.1236 step_time: 0.0271\n",
"Test: [400/2721] loss: 0.1247 step_time: 0.0278\n",
"Test: [600/2721] loss: 0.1262 step_time: 0.0275\n",
"Test: [800/2721] loss: 0.1262 step_time: 0.0278\n",
"Test: [1000/2721] loss: 0.1293 step_time: 0.0273\n",
"Test: [1200/2721] loss: 0.1306 step_time: 0.0264\n",
"Test: [1400/2721] loss: 0.1281 step_time: 0.0281\n",
"Test: [1600/2721] loss: 0.1242 step_time: 0.0273\n",
"Test: [1800/2721] loss: 0.1229 step_time: 0.0280\n",
"Test: [2000/2721] loss: 0.1226 step_time: 0.0274\n",
"Test: [2200/2721] loss: 0.1239 step_time: 0.0278\n",
"Test: [2400/2721] loss: 0.1256 step_time: 0.0289\n",
"Test: [2600/2721] loss: 0.1247 step_time: 0.0282\n",
"Epoch 0 step 128027. Test loss 0.12557, auc 0.803562\n",
"Checkpoint saving took 40.46 [s]\n",
"DLL 2020-03-28 15:15:36.290149 - () best_auc : 0.80356 best_epoch : 1.00 average_train_throughput : 6.47e+05 average_test_throughput : 1.17e+06\n",
"```"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "X959LYwjmTcw"
},
"source": [
"<a id=\"4\"></a>\n",
"## 4. Testing trained model\n",
"\n",
"After model training has completed, we can test the trained model against the Criteo test dataset. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Note: for subsequent launches of the %run magic, \n",
"# please restart your kernel manualy or execute this cell to restart kernel.\n",
"import os\n",
"os._exit(00)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": false
},
"outputs": [],
"source": [
"%run ../dlrm/scripts/main \\\n",
"--mode test\\\n",
"--dataset /data/dlrm/binary_dataset \\\n",
"--load_checkpoint_path ./dlrm_model_fp16.pt"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "g8MxXY5GmTc8"
},
"source": [
"# Conclusion\n",
"\n",
"In this notebook, we have walked through the complete process of preparing the container and data required for training the DLRM model. We have also investigated various training options with FP32 and automatic mixed precision, trained and tested DLRM models with new test data.\n",
"\n",
"## What's next\n",
"Now it's time to try the DLRM model on your own data. Observe the performance impact of mixed precision training while comparing the final accuracy of the models trained with FP32 and mixed precision.\n"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "249yGNLmmTc_"
},
"outputs": [],
"source": []
}
],
"metadata": {
"colab": {
"include_colab_link": true,
"name": "TensorFlow_UNet_Industrial_Colab_train_and_inference.ipynb",
"provenance": []
},
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.9"
}
},
"nbformat": 4,
"nbformat_minor": 1
}

View file

@ -0,0 +1,69 @@
<!-- #region -->
# DLRM Jupyter demo notebooks
This folder contains the demo notebooks for DLRM. The most convenient way to use these notebooks is via using a docker container, which provides a self-contained, isolated and re-producible environment for all experiments. Refer to the [Quick Start Guide section](../README.md) of the Readme documentation for a comprehensive guide.
First, clone the repository:
```
git clone https://github.com/NVIDIA/DeepLearningExamples
cd DeepLearningExamples/PyTorch/Recommendation/DLRM
```
## Notebook list
### 1. Pytorch_DLRM_pyt_train_and_inference.ipynb: training and inference demo
To execute this notebook, first build the DLRM container:
```
docker build . -t nvidia_dlrm_pyt
```
Make a directory for storing DLRM data and start a docker containerexport PYTHONPATH=/workspace/dlrm with:
```
mkdir -p data
docker run --runtime=nvidia -it --rm --ipc=host -v ${PWD}/data:/data nvidia_dlrm_pyt bash
```
Within the docker interactive bash session, start Jupyter with
```
export PYTHONPATH=/workspace/dlrm
jupyter notebook --ip 0.0.0.0 --port 8888
```
Then open the Jupyter GUI interface on your host machine at http://localhost:8888. Within the container, this demo notebook is located at `/workspace/dlrm/notebooks`.
<!-- #endregion -->
### 2. DLRM_Triton_inference_demo.ipynb: inference demo with the NVIDIA Triton Inference server.
To execute this notebook, first build the following inference container:
```
docker build -t dlrm-inference . -f triton/Dockerfile
```
Start in interactive docker session with:
```
docker run -it --rm --gpus device=0 --shm-size=1g --ulimit memlock=-1 --ulimit stack=67108864 --net=host -v <PATH_TO_SAVED_MODEL>:/models -v <PATH_TO_EXPORT_MODEL>:/repository dlrm-inference bash
```
where:
- PATH_TO_SAVED_MODEL: directory containing the trained DLRM models.
- PATH_TO_EXPORT_MODEL: directory which will contain the converted model to be used with the NVIDIA Triton inference server.
Within the docker interactive bash session, start Jupyter with
```
export PYTHONPATH=/workspace/dlrm
jupyter notebook --ip 0.0.0.0 --port 8888
```
Then open the Jupyter GUI interface on your host machine at http://localhost:8888. Within the container, this demo notebook is located at `/workspace/dlrm/notebooks`.
```python
```

Binary file not shown.

After

Width:  |  Height:  |  Size: 31 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 102 KiB

View file

@ -0,0 +1,90 @@
# Copyright (c) 2020 NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import pandas as pd
import os
from joblib import Parallel, delayed
import glob
import argparse
import tqdm
import subprocess
def process_file(f, dst):
all_columns_sorted = [f'_c{i}' for i in range(0, 40)]
data = pd.read_parquet(f)
data = data[all_columns_sorted]
dense_columns = [f'_c{i}' for i in range(1, 14)]
data[dense_columns] = data[dense_columns].astype(np.float32)
data = data.to_records(index=False)
data = data.tobytes()
dst_file = dst + '/' + f.split('/')[-1] + '.bin'
with open(dst_file, 'wb') as dst_fd:
dst_fd.write(data)
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--src_dir', type=str)
parser.add_argument('--intermediate_dir', type=str)
parser.add_argument('--dst_dir', type=str)
parser.add_argument('--parallel_jobs', default=40, type=int)
args = parser.parse_args()
print('Processing train files...')
train_src_files = glob.glob(args.src_dir + '/train/*.parquet')
train_intermediate_dir = args.intermediate_dir + '/train'
os.makedirs(train_intermediate_dir, exist_ok=True)
Parallel(n_jobs=args.parallel_jobs)(delayed(process_file)(f, train_intermediate_dir) for f in tqdm.tqdm(train_src_files))
print('Train files conversion done')
print('Processing test files...')
test_src_files = glob.glob(args.src_dir + '/test/*.parquet')
test_intermediate_dir = args.intermediate_dir + '/test'
os.makedirs(test_intermediate_dir, exist_ok=True)
Parallel(n_jobs=args.parallel_jobs)(delayed(process_file)(f, test_intermediate_dir) for f in tqdm.tqdm(test_src_files))
print('Test files conversion done')
print('Processing validation files...')
valid_src_files = glob.glob(args.src_dir + '/validation/*.parquet')
valid_intermediate_dir = args.intermediate_dir + '/valid'
os.makedirs(valid_intermediate_dir, exist_ok=True)
Parallel(n_jobs=args.parallel_jobs)(delayed(process_file)(f, valid_intermediate_dir) for f in tqdm.tqdm(valid_src_files))
print('Validation files conversion done')
os.makedirs(args.dst_dir, exist_ok=True)
print('Concatenating train files')
os.system(f'cat {train_intermediate_dir}/*.bin > {args.dst_dir}/train_data.bin')
print('Concatenating test files')
os.system(f'cat {test_intermediate_dir}/*.bin > {args.dst_dir}/test_data.bin')
print('Concatenating validation files')
os.system(f'cat {valid_intermediate_dir}/*.bin > {args.dst_dir}/val_data.bin')
print('Done')
if __name__ == '__main__':
main()

View file

@ -0,0 +1,59 @@
# Copyright (c) 2020 NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#! /bin/bash
set -e
set -x
ls -ltrash
download_dir=${download_dir:-'/data/dlrm/criteo'}
./verify_criteo_downloaded.sh ${download_dir}
spark_output_path=${spark_output_path:-'/data/dlrm/spark/output'}
if [ -f ${spark_output_path}/train/_SUCCESS ] \
&& [ -f ${spark_output_path}/validation/_SUCCESS ] \
&& [ -f ${spark_output_path}/test/_SUCCESS ]; then
echo "Spark preprocessing already carried done"
else
echo "Performing spark preprocessing"
./run_spark.sh ${download_dir} ${spark_output_path}
fi
conversion_intermediate_dir=${conversion_intermediate_dir:-'/data/dlrm/intermediate_binary'}
final_output_dir=${final_output_dir:-'/data/dlrm/binary_dataset'}
if [ -f ${final_output_dir}/train_data.bin ] \
&& [ -f ${final_output_dir}/val_data.bin ] \
&& [ -f ${final_output_dir}/test_data.bin ] \
&& [ -f ${final_output_dir}/model_sizes.json ]; then
echo "Final conversion already done"
else
echo "Performing final conversion to a custom data format"
python parquet_to_binary.py --parallel_jobs 40 --src_dir ${spark_output_path} \
--intermediate_dir ${conversion_intermediate_dir} \
--dst_dir ${final_output_dir}
cp "${spark_output_path}/model_size.json" "${final_output_dir}/model_size.json"
fi
echo "Done preprocessing the Criteo Kaggle Dataset"
echo "You can now start the training with: "
echo "python -m dlrm.scripts.main --mode train --dataset /data/dlrm/binary_dataset/ --model_config dlrm/config/default.json"

View file

@ -0,0 +1,166 @@
# Copyright (c) 2020 NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#########################################################################
# File Name: run-spark.sh
#!/bin/bash
set -e
# the environment variables to run spark job
# should modify below environment variables
# the data path including 1TB criteo data, day_0, day_1, ...
export INPUT_PATH=${1:-'/data/dlrm/criteo'}
# the output path, use for generating the dictionary and the final dataset
# the output folder should have more than 300GB
export OUTPUT_PATH=${2:-'/data/dlrm/spark/output'}
# spark local dir should have about 3TB
# the temporary path used for spark shuffle write
export SPARK_LOCAL_DIRS='/data/dlrm/spark/tmp'
# below numbers should be adjusted according to the resource of your running environment
# set the total number of CPU cores, spark can use
export TOTAL_CORES=80
# set the number of executors
export NUM_EXECUTORS=8
# the cores for each executor, it'll be calculated
export NUM_EXECUTOR_CORES=$((${TOTAL_CORES}/${NUM_EXECUTORS}))
# unit: GB, set the max memory you want to use
export TOTAL_MEMORY=800
# unit: GB, set the memory for driver
export DRIVER_MEMORY=32
# the memory per executor
export EXECUTOR_MEMORY=$(((${TOTAL_MEMORY}-${DRIVER_MEMORY})/${NUM_EXECUTORS}))
# use frequency_limit=15 or not
# by default use a frequency limit of 15
USE_FREQUENCY_LIMIT=1
OPTS=""
if [[ $USE_FREQUENCY_LIMIT == 1 ]]; then
OPTS="--frequency_limit 15"
fi
export SPARK_HOME=/opt/spark-2.4.5-bin-hadoop2.7
export JAVA_HOME=/usr/lib/jvm/java-8-openjdk-amd64
export PATH=$SPARK_HOME/bin:$SPARK_HOME/sbin:$PATH
# we use spark standalone to run the job
export MASTER=spark://$HOSTNAME:7077
echo "Starting spark standalone"
start-master.sh
start-slave.sh $MASTER
echo "Generating the dictionary..."
spark-submit --master $MASTER \
--driver-memory "${DRIVER_MEMORY}G" \
--executor-cores $NUM_EXECUTOR_CORES \
--executor-memory "${EXECUTOR_MEMORY}G" \
--conf spark.cores.max=$TOTAL_CORES \
--conf spark.task.cpus=1 \
--conf spark.sql.files.maxPartitionBytes=1073741824 \
--conf spark.sql.shuffle.partitions=600 \
--conf spark.driver.maxResultSize=2G \
--conf spark.locality.wait=0s \
--conf spark.network.timeout=1800s \
spark_data_utils.py --mode generate_models \
$OPTS \
--input_folder $INPUT_PATH \
--days 0-23 \
--model_folder $OUTPUT_PATH/models \
--write_mode overwrite --low_mem 2>&1 | tee submit_dict_log.txt
echo "Transforming the train data from day_0 to day_22..."
spark-submit --master $MASTER \
--driver-memory "${DRIVER_MEMORY}G" \
--executor-cores $NUM_EXECUTOR_CORES \
--executor-memory "${EXECUTOR_MEMORY}G" \
--conf spark.cores.max=$TOTAL_CORES \
--conf spark.task.cpus=1 \
--conf spark.sql.files.maxPartitionBytes=1073741824 \
--conf spark.sql.shuffle.partitions=600 \
--conf spark.driver.maxResultSize=2G \
--conf spark.locality.wait=0s \
--conf spark.network.timeout=1800s \
spark_data_utils.py --mode transform \
--input_folder $INPUT_PATH \
--days 0-22 \
--output_folder $OUTPUT_PATH/train \
--model_size_file $OUTPUT_PATH/model_size.json \
--model_folder $OUTPUT_PATH/models \
--write_mode overwrite --low_mem 2>&1 | tee submit_train_log.txt
echo "Splitting the last day into 2 parts of test and validation..."
last_day=$INPUT_PATH/day_23
temp_test=$OUTPUT_PATH/temp/test
temp_validation=$OUTPUT_PATH/temp/validation
mkdir -p $temp_test $temp_validation
lines=`wc -l $last_day | awk '{print $1}'`
former=$((lines / 2))
latter=$((lines - former))
head -n $former $last_day > $temp_test/day_23
tail -n $latter $last_day > $temp_validation/day_23
echo "Transforming the test data in day_23..."
spark-submit --master $MASTER \
--driver-memory "${DRIVER_MEMORY}G" \
--executor-cores $NUM_EXECUTOR_CORES \
--executor-memory "${EXECUTOR_MEMORY}G" \
--conf spark.cores.max=$TOTAL_CORES \
--conf spark.task.cpus=1 \
--conf spark.sql.files.maxPartitionBytes=1073741824 \
--conf spark.sql.shuffle.partitions=30 \
--conf spark.driver.maxResultSize=2G \
--conf spark.locality.wait=0s \
--conf spark.network.timeout=1800s \
spark_data_utils.py --mode transform \
--input_folder $temp_test \
--days 23-23 \
--output_folder $OUTPUT_PATH/test \
--output_ordering input \
--model_folder $OUTPUT_PATH/models \
--write_mode overwrite --low_mem 2>&1 | tee submit_test_log.txt
echo "Transforming the validation data in day_23..."
spark-submit --master $MASTER \
--driver-memory "${DRIVER_MEMORY}G" \
--executor-cores $NUM_EXECUTOR_CORES \
--executor-memory "${EXECUTOR_MEMORY}G" \
--conf spark.cores.max=$TOTAL_CORES \
--conf spark.task.cpus=1 \
--conf spark.sql.files.maxPartitionBytes=1073741824 \
--conf spark.sql.shuffle.partitions=30 \
--conf spark.driver.maxResultSize=2G \
--conf spark.locality.wait=0s \
--conf spark.network.timeout=1800s \
spark_data_utils.py --mode transform \
--input_folder $temp_validation \
--days 23-23 \
--output_folder $OUTPUT_PATH/validation \
--output_ordering input \
--model_folder $OUTPUT_PATH/models \
--write_mode overwrite --low_mem 2>&1 | tee submit_validation_log.txt
rm -r $temp_test $temp_validation

View file

@ -0,0 +1,507 @@
# Copyright (c) 2020 NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import os
import sys
from argparse import ArgumentParser
from collections import OrderedDict
from contextlib import contextmanager
from operator import itemgetter
from time import time
from pyspark import broadcast
from pyspark.sql import Row, SparkSession, Window
from pyspark.sql.functions import *
from pyspark.sql.types import *
LABEL_COL = 0
INT_COLS = list(range(1, 14))
CAT_COLS = list(range(14, 40))
def get_column_counts_with_frequency_limit(df, frequency_limit = None):
cols = ['_c%d' % i for i in CAT_COLS]
df = (df
.select(posexplode(array(*cols)))
.withColumnRenamed('pos', 'column_id')
.withColumnRenamed('col', 'data')
.filter('data is not null')
.groupBy('column_id', 'data')
.count())
if frequency_limit:
frequency_limit = frequency_limit.split(",")
exclude = []
default_limit = None
for fl in frequency_limit:
frequency_pair = fl.split(":")
if len(frequency_pair) == 1:
default_limit = int(frequency_pair[0])
elif len(frequency_pair) == 2:
df = df.filter((col('column_id') != int(frequency_pair[0]) - CAT_COLS[0]) | (col('count') >= int(frequency_pair[1])))
exclude.append(int(frequency_pair[0]))
if default_limit:
remain = [x - CAT_COLS[0] for x in CAT_COLS if x not in exclude]
df = df.filter((~col('column_id').isin(remain)) | (col('count') >= default_limit))
# for comparing isin and separate filter
# for i in remain:
# df = df.filter((col('column_id') != i - CAT_COLS[0]) | (col('count') >= default_limit))
return df
def assign_id_with_window(df):
windowed = Window.partitionBy('column_id').orderBy(desc('count'))
return (df
.withColumn('id', row_number().over(windowed))
.withColumnRenamed('count', 'model_count'))
def assign_low_mem_partial_ids(df):
# To avoid some scaling issues with a simple window operation, we use a more complex method
# to compute the same thing, but in a more distributed spark specific way
df = df.orderBy(asc('column_id'), desc('count'))
# The monotonically_increasing_id is the partition id in the top 31 bits and the rest
# is an increasing count of the rows within that partition. So we split it into two parts,
# the partion id part_id and the count mono_id
df = df.withColumn('part_id', spark_partition_id())
return df.withColumn('mono_id', monotonically_increasing_id() - shiftLeft(col('part_id'), 33))
def assign_low_mem_final_ids(df):
# Now we can find the minimum and maximum mono_ids within a given column/partition pair
sub_model = df.groupBy('column_id', 'part_id').agg(max('mono_id').alias('top'), min('mono_id').alias('bottom'))
sub_model = sub_model.withColumn('diff', col('top') - col('bottom') + 1)
sub_model = sub_model.drop('top')
# This window function is over aggregated column/partition pair table. It will do a running sum of the rows
# within that column
windowed = Window.partitionBy('column_id').orderBy('part_id').rowsBetween(Window.unboundedPreceding, -1)
sub_model = sub_model.withColumn('running_sum', sum('diff').over(windowed)).na.fill(0, ["running_sum"])
joined = df.withColumnRenamed('column_id', 'i_column_id')
joined = joined.withColumnRenamed('part_id', 'i_part_id')
joined = joined.withColumnRenamed('count', 'model_count')
# Then we can join the original input with the pair it is a part of
joined = joined.join(sub_model, (col('i_column_id') == col('column_id')) & (col('part_id') == col('i_part_id')))
# So with all that we can subtract bottom from mono_id makeing it start at 0 for each partition
# and then add in the running_sum so the id is contiguous and unique for the entire column. + 1 to make it match the 1 based indexing
# for row_number
ret = joined.select(col('column_id'),
col('data'),
(col('mono_id') - col('bottom') + col('running_sum') + 1).cast(IntegerType()).alias('id'),
col('model_count'))
return ret
def get_column_models(combined_model):
for i in CAT_COLS:
model = (combined_model
.filter('column_id == %d' % (i - CAT_COLS[0]))
.drop('column_id'))
yield i, model
def col_of_rand_long():
return (rand() * (1 << 52)).cast(LongType())
def skewed_join(df, model, col_name, cutoff):
# Most versions of spark don't have a good way
# to deal with a skewed join out of the box.
# Some do and if you want to replace this with
# one of those that would be great.
# Because we have statistics about the skewedness
# that we can used we divide the model up into two parts
# one part is the highly skewed part and we do a
# broadcast join for that part, but keep the result in
# a separate column
b_model = broadcast(model.filter(col('model_count') >= cutoff)
.withColumnRenamed('data', col_name)
.drop('model_count'))
df = (df
.join(b_model, col_name, how='left')
.withColumnRenamed('id', 'id_tmp'))
# We also need to spread the skewed data that matched
# evenly. We will use a source of randomness for this
# but use a -1 for anything that still needs to be matched
if 'ordinal' in df.columns:
rand_column = col('ordinal')
else:
rand_column = col_of_rand_long()
df = df.withColumn('join_rand',
# null values are not in the model, they are filtered out
# but can be a source of skewedness so include them in
# the even distribution
when(col('id_tmp').isNotNull() | col(col_name).isNull(), rand_column)
.otherwise(lit(-1)))
# Null out the string data that already matched to save memory
df = df.withColumn(col_name,
when(col('id_tmp').isNotNull(), None)
.otherwise(col(col_name)))
# Now do the second join, which will be a non broadcast join.
# Sadly spark is too smart for its own good and will optimize out
# joining on a column it knows will always be a constant value.
# So we have to make a convoluted version of assigning a -1 to the
# randomness column for the model itself to work around that.
nb_model = (model
.withColumn('join_rand', when(col('model_count') < cutoff, lit(-1)).otherwise(lit(-2)))
.filter(col('model_count') < cutoff)
.withColumnRenamed('data', col_name)
.drop('model_count'))
df = (df
.join(nb_model, ['join_rand', col_name], how='left')
.drop(col_name, 'join_rand')
# Pick either join result as an answer
.withColumn(col_name, coalesce(col('id'), col('id_tmp')))
.drop('id', 'id_tmp'))
return df
def apply_models(df, models, broadcast_model = False, skew_broadcast_pct = 1.0):
# sort the models so broadcast joins come first. This is
# so we reduce the amount of shuffle data sooner than later
# If we parsed the string hex values to ints early on this would
# not make a difference.
models = sorted(models, key=itemgetter(3), reverse=True)
for i, model, original_rows, would_broadcast in models:
col_name = '_c%d' % i
if not (would_broadcast or broadcast_model):
# The data is highly skewed so we need to offset that
cutoff = int(original_rows * skew_broadcast_pct/100.0)
df = skewed_join(df, model, col_name, cutoff)
else:
# broadcast joins can handle skewed data so no need to
# do anything special
model = (model.drop('model_count')
.withColumnRenamed('data', col_name))
model = broadcast(model) if broadcast_model else model
df = (df
.join(model, col_name, how='left')
.drop(col_name)
.withColumnRenamed('id', col_name))
return df.fillna(0, ['_c%d' % i for i in CAT_COLS])
def transform_log(df, transform_log = False):
cols = ['_c%d' % i for i in INT_COLS]
if transform_log:
for col_name in cols:
df = df.withColumn(col_name, log(df[col_name] + 3))
return df.fillna(0, cols)
def would_broadcast(spark, str_path):
sc = spark.sparkContext
config = sc._jsc.hadoopConfiguration()
path = sc._jvm.org.apache.hadoop.fs.Path(str_path)
fs = sc._jvm.org.apache.hadoop.fs.FileSystem.get(config)
stat = fs.listFiles(path, True)
sum = 0
while stat.hasNext():
sum = sum + stat.next().getLen()
sql_conf = sc._jvm.org.apache.spark.sql.internal.SQLConf()
cutoff = sql_conf.autoBroadcastJoinThreshold() * sql_conf.fileCompressionFactor()
return sum <= cutoff
def delete_data_source(spark, path):
sc = spark.sparkContext
config = sc._jsc.hadoopConfiguration()
path = sc._jvm.org.apache.hadoop.fs.Path(path)
sc._jvm.org.apache.hadoop.fs.FileSystem.get(config).delete(path, True)
def load_raw(spark, folder, day_range):
label_fields = [StructField('_c%d' % LABEL_COL, IntegerType())]
int_fields = [StructField('_c%d' % i, IntegerType()) for i in INT_COLS]
str_fields = [StructField('_c%d' % i, StringType()) for i in CAT_COLS]
schema = StructType(label_fields + int_fields + str_fields)
paths = [os.path.join(folder, 'day_%d' % i) for i in day_range]
return (spark
.read
.schema(schema)
.option('sep', '\t')
.csv(paths))
def rand_ordinal(df):
# create a random long from the double precision float.
# The fraction part of a double is 52 bits, so we try to capture as much
# of that as possible
return df.withColumn('ordinal', col_of_rand_long())
def day_from_ordinal(df, num_days):
return df.withColumn('day', (col('ordinal') % num_days).cast(IntegerType()))
def day_from_input_file(df):
return df.withColumn('day', substring_index(input_file_name(), '_', -1).cast(IntegerType()))
def psudo_sort_by_day_plus(spark, df, num_days):
# Sort is very expensive because it needs to calculate the partitions
# which in our case may involve rereading all of the data. In some cases
# we can avoid this by repartitioning the data and sorting within a single partition
shuffle_parts = int(spark.conf.get('spark.sql.shuffle.partitions'))
extra_parts = int(shuffle_parts/num_days)
if extra_parts <= 0:
df = df.repartition('day')
else:
#We want to spread out the computation to about the same amount as shuffle_parts
divided = (col('ordinal') / num_days).cast(LongType())
extra_ident = divided % extra_parts
df = df.repartition(col('day'), extra_ident)
return df.sortWithinPartitions('day', 'ordinal')
def load_combined_model(spark, model_folder):
path = os.path.join(model_folder, 'combined.parquet')
return spark.read.parquet(path)
def save_combined_model(df, model_folder, mode=None):
path = os.path.join(model_folder, 'combined.parquet')
df.write.parquet(path, mode=mode)
def delete_combined_model(spark, model_folder):
path = os.path.join(model_folder, 'combined.parquet')
delete_data_source(spark, path)
def load_low_mem_partial_ids(spark, model_folder):
path = os.path.join(model_folder, 'partial_ids.parquet')
return spark.read.parquet(path)
def save_low_mem_partial_ids(df, model_folder, mode=None):
path = os.path.join(model_folder, 'partial_ids.parquet')
df.write.parquet(path, mode=mode)
def delete_low_mem_partial_ids(spark, model_folder):
path = os.path.join(model_folder, 'partial_ids.parquet')
delete_data_source(spark, path)
def load_column_models(spark, model_folder, count_required):
for i in CAT_COLS:
path = os.path.join(model_folder, '%d.parquet' % i)
df = spark.read.parquet(path)
if count_required:
values = df.agg(sum('model_count').alias('sum'), count('*').alias('size')).collect()
else:
values = df.agg(sum('model_count').alias('sum')).collect()
yield i, df, values[0], would_broadcast(spark, path)
def save_column_models(column_models, model_folder, mode=None):
for i, model in column_models:
path = os.path.join(model_folder, '%d.parquet' % i)
model.write.parquet(path, mode=mode)
def save_model_size(model_size, path, write_mode):
if os.path.exists(path) and write_mode == 'errorifexists':
print('Error: model size file %s exists' % path)
sys.exit(1)
os.makedirs(os.path.dirname(os.path.abspath(path)), exist_ok=True)
with open(path, 'w') as fp:
json.dump(model_size, fp, indent=4)
_benchmark = {}
@contextmanager
def _timed(step):
start = time()
yield
end = time()
_benchmark[step] = end - start
def _parse_args():
parser = ArgumentParser()
parser.add_argument(
'--mode',
required=True,
choices=['generate_models', 'transform'])
parser.add_argument('--days', required=True)
parser.add_argument('--input_folder', required=True)
parser.add_argument('--output_folder')
parser.add_argument('--model_size_file')
parser.add_argument('--model_folder', required=True)
parser.add_argument(
'--write_mode',
choices=['overwrite', 'errorifexists'],
default='errorifexists')
parser.add_argument('--frequency_limit')
parser.add_argument('--no_numeric_log_col', action='store_true')
#Support for running in a lower memory environment
parser.add_argument('--low_mem', action='store_true')
parser.add_argument(
'--output_ordering',
choices=['total_random', 'day_random', 'any', 'input'],
default='total_random')
parser.add_argument(
'--output_partitioning',
choices=['day', 'none'],
default='none')
parser.add_argument('--dict_build_shuffle_parallel_per_day', type=int, default=2)
parser.add_argument('--apply_shuffle_parallel_per_day', type=int, default=25)
parser.add_argument('--skew_broadcast_pct', type=float, default=1.0)
parser.add_argument('--debug_mode', action='store_true')
args = parser.parse_args()
start, end = args.days.split('-')
args.day_range = list(range(int(start), int(end) + 1))
args.days = len(args.day_range)
return args
def _main():
args = _parse_args()
spark = SparkSession.builder.getOrCreate()
df = load_raw(spark, args.input_folder, args.day_range)
if args.mode == 'generate_models':
spark.conf.set('spark.sql.shuffle.partitions', args.days * args.dict_build_shuffle_parallel_per_day)
with _timed('generate models'):
col_counts = get_column_counts_with_frequency_limit(df, args.frequency_limit)
if args.low_mem:
# in low memory mode we have to save an intermediate result
# because if we try to do it in one query spark ends up assigning the
# partial ids in two different locations that are not guaranteed to line up
# this prevents that from happening by assigning the partial ids
# and then writeing them out.
save_low_mem_partial_ids(
assign_low_mem_partial_ids(col_counts),
args.model_folder,
args.write_mode)
save_combined_model(
assign_low_mem_final_ids(load_low_mem_partial_ids(spark, args.model_folder)),
args.model_folder,
args.write_mode)
if not args.debug_mode:
delete_low_mem_partial_ids(spark, args.model_folder)
else:
save_combined_model(
assign_id_with_window(col_counts),
args.model_folder,
args.write_mode)
save_column_models(
get_column_models(load_combined_model(spark, args.model_folder)),
args.model_folder,
args.write_mode)
if not args.debug_mode:
delete_combined_model(spark, args.model_folder)
if args.mode == 'transform':
spark.conf.set('spark.sql.shuffle.partitions', args.days * args.apply_shuffle_parallel_per_day)
with _timed('transform'):
if args.output_ordering == 'total_random':
df = rand_ordinal(df)
if args.output_partitioning == 'day':
df = day_from_ordinal(df, args.days)
elif args.output_ordering == 'day_random':
df = rand_ordinal(df)
df = day_from_input_file(df)
elif args.output_ordering == 'input':
df = df.withColumn('ordinal', monotonically_increasing_id())
if args.output_partitioning == 'day':
df = day_from_input_file(df)
else: # any ordering
if args.output_partitioning == 'day':
df = day_from_input_file(df)
models = list(load_column_models(spark, args.model_folder, bool(args.model_size_file)))
if args.model_size_file:
save_model_size(
OrderedDict(('_c%d' % i, agg.size) for i, _, agg, _ in models),
args.model_size_file,
args.write_mode)
models = [(i, df, agg.sum, flag) for i, df, agg, flag in models]
df = apply_models(
df,
models,
not args.low_mem,
args.skew_broadcast_pct)
df = transform_log(df, not args.no_numeric_log_col)
if args.output_partitioning == 'day':
partitionBy = 'day'
else:
partitionBy = None
if args.output_ordering == 'total_random':
if args.output_partitioning == 'day':
df = psudo_sort_by_day_plus(spark, df, args.days)
else: # none
# Don't do a full sort it is expensive. Order is random so
# just make it random
df = df.repartition('ordinal').sortWithinPartitions('ordinal')
df = df.drop('ordinal')
elif args.output_ordering == 'day_random':
df = psudo_sort_by_day_plus(spark, df, args.days)
df = df.drop('ordinal')
if args.output_partitioning != 'day':
df = df.drop('day')
elif args.output_ordering == 'input':
if args.low_mem:
# This is the slowest option. We totally messed up the order so we have to put
# it back in the correct order
df = df.orderBy('ordinal')
else:
# Applying the dictionary happened within a single task so we are already really
# close to the correct order, just need to sort within the partition
df = df.sortWithinPartitions('ordinal')
df = df.drop('ordinal')
if args.output_partitioning != 'day':
df = df.drop('day')
# else: any ordering so do nothing the ordering does not matter
df.write.parquet(
args.output_folder,
mode=args.write_mode,
partitionBy=partitionBy)
print('=' * 100)
print(_benchmark)
if __name__ == '__main__':
_main()

View file

@ -0,0 +1,34 @@
# Copyright (c) 2020 NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#! /bin/bash
set -e
set -x
download_dir=${1:-'/data/dlrm/criteo'}
cd ${download_dir}
for i in $(seq 0 23); do
filename=day_${i}
if [ -f $filename ]; then
echo "$filename exists, OK"
else
echo "$filename does not exist. Please follow the instructions at: http://labs.criteo.com/2013/12/download-terabyte-click-logs/ to download it"
exit 1
fi
done
cd -
echo "Criteo data verified"

View file

@ -0,0 +1,4 @@
-e git://github.com/NVIDIA/dllogger#egg=dllogger
absl-py>=0.7.0
numpy
pyarrow

View file

@ -0,0 +1,31 @@
# Copyright (c) 2020 NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import subprocess
from setuptools import setup, find_packages
from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension
abspath = os.path.dirname(os.path.realpath(__file__))
print(find_packages(exclude=["*.tests", "*.tests.*", "tests.*", "tests"]))
setup(name="dlrm",
package_dir={'dlrm': 'dlrm'},
version="1.0.0",
description="Reimplementation of Facebook's DLRM",
packages=find_packages(exclude=["*.tests", "*.tests.*", "tests.*", "tests"]),
zip_safe=False,
cmdclass={"build_ext": BuildExtension})

View file

@ -0,0 +1,31 @@
# Copyright (c) 2020 NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
ARG FROM_IMAGE_NAME=nvcr.io/nvidia/pytorch:20.03-py3
FROM nvcr.io/nvidia/tritonserver:20.03-py3-clientsdk as trt
FROM ${FROM_IMAGE_NAME}
ADD requirements.txt .
RUN pip install -r requirements.txt
RUN pip install onnxruntime
COPY --from=trt /workspace/install /workspace/install/
ENV LD_LIBRARY_PATH /workspace/install/lib:${LD_LIBRARY_PATH}
RUN ls /workspace/install/python
RUN pip install /workspace/install/python/tensorrtserver-1.12.0-py3-none-linux_x86_64.whl
ENV PYTHONPATH /workspace/dlrm
WORKDIR /workspace/dlrm
COPY . .

View file

@ -0,0 +1,278 @@
# Deploying the DLRM model using Triton Inference Server
The [NVIDIA Triton Inference Server](https://github.com/NVIDIA/trtis-inference-server) provides a datacenter and cloud inferencing solution optimized for NVIDIA GPUs. The server provides an inference service via an HTTP or gRPC endpoint, allowing remote clients to request inferencing for any number of GPU or CPU models being managed by the server.
This folder contains instructions for deploment and exemplary client application to run inference on
Triton Inference Server as well as detailed performance analysis.
## Table Of Contents
- [Running Triton Inference Server and client](#running-triton-inference-server-and-client)
- [Latency vs Throughput](#throughputlatency-results)
- [Dynamic batching support](#dynamic-batching-support)
## Running Triton Inference Server and client
The very first step of deployment is to acquire trained checkpoint and model configuration for this
checkpoint. Default model configuration are stored inside `dlrm/config` directory.
### Inference container
Every command below is called from special inference container. To build that container go to main
repository folder and call
`docker build -t dlrm-inference . -f triton/Dockerfile`
This command will download dependencies and build inference container. Then run shell inside the
container:
`docker run -it --rm --gpus device=0 --shm-size=1g --ulimit memlock=-1 --ulimit stack=67108864 --net=host -v <PATH_TO_MODEL_REPOSITORY>:/repository dlrm-inference bash`
Here `device=0,1,2,3` selects GPUs indexed by ordinals `0,1,2` and `3`, respectively. The server will see only these GPUs. If you write `device=all`, then the server will see all the available GPUs. `PATH_TO_MODEL_REPOSITORY` indicates location where
deployed models were stored.
### Deploying the model
To deploy model into Triton compatible format, `deployer.py` script can by used. This script is
meant to be run from inside deployment docker container.
```
usage: deployer.py [-h] (--ts-script | --ts-trace | --onnx) [--triton-no-cuda]
[--triton-model-name TRITON_MODEL_NAME]
[--triton-model-version TRITON_MODEL_VERSION]
[--triton-max-batch-size TRITON_MAX_BATCH_SIZE]
[--triton-dyn-batching-delay TRITON_DYN_BATCHING_DELAY]
[--triton-engine-count TRITON_ENGINE_COUNT]
[--save-dir SAVE_DIR]
...
optional arguments:
-h, --help show this help message and exit
--ts-script convert to torchscript using torch.jit.script
--ts-trace convert to torchscript using torch.jit.trace
--onnx convert to onnx using torch.onnx.export
triton related flags:
--triton-no-cuda Use the CPU for tracing.
--triton-model-name TRITON_MODEL_NAME
exports to appropriate directory structure for triton
--triton-model-version TRITON_MODEL_VERSION
exports to appropriate directory structure for triton
--triton-max-batch-size TRITON_MAX_BATCH_SIZE
Specifies the 'max_batch_size' in the triton model
config. See the triton documentation for more info.
--triton-dyn-batching-delay TRITON_DYN_BATCHING_DELAY
Determines the dynamic_batching queue delay in
milliseconds(ms) for the triton model config. Use '0'
or '-1' to specify static batching. See the triton
documentation for more info.
--triton-engine-count TRITON_ENGINE_COUNT
Specifies the 'instance_group' count value in the
triton model config. See the triton documentation for
more info.
--save-dir SAVE_DIR Saved model directory
other flags:
model_arguments arguments that will be ignored by deployer lib and
will be forwarded to your deployer script
```
Following model specific arguments have to be specified for model deployment:
```
--num_numerical_features NUM_FEATURES
Number of numerical features at network input.
--embedding_dim EMBEDDING_DIM
Embedding dimensionality.
--top_mlp_sizes TOP_MLP_SIZES [TOP_MLP_SIZES ...]
Units in layers of top MLP (default: 1024 1024 512 256 1).
--bottom_mlp_sizes BOTTOM_MLP_SIZES [BOTTOM_MLP_SIZES ...]
Units in layers of bottom MLP (default: 512 256 128).
--interaction_op {cat,dot}
Interaction operator to use.
--self_interaction
Enables self interaction.
--hash_indices
Hash indices for categorical features.
--dataset DATASET
Path to dataset directory contaning model_size.json file
describing input sizes for each embedding layer.
--batch_size BATCH_SIZE
Internal dataloader batch size, usually it is the same as batch size
specified in --triton-max-batch_size flag.
--fp16
Set a model for fp16 deployment.
--dump_perf_data DIRECTORY_NAME
Dump binary performance data that can by loaded by perf client.
--model_checkpoint MODEL_CHECKPOINT
Checkpoint file with trained model that is going to be deployed.
--cpu Export cpu model instead of gpu.
```
For example, to deploy model into onnx format, using half precision and max batch size 4096 called
`dlrm-onnx-16` execute:
`python triton/deployer.py --onnx --triton-model-name dlrm-onnx-16 --triton-max-batch-size 4096 --save-dir /repository -- --model_checkpoint /results/checkpoint --fp16 --batch_size 4096 --num_numerical_features 13 --embedding_dim 128 --top_mlp_sizes 1024 1024 512 256 1 --bottom_mlp_sizes 512 256 128 --interaction_op dot --hash_indices --dataset /data`
Where `model_checkpoint` is a checkpoint for a trained model with the same configuration as used during export and dataset (or at least dataset configuration)
is mounted under `/data`
### Running the Triton server
**NOTE: This step is executed outside inference container**
1. `docker pull nvcr.io/nvidia/tritonserver:20.03-py3`
2. `docker run -d --rm --gpus device=0 --ipc=host --network=host [--cpuset-cpus=0-15] -p 8000:8000 -p 8001:8001 -p 8002:8002 -v <PATH_TO_MODEL_REPOSITORY>:/models nvcr.io/nvidia/tritonserver:20.03-py3 trtserver --model-store=/models --log-verbose=1 --model-control-mode=explicit`
Here `device=0,1,2,3` selects GPUs indexed by ordinals `0,1,2` and `3`, respectively. The server will see only these GPUs. If you write `device=all`, then the server will see all the available GPUs. `PATH_TO_MODEL_REPOSITORY` indicates location where
deployed models were stored. Additional `--model-controle-mode` option allows to manually load and
unload models. This is especially useful when dealing with numerous large models like DLRM.
For models exported to onnx format and hosted inside onnx runtime it might be required to limit visible cpu to fully utlize gpu acceleration. Use `--cpuset-cpus` docker option for that.
### Running client
Exemplary client `client.py` allows to check model performance against synthetic or real validation
data. Client connects to Triton server and perform inference.
```
usage: client.py [-h] --triton-server-url TRITON_SERVER_URL
--triton-model-name TRITON_MODEL_NAME
[--triton-model-version TRITON_MODEL_VERSION]
[--protocol PROTOCOL] [-v] [-H HTTP_HEADER]
[--num_numerical_features NUM_NUMERICAL_FEATURES]
--dataset_config DATASET_CONFIG
[--inference_data INFERENCE_DATA] [--batch_size BATCH_SIZE]
[--fp16]
optional arguments:
-h, --help show this help message and exit
--triton-server-url TRITON_SERVER_URL
URL adress of trtion server (with port)
--triton-model-name TRITON_MODEL_NAME
Triton deployed model name
--triton-model-version TRITON_MODEL_VERSION
Triton model version
--protocol PROTOCOL Communication protocol (HTTP/GRPC)
-v, --verbose Verbose mode.
-H HTTP_HEADER HTTP headers to add to inference server requests.
Format is -H"Header:Value".
--num_numerical_features NUM_NUMERICAL_FEATURES
Number of numerical features as an input.
--dataset_config DATASET_CONFIG
Configuration file describing categorical features
--inference_data INFERENCE_DATA
Path to file with inference data.
--batch_size BATCH_SIZE
Inference request batch size
--fp16 Use 16bit for numerical input
```
To run inference on model exported in previous steps, using data located under
`/data/test_data.bin` execute:
`python triton/client.py --triton-server-url localhost:8000 --protocol HTTP --triton-model-name dlrm-onnx-16 --num_numerical_features 13 --dataset_config /data/model_size.json --inference_data /data/test_data.bin --batch_size 4096 --fp16`
or
`python triton/client.py --triton-server-url localhost:8001 --protocol GRPC --triton-model-name dlrm-onnx-16 --num_numerical_features 13 --dataset_config /data/model_size.json --inference_data /data/test_data.bin --batch_size 4096 --fp16`
### Gathering performance data
Performance data can be gathered using `perf_client` tool. To use this tool, performance data needs
to be dumped during deployment. To do that, use `--dump_perf_data` option for the deployer:
`python triton/deployer.py --onnx --triton-model-name dlrm-onnx-16 --triton-max-batch-size 4096 --save-dir /repository -- --model_checkpoint /results/checkpoint --fp16 --batch_size 4096 --num_numerical_features 13 --embedding_dim 128 --top_mlp_sizes 1024 1024 512 256 1 --bottom_mlp_sizes 512 256 128 --interaction_op dot --hash_indices --dataset /data --dump_perf_data /location/for/perfdata`
When perf data are dumped, `perf_client` can be used with following command:
`/workspace/install/bin/perf_client --max-threads 10 -m dlrm-onnx-16 -x 1 -p 5000 -v -i gRPC -u localhost:8001 -b 4096 -l 5000 --concurrency-range 1 --input-data /location/for/perfdata -f result.csv`
For more information about `perf_client` please refer to [official documentation](https://docs.nvidia.com/deeplearning/sdk/triton-inference-server-master-branch-guide/docs/optimization.html#perf-client).
## Throughput/Latency results
Throughput is measured in recommendations/second, and latency in milliseconds.
**ONNX FP16 inference (V100-32G)**
| **Batch Size** | **Throughput** | **Avg Latency** | **95% Latency** | **99% Latency** |
|----------------|----------------|-----------------|-----------------|-----------------|
| 1 | 432.4 rec/s | 2.31 ms | 2.42 ms | 2.51 ms |
| 8 | 3214.4 rec/s | 2.48 ms | 2.64 ms | 2.72 ms |
| 64 | 26924.8 rec/s | 2.37 ms | 2.50 ms | 2.57 ms |
| 512 | 190413 rec/s | 2.68 ms | 2.85 ms | 2.94 ms |
| 4096 | 891290 rec/s | 4.58 ms | 4.82 ms | 4.96 ms |
| 32768 | 1218970 rec/s | 26.85 ms | 27.43 ms | 28.81 ms |
| 65536 | 1245180 rec/s | 52.55 ms | 53.46 ms | 55.83 ms |
| 131072 | 1140330 rec/s | 115.24 ms | 117.56 ms | 120.32 ms |
**TorchScript FP16 inference (V100-32G)**
| **Batch Size** | **Throughput** | **Avg Latency** | **95% Latency** | **99% Latency** |
|----------------|----------------|-----------------|-----------------|-----------------|
| 1 | 399.6 rec/s | 2.50 ms | 2.56 ms | 2.70 ms |
| 8 | 3563.2 rec/s | 2.24 ms | 2.29 ms | 2.42 ms |
| 64 | 28288.2 rec/s | 2.26 ms | 2.33 ms | 2.41 ms |
| 512 | 220774 rec/s | 2.31 ms | 2.38 ms | 2.44 ms |
| 4096 | 1104280 rec/s | 3.70 ms | 3.78 ms | 3.86 ms |
| 32768 | 1428680 rec/s | 22.97 ms | 23.29 ms | 24.05 ms |
| 65536 | 1402470 rec/s | 46.80 ms | 48.12 ms | 52.88 ms |
| 131072 | 1546650 rec/s | 85.27 ms | 86.17 ms | 87.05 ms |
**TorchScript FP32 inference (V100-32G)**
| **Batch Size** | **Throughput** | **Avg Latency** | **95% Latency** | **99% Latency** |
|----------------|----------------|-----------------|-----------------|-----------------|
| 1 | 333.7 rec/s | 2.99 ms | 3.17 ms | 3.32 ms |
| 8 | 3092.8 rec/s | 2.58 ms | 2.79 ms | 2.91 ms |
| 64 | 24435.2 rec/s | 2.61 ms | 2.78 ms | 2.89 ms |
| 512 | 169216 rec/s | 3.02 ms | 3.14 ms | 3.19 ms |
| 4096 | 718438 rec/s | 5.69 ms | 5.93 ms | 6.08 ms |
| 32768 | 842138 rec/s | 38.96 ms | 39.68 ms | 41.02 ms |
| 65536 | 892138 rec/s | 73.53 ms | 74.56 ms | 74.99 ms |
| 131072 | 904397 rec/s | 146.11 ms | 149.88 ms | 151.43 ms |
**ONNX FP32 inference CPU (2x E5-2698 v4 @ 2.20GHz)**
| **Batch Size** | **Throughput** | **Avg Latency** | **95% Latency** | **99% Latency** |
|----------------|----------------|-----------------|-----------------|-----------------|
| 1 | 402.5 rec/s | 2.48 ms | 2.34 ms | 3.16 ms |
| 8 | 2316 rec/s | 3.39 ms | 2.89 ms | 6.93 ms |
| 64 | 9248 rec/s | 6.91 ms | 6.73 ms | 13.14 ms |
| 512 | 14643.3 rec/s | 35.00 ms | 42.77 ms | 69.24 ms |
| 4096 | 13926.4 rec/s | 291.28 ms | 321.90 ms | 490.06 ms |
| 32768 | 13107.2 rec/s | 2387.24 ms | 2395.80 ms | 2395.80 ms |
| 65536 | 14417.9 rec/s | 5008.26 ms | 5311.47 ms | 5311.47 ms |
| 131072 | 13107.2 rec/s | 10033.19 ms | 10416.43 ms | 10416.43 ms |
**TorchScript FP32 inference CPU (2x E5-2698 v4 @ 2.20GHz)**
| **Batch Size** | **Throughput** | **Avg Latency** | **95% Latency** | **99% Latency** |
|----------------|----------------|-----------------|-----------------|-----------------|
| 1 | 116.3 rec/s | 8.60 ms | 9.83 ms | 14.60 ms |
| 8 | 3723.2 rec/s | 2.14 ms | 2.55 ms | 2.78 ms |
| 64 | 3014.4 rec/s | 21.22 ms | 31.34 ms | 41.28 ms |
| 512 | 6451.2 rec/s | 79.69 ms | 106.00 ms | 296.39 ms |
| 4096 | 41984 rec/s | 97.71 ms | 118.70 ms | 123.37 ms |
| 32768 | 79735.5 rec/s | 407.98 ms | 426.64 ms | 430.66 ms |
| 65536 | 79021.8 rec/s | 852.90 ms | 902.39 ms | 911.46 ms |
| 131072 | 81264.6 rec/s | 1601.28 ms | 1694.64 ms | 1711.57 ms |
![Latency vs Throughput](./img/lat_vs_thr.png)
The plot above shows, that the GPU is saturated with batch size 4096. However, running inference with larger batches
might be faster, than running several inference requests. Therefore, we choose 65536 as the optimal batch size.
## Dynamic batching support
The Triton server has a dynamic batching mechanism built in, that can be enabled. When it is enabled, then the server creates
inference batches from the received requests. Since the output of the model is a single probability, the batch size of a
single request may be large. Here it is assumed to be 4096. With dynamic batching enabled, the server will concatenate requests of this size into
an inference batch. The upper bound of the size of the inference batch is set to 65536. All these parameters are configurable.
Performance results on a single V100-32G (ONNX FP16 model) for various numbers of simultaneous requests are shown in the figure below.
![Dynamic batching](./img/dyn_batch_concurrency.png)
The plot above shows, that if we have a 20ms upper bound on latency, then a single GPU can handle up to 8 concurrent requests.
This leads to total throughput of 1.776.030 recommendations/sec. This means 35520 recommendations within 20ms, on a single GPU.

View file

@ -0,0 +1,133 @@
# Copyright (c) 2020 NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import json
import torch
from dlrm.data import data_loader
from dlrm.data.synthetic_dataset import SyntheticDataset
from tqdm import tqdm
from tensorrtserver.api import *
from sklearn.metrics import roc_auc_score
from functools import partial
def get_data_loader(batch_size, *, data_file, model_config):
with open(model_config.dataset_config) as f:
categorical_sizes = list(json.load(f).values())
if data_file:
data = data_loader.CriteoBinDataset(data_file=data_file,
batch_size=batch_size, subset=None,
numerical_features=model_config.num_numerical_features,
categorical_features=len(categorical_sizes),
online_shuffle=False)
else:
data = SyntheticDataset(num_entries=batch_size * 1024, batch_size=batch_size,
dense_features=model_config.num_numerical_features,
categorical_feature_sizes=categorical_sizes,
device="cpu")
return torch.utils.data.DataLoader(data,
batch_size=None,
num_workers=0,
pin_memory=False)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--triton-server-url", type=str, required=True,
help="URL adress of trtion server (with port)")
parser.add_argument("--triton-model-name", type=str, required=True,
help="Triton deployed model name")
parser.add_argument("--triton-model-version", type=int, default=-1,
help="Triton model version")
parser.add_argument("--protocol", type=str, default="HTTP",
help="Communication protocol (HTTP/GRPC)")
parser.add_argument("-v", "--verbose", action="store_true", default=False,
help="Verbose mode.")
parser.add_argument('-H', dest='http_headers', metavar="HTTP_HEADER",
required=False, action='append',
help='HTTP headers to add to inference server requests. ' +
'Format is -H"Header:Value".')
parser.add_argument("--num_numerical_features", type=int, default=13)
parser.add_argument("--dataset_config", type=str, required=True)
parser.add_argument("--inference_data", type=str,
help="Path to file with inference data.")
parser.add_argument("--batch_size", type=int, default=1,
help="Inference request batch size")
parser.add_argument("--fp16", action="store_true", default=False,
help="Use 16bit for numerical input")
FLAGS = parser.parse_args()
FLAGS.protocol = ProtocolType.from_str(FLAGS.protocol)
# Create a health context, get the ready and live state of server.
health_ctx = ServerHealthContext(FLAGS.triton_server_url, FLAGS.protocol,
http_headers=FLAGS.http_headers, verbose=FLAGS.verbose)
print("Health for model {}".format(FLAGS.triton_model_name))
print("Live: {}".format(health_ctx.is_live()))
print("Ready: {}".format(health_ctx.is_ready()))
with ModelControlContext(FLAGS.triton_server_url, FLAGS.protocol) as ctx:
ctx.load(FLAGS.triton_model_name)
# Create a status context and get server status
status_ctx = ServerStatusContext(FLAGS.triton_server_url, FLAGS.protocol, FLAGS.triton_model_name,
http_headers=FLAGS.http_headers, verbose=FLAGS.verbose)
print("Status for model {}".format(FLAGS.triton_model_name))
print(status_ctx.get_server_status())
# Create the inference context for the model.
infer_ctx = InferContext(FLAGS.triton_server_url, FLAGS.protocol, FLAGS.triton_model_name,
FLAGS.triton_model_version,
http_headers=FLAGS.http_headers, verbose=FLAGS.verbose)
dataloader = get_data_loader(FLAGS.batch_size,
data_file=FLAGS.inference_data,
model_config=FLAGS)
results = []
tgt_list = []
for num, cat, target in tqdm(dataloader):
num = num.cpu().numpy()
if FLAGS.fp16:
num = num.astype(np.float16)
cat = cat.long().cpu().numpy()
input_dict = {"input__0": tuple(num[i] for i in range(len(num))),
"input__1": tuple(cat[i] for i in range(len(cat)))}
output_keys = ["output__0"]
output_dict = {x: InferContext.ResultFormat.RAW for x in output_keys}
result = infer_ctx.run(input_dict, output_dict, len(num))
results.append(result["output__0"])
tgt_list.append(target.cpu().numpy())
results = np.concatenate(results).squeeze()
tgt_list = np.concatenate(tgt_list)
score = roc_auc_score(tgt_list, results)
print(F"Model score: {score}")
with ModelControlContext(FLAGS.triton_server_url, FLAGS.protocol) as ctx:
ctx.unload(FLAGS.triton_model_name)

View file

@ -0,0 +1,127 @@
#!/usr/bin/python
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import torch
import argparse
import deployer_lib
import json
#
import sys
sys.path.append('../')
from dlrm.model import Dlrm
from dlrm.data.synthetic_dataset import SyntheticDataset
def get_model_args(model_args):
parser = argparse.ArgumentParser()
parser.add_argument("--batch_size", default=1, type=int)
parser.add_argument("--fp16", action="store_true", default=False)
parser.add_argument("--dump_perf_data", type=str, default=None)
parser.add_argument("--model_checkpoint", type=str, default=None)
parser.add_argument("--num_numerical_features", type=int, default=13)
parser.add_argument("--embedding_dim", type=int, default=128)
parser.add_argument("--top_mlp_sizes", type=int, nargs="+",
default=[1024, 1024, 512, 256, 1])
parser.add_argument("--bottom_mlp_sizes", type=int, nargs="+",
default=[512, 256, 128])
parser.add_argument("--interaction_op", type=str, default="dot",
choices=["dot", "cat"])
parser.add_argument("--self_interaction", default=False,
action="store_true")
parser.add_argument("--hash_indices", default=False,
action="store_true")
parser.add_argument("--cpu", default=False, action="store_true")
parser.add_argument("--dataset", type=str, required=True)
return parser.parse_args(model_args)
def initialize_model(args, categorical_sizes):
''' return model, ready to trace '''
base_device = "cuda:0" if not args.cpu else "cpu"
model_config = {
"top_mlp_sizes": args.top_mlp_sizes,
"bottom_mlp_sizes": args.bottom_mlp_sizes,
"embedding_dim": args.embedding_dim,
"interaction_op": args.interaction_op,
"self_interaction": args.self_interaction,
"categorical_feature_sizes": categorical_sizes,
"num_numerical_features": args.num_numerical_features,
"hash_indices": args.hash_indices,
"base_device": base_device
}
model = Dlrm.from_dict(model_config, sigmoid=True)
model.to(base_device)
if args.model_checkpoint:
model.load_state_dict(torch.load(args.model_checkpoint,
map_location="cpu"))
if args.fp16:
model = model.half()
return model
def get_dataloader(args, categorical_sizes):
dataset_test = SyntheticDataset(num_entries=2000,
batch_size=args.batch_size,
dense_features=args.num_numerical_features,
categorical_feature_sizes=categorical_sizes,
device="cpu" if args.cpu else "cuda:0")
class RemoveOutput:
def __init__(self, dataset):
self.dataset = dataset
def __getitem__(self, idx):
value = self.dataset[idx]
if args.fp16:
value = (value[0].half(), value[1].long(), value[2])
else:
value = (value[0], value[1].long(), value[2])
return value[:-1]
def __len__(self):
return len(self.dataset)
test_loader = torch.utils.data.DataLoader(RemoveOutput(dataset_test),
batch_size=None,
num_workers=0,
pin_memory=False)
return test_loader
if __name__=='__main__':
deployer, model_args = deployer_lib.create_deployer(sys.argv[1:],
get_model_args) # deployer and returns removed deployer arguments
with open(os.path.join(model_args.dataset, "model_size.json")) as f:
categorical_sizes = list(json.load(f).values())
model = initialize_model(model_args, categorical_sizes)
dataloader = get_dataloader(model_args, categorical_sizes)
if model_args.dump_perf_data:
input_0, input_1 = next(iter(dataloader))
if model_args.fp16:
input_0 = input_0.half()
os.makedirs(model_args.dump_perf_data, exist_ok=True)
input_0.detach().cpu().numpy()[0].tofile(os.path.join(model_args.dump_perf_data, "input__0"))
input_1.detach().cpu().numpy()[0].tofile(os.path.join(model_args.dump_perf_data, "input__1"))
deployer.deploy(dataloader, model)

View file

@ -0,0 +1,540 @@
#!/usr/bin/python
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
import shutil
import time
import json
import onnx
import torch
import argparse
import statistics
import onnxruntime
from collections import Counter
torch_type_to_triton_type = {
torch.bool: 'TYPE_BOOL',
torch.int8: 'TYPE_INT8',
torch.int16: 'TYPE_INT16',
torch.int32: 'TYPE_INT32',
torch.int64: 'TYPE_INT64',
torch.uint8: 'TYPE_UINT8',
torch.float16: 'TYPE_FP16',
torch.float32: 'TYPE_FP32',
torch.float64: 'TYPE_FP64'
}
CONFIG_TEMPLATE = r"""
name: "{model_name}"
platform: "{platform}"
max_batch_size: {max_batch_size}
input [
{spec_inputs}
]
output [
{spec_outputs}
]
{dynamic_batching}
{model_optimizations}
instance_group [
{{
count: {engine_count}
kind: KIND_GPU
gpus: [ {gpu_list} ]
}}
]
"""
INPUT_TEMPLATE = r"""
{{
name: "input__{num}"
data_type: {type}
dims: {dims}
{reshape}
}},"""
OUTPUT_TEMPLATE = r"""
{{
name: "output__{num}"
data_type: {type}
dims: {dims}
{reshape}
}},"""
MODEL_OPTIMIZATION_TEMPLATE = r"""
optimization {{
execution_accelerators {{
gpu_execution_accelerator: [
{{
name: "tensorrt"
}}
]
}}
}}
"""
def remove_empty_lines(text):
''' removes empty lines from text, returns the result '''
ret = "".join([s for s in text.strip().splitlines(True) if s.strip()])
return ret
def create_deployer(argv, model_args_parser):
''' takes a list of arguments, returns a deployer object and the list of unused arguments '''
parser = argparse.ArgumentParser()
# required args
method = parser.add_mutually_exclusive_group(required=True)
method.add_argument('--ts-script',
action='store_true',
help='convert to torchscript using torch.jit.script')
method.add_argument('--ts-trace',
action='store_true',
help='convert to torchscript using torch.jit.trace')
method.add_argument('--onnx',
action='store_true',
help='convert to onnx using torch.onnx.export')
# triton related args
arguments = parser.add_argument_group('triton related flags')
arguments.add_argument('--triton-no-cuda',
action='store_true',
help='Use the CPU for tracing.')
arguments.add_argument(
'--triton-model-name',
type=str,
default="model",
help="exports to appropriate directory structure for triton")
arguments.add_argument(
"--triton-model-version",
type=int,
default=1,
help="exports to appropriate directory structure for triton")
arguments.add_argument(
"--triton-max-batch-size",
type=int,
default=8,
help="Specifies the 'max_batch_size' in the triton model config.\
See the triton documentation for more info.")
arguments.add_argument(
"--triton-dyn-batching-delay",
type=float,
default=0,
help=
"Determines the dynamic_batching queue delay in milliseconds(ms) for\
the triton model config. Use '0' or '-1' to specify static batching.\
See the triton documentation for more info.")
arguments.add_argument(
"--triton-engine-count",
type=int,
default=1,
help=
"Specifies the 'instance_group' count value in the triton model config.\
See the triton documentation for more info.")
arguments.add_argument('--save-dir',
type=str,
default='./triton_models',
help='Saved model directory')
# other args
arguments = parser.add_argument_group('other flags')
# remainder args
arguments.add_argument(
'model_arguments',
nargs=argparse.REMAINDER,
help=
'arguments that will be ignored by deployer lib and will be forwarded to your deployer script'
)
#
args = parser.parse_args(argv)
model_args = model_args_parser(args.model_arguments[1:])
model_args_no_def = {
k: v
for k, v in vars(model_args).items()
if k in [arg[2:] for arg in args.model_arguments[1:]]
}
deployer = Deployer(args, model_args_no_def)
#
return deployer, model_args
class DeployerLibrary:
def __init__(self, args, model_args):
self.args = args
self.model_args = model_args
self.platform = None
def set_platform(self, platform):
''' sets the platform
:: platform :: "pytorch_libtorch" or "onnxruntime_onnx"
'''
self.platform = platform
def prepare_inputs(self, dataloader, device):
''' load sample inputs to device '''
inputs = []
for batch in dataloader:
if type(batch) is torch.Tensor:
batch_d = batch.to(device)
batch_d = (batch_d, )
inputs.append(batch_d)
else:
batch_d = []
for x in batch:
assert type(x) is torch.Tensor, "input is not a tensor"
batch_d.append(x.to(device) if device else x)
batch_d = tuple(batch_d)
inputs.append(batch_d)
return inputs
def get_list_of_shapes(self, l, fun):
''' returns the list of min/max shapes, depending on fun
:: l :: list of tuples of tensors
:: fun :: min or max
'''
tensor_tuple = l[0]
shapes = [list(x.shape) for x in tensor_tuple]
for tensor_tuple in l:
assert len(tensor_tuple) == len(
shapes), "tensors with varying shape lengths are not supported"
for i, x in enumerate(tensor_tuple):
for j in range(len(x.shape)):
shapes[i][j] = fun(shapes[i][j], x.shape[j])
return shapes # a list of shapes
def get_tuple_of_min_shapes(self, l):
''' returns the tuple of min shapes
:: l :: list of tuples of tensors '''
shapes = self.get_list_of_shapes(l, min)
min_batch = 1
shapes = [[min_batch, *shape[1:]] for shape in shapes]
shapes = tuple(shapes)
return shapes # tuple of min shapes
def get_tuple_of_max_shapes(self, l):
''' returns the tuple of max shapes
:: l :: list of tuples of tensors '''
shapes = self.get_list_of_shapes(l, max)
max_batch = max(2, shapes[0][0])
shapes = [[max_batch, *shape[1:]] for shape in shapes]
shapes = tuple(shapes)
return shapes # tuple of max shapes
def get_tuple_of_opt_shapes(self, l):
''' returns the tuple of opt shapes
:: l :: list of tuples of tensors '''
counter = Counter()
for tensor_tuple in l:
shapes = [x.shape for x in tensor_tuple]
shapes = tuple(shapes)
counter[shapes] += 1
shapes = counter.most_common(1)[0][0]
return shapes # tuple of most common occuring shapes
def get_tuple_of_dynamic_shapes(self, l):
''' returns a tuple of dynamic shapes: variable tensor dimensions
(for ex. batch size) occur as -1 in the tuple
:: l :: list of tuples of tensors '''
tensor_tuple = l[0]
shapes = [list(x.shape) for x in tensor_tuple]
for tensor_tuple in l:
err_msg = "tensors with varying shape lengths are not supported"
assert len(tensor_tuple) == len(shapes), err_msg
for i, x in enumerate(tensor_tuple):
for j in range(len(x.shape)):
if shapes[i][j] != x.shape[j] or j == 0:
shapes[i][j] = -1
shapes = tuple(shapes)
return shapes # tuple of dynamic shapes
def run_models(self, models, inputs):
''' run the models on inputs, return the outputs and execution times '''
ret = []
for model in models:
torch.cuda.synchronize()
time_start = time.time()
outputs = []
for input in inputs:
with torch.no_grad():
output = model(*input)
if type(output) is torch.Tensor:
output = [output]
outputs.append(output)
torch.cuda.synchronize()
time_end = time.time()
t = time_end - time_start
ret.append(outputs)
ret.append(t)
return ret
def compute_errors(self, outputs_A, outputs_B):
''' returns the list of L_inf errors computed over every single output tensor '''
Linf_errors = []
for output_A, output_B in zip(outputs_A, outputs_B):
for x, y in zip(output_A, output_B):
error = (x - y).norm(float('inf')).item()
Linf_errors.append(error)
return Linf_errors
def print_errors(self, Linf_errors):
''' print various statistcs of Linf errors '''
print()
print("conversion correctness test results")
print("-----------------------------------")
print("maximal absolute error over dataset (L_inf): ",
max(Linf_errors))
print()
print("average L_inf error over output tensors: ",
statistics.mean(Linf_errors))
print("variance of L_inf error over output tensors: ",
statistics.variance(Linf_errors))
print("stddev of L_inf error over output tensors: ",
statistics.stdev(Linf_errors))
print()
def write_config(self,
config_filename,
input_shapes,
input_types,
output_shapes,
output_types):
''' writes triton config file
:: config_filename :: the file to write the config file into
:: input_shapes :: tuple of dynamic shapes of the input tensors
:: input_types :: tuple of torch types of the input tensors
:: output_shapes :: tuple of dynamic shapes of the output tensors
:: output_types :: tuple of torch types of the output tensors
'''
assert self.platform is not None, "error - platform is not set"
config_template = CONFIG_TEMPLATE
accelerator_template = MODEL_OPTIMIZATION_TEMPLATE
input_template = INPUT_TEMPLATE
spec_inputs = r""""""
for i,(shape,typ) in enumerate(zip(input_shapes,input_types)):
d = {
'num' : str(i),
'type': torch_type_to_triton_type[typ],
'dims': str([1]) if len(shape) == 1 else str(list(shape)[1:]) # first dimension is the batch size
}
d['reshape'] = 'reshape: { shape: [ ] }' if len(shape) == 1 else ''
spec_inputs += input_template.format_map(d)
spec_inputs = spec_inputs[:-1]
output_template = OUTPUT_TEMPLATE
spec_outputs = r""""""
for i,(shape,typ) in enumerate(zip(output_shapes,output_types)):
d = {
'num' : str(i),
'type': torch_type_to_triton_type[typ],
'dims': str([1]) if len(shape) == 1 else str(list(shape)[1:]) # first dimension is the batch size
}
d['reshape'] = 'reshape: { shape: [ ] }' if len(shape) == 1 else ''
spec_outputs += output_template.format_map(d)
spec_outputs = spec_outputs[:-1]
batching_str = ""
parameters_str = ""
max_batch_size = self.args.triton_max_batch_size
accelerator_str = ""
if (self.args.triton_dyn_batching_delay > 0):
# Use only full and half full batches
pref_batch_size = [int(max_batch_size / 2.0), max_batch_size]
batching_str = r"""
dynamic_batching {{
preferred_batch_size: [{0}]
max_queue_delay_microseconds: {1}
}}""".format(", ".join([str(x) for x in pref_batch_size]),
int(self.args.triton_dyn_batching_delay * 1000.0))
if self.platform == 'onnxruntime_onnx':
accelerator_str = accelerator_template.format_map({})
config_values = {
"model_name":
self.args.triton_model_name,
"platform":
self.platform,
"max_batch_size":
max_batch_size,
"spec_inputs":
spec_inputs,
"spec_outputs":
spec_outputs,
"dynamic_batching":
batching_str,
"model_parameters":
parameters_str,
"model_optimizations":
accelerator_str,
"gpu_list":
", ".join([str(x) for x in range(torch.cuda.device_count())]),
"engine_count":
self.args.triton_engine_count
}
# write config
with open(config_filename, "w") as file:
final_config_str = config_template.format_map(config_values)
final_config_str = remove_empty_lines(final_config_str)
file.write(final_config_str)
class Deployer:
def __init__(self, args, model_args):
self.args = args
self.lib = DeployerLibrary(args, model_args)
def deploy(self, dataloader, model):
''' deploy the model and test for correctness with dataloader '''
if self.args.ts_script or self.args.ts_trace:
self.lib.set_platform("pytorch_libtorch")
print("deploying model " + self.args.triton_model_name +
" in format " + self.lib.platform)
self.to_triton_torchscript(dataloader, model)
elif self.args.onnx:
self.lib.set_platform("onnxruntime_onnx")
print("deploying model " + self.args.triton_model_name +
" in format " + self.lib.platform)
self.to_triton_onnx(dataloader, model)
else:
assert False, "error"
print("done")
def to_triton_onnx(self, dataloader, model):
''' export the model to onnx and test correctness on dataloader '''
model.eval()
assert not model.training, "internal error - model should be in eval() mode! "
# prepare inputs
inputs = self.lib.prepare_inputs(dataloader, device=None)
# generate outputs
outputs = []
for input in inputs:
with torch.no_grad():
output = model(*input)
if type(output) is torch.Tensor:
output = [output]
outputs.append(output)
# generate input shapes - dynamic tensor shape support
input_shapes = self.lib.get_tuple_of_dynamic_shapes(inputs)
# generate output shapes - dynamic tensor shape support
output_shapes = self.lib.get_tuple_of_dynamic_shapes(outputs)
# generate input types
input_types = [x.dtype for x in inputs[0]]
# generate output types
output_types = [x.dtype for x in outputs[0]]
# get input names
rng = range(len(input_types))
input_names = ["input__" + str(num) for num in rng]
# get output names
rng = range(len(output_types))
output_names = ["output__" + str(num) for num in rng]
# prepare save path
model_folder = os.path.join(self.args.save_dir, self.args.triton_model_name)
version_folder = os.path.join(model_folder, str(self.args.triton_model_version))
if not os.path.exists(version_folder):
os.makedirs(version_folder)
final_model_path = os.path.join(version_folder, 'model.onnx')
if not os.path.exists(final_model_path):
os.makedirs(final_model_path)
final_model_path = os.path.join(final_model_path, 'model.onnx')
# get indices of dynamic input and output shapes
dynamic_axes = {}
for input_name,input_shape in zip(input_names,input_shapes):
dynamic_axes[input_name] = [i for i,x in enumerate(input_shape) if x == -1]
for output_name,output_shape in zip(output_names,output_shapes):
dynamic_axes[output_name] = [i for i,x in enumerate(output_shape) if x == -1]
# export the model
assert not model.training, "internal error - model should be in eval() mode! "
with torch.no_grad():
torch.onnx.export(model, inputs[0], final_model_path, verbose=False,
input_names=input_names, output_names=output_names,
dynamic_axes=dynamic_axes, opset_version=11,
use_external_data_format=True)
config_filename = os.path.join(model_folder, "config.pbtxt")
self.lib.write_config(config_filename,
input_shapes, input_types,
output_shapes, output_types)
def to_triton_torchscript(self, dataloader, model):
''' export the model to torchscript and test correctness on dataloader '''
model.eval()
assert not model.training, "internal error - model should be in eval() mode! "
# prepare inputs
inputs = self.lib.prepare_inputs(dataloader, device=None)
# generate input shapes - dynamic tensor shape support
input_shapes = self.lib.get_tuple_of_dynamic_shapes(inputs)
# generate input types
input_types = [x.dtype for x in inputs[0]]
# prepare save path
model_folder = os.path.join(self.args.save_dir, self.args.triton_model_name)
version_folder = os.path.join(model_folder, str(self.args.triton_model_version))
if not os.path.exists(version_folder):
os.makedirs(version_folder)
final_model_path = os.path.join(version_folder, 'model.pt')
# convert the model
with torch.no_grad():
if self.args.ts_trace: # trace it
model_ts = torch.jit.trace(model, inputs[0])
if self.args.ts_script: # script it
model_ts = torch.jit.script(model)
# generate outputs
outputs = []
for input in inputs:
with torch.no_grad():
output = model(*input)
if type(output) is torch.Tensor:
output = [output]
outputs.append(output)
# save the model
torch.jit.save(model_ts, final_model_path)
# generate output shapes - dynamic tensor shape support
output_shapes = self.lib.get_tuple_of_dynamic_shapes(outputs)
# generate output types
output_types = [x.dtype for x in outputs[0]]
# now we build the config for triton
config_filename = os.path.join(model_folder, "config.pbtxt")
self.lib.write_config(config_filename,
input_shapes, input_types,
output_shapes, output_types)

Binary file not shown.

After

Width:  |  Height:  |  Size: 27 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 44 KiB

View file

@ -25,13 +25,14 @@ The examples are organized first by framework, such as TensorFlow, PyTorch, etc.
- __VNet__ [[TensorFlow](https://github.com/NVIDIA/DeepLearningExamples/tree/master/TensorFlow/Segmentation/VNet)]
### Natural Language Processing
- __BERT__ [[PyTorch](https://github.com/NVIDIA/DeepLearningExamples/tree/master/PyTorch/LanguageModeling/BERT)] [[TensorFlow](https://github.com/NVIDIA/DeepLearningExamples/tree/master/TensorFlow/LanguageModeling/BERT)]
- __GNMT__ [[PyTorch](https://github.com/NVIDIA/DeepLearningExamples/tree/master/PyTorch/Translation/GNMT)] [[TensorFlow](https://github.com/NVIDIA/DeepLearningExamples/tree/master/TensorFlow/Translation/GNMT)]
- __Transformer__ [[PyTorch](https://github.com/NVIDIA/DeepLearningExamples/tree/master/PyTorch/Translation/Transformer)]
- __BERT__ [[PyTorch](https://github.com/NVIDIA/DeepLearningExamples/tree/master/PyTorch/LanguageModeling/BERT)] [[TensorFlow](https://github.com/NVIDIA/DeepLearningExamples/tree/master/TensorFlow/LanguageModeling/BERT)]
- __Transformer-XL__ [[PyTorch](https://github.com/NVIDIA/DeepLearningExamples/tree/master/PyTorch/LanguageModeling/Transformer-XL)] [[TensorFlow](https://github.com/NVIDIA/DeepLearningExamples/tree/master/TensorFlow/LanguageModeling/Transformer-XL)]
### Recommender Systems
- __DLRM__ [[PyTorch](https://github.com/NVIDIA/DeepLearningExamples/tree/master/PyTorch/Recommendation/DLRM)]
- __NCF__ [[PyTorch](https://github.com/NVIDIA/DeepLearningExamples/tree/master/PyTorch/Recommendation/NCF)] [[TensorFlow](https://github.com/NVIDIA/DeepLearningExamples/tree/master/TensorFlow/Recommendation/NCF)]
- __VAE-CF__ [[TensorFlow](https://github.com/NVIDIA/DeepLearningExamples/tree/master/TensorFlow/Recommendation/VAE-CF)]
- __WideAndDeep__ [[TensorFlow](https://github.com/NVIDIA/DeepLearningExamples/tree/master/TensorFlow/Recommendation/WideAndDeep)]
@ -70,6 +71,7 @@ The examples are organized first by framework, such as TensorFlow, PyTorch, etc.
| [BERT](https://github.com/NVIDIA/DeepLearningExamples/tree/master/PyTorch/LanguageModeling/BERT) |PyTorch | N/A | Yes | Yes | Yes | - | - | [Yes](https://github.com/NVIDIA/DeepLearningExamples/tree/master/PyTorch/LanguageModeling/BERT/triton) | - |
| [Transformer-XL](https://github.com/NVIDIA/DeepLearningExamples/tree/master/PyTorch/LanguageModeling/Transformer-XL) |PyTorch | N/A | Yes | Yes | Yes | - | - | - | - |
| [Neural Collaborative Filtering](https://github.com/NVIDIA/DeepLearningExamples/tree/master/PyTorch/Recommendation/NCF) |PyTorch | N/A | Yes | Yes | - | - |- | - | - |
| [DLRM](https://github.com/NVIDIA/DeepLearningExamples/tree/master/PyTorch/Recommendation/NCF) |PyTorch | N/A | Yes | Yes | - | - |- | - | - |
| [Mask R-CNN](https://github.com/NVIDIA/DeepLearningExamples/tree/master/PyTorch/Segmentation/MaskRCNN) |PyTorch | N/A | Yes | Yes | - | - | - | - | - |
| [Jasper](https://github.com/NVIDIA/DeepLearningExamples/tree/master/PyTorch/SpeechRecognition/Jasper) |PyTorch | N/A | Yes | Yes | - | Yes | Yes | [Yes](https://github.com/NVIDIA/DeepLearningExamples/tree/master/PyTorch/SpeechRecognition/Jasper/trtis) | - |
| [Tacotron 2 And WaveGlow v1.10](https://github.com/NVIDIA/DeepLearningExamples/tree/master/PyTorch/SpeechSynthesis/Tacotron2) | PyTorch | N/A | Yes | Yes | - | Yes | Yes | [Yes](https://github.com/NVIDIA/DeepLearningExamples/tree/master/PyTorch/SpeechSynthesis/Tacotron2/notebooks/trtis) | - |