Merge pull request #636 from NVIDIA/gh/release

[VAE/TF] Updating for Ampere
This commit is contained in:
nv-kkudrynski 2020-08-05 20:55:02 +02:00 committed by GitHub
commit 280e75c63e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 305 additions and 276 deletions

View file

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
ARG FROM_IMAGE_NAME=nvcr.io/nvidia/tensorflow:19.11-tf1-py3
ARG FROM_IMAGE_NAME=nvcr.io/nvidia/tensorflow:20.06-tf1-py3
FROM ${FROM_IMAGE_NAME}
ADD requirements.txt .

View file

@ -4,48 +4,53 @@ This repository provides a script and recipe to train the Variational Autoencode
## Table Of Contents
- [Model overview](#model-overview)
* [Model architecture](#model-architecture)
* [Default configuration](#default-configuration)
* [Feature support matrix](#feature-support-matrix)
* [Features](#features)
* [Mixed precision training](#mixed-precision-training)
* [Enabling mixed precision](#enabling-mixed-precision)
- [Setup](#setup)
* [Requirements](#requirements)
- [Quick Start Guide](#quick-start-guide)
- [Advanced](#advanced)
* [Scripts and sample code](#scripts-and-sample-code)
* [Parameters](#parameters)
* [Command-line options](#command-line-options)
* [Getting the data](#getting-the-data)
* [Model overview](#model-overview)
* [Model architecture](#model-architecture)
* [Default configuration](#default-configuration)
* [Feature support matrix](#feature-support-matrix)
* [Features](#features)
* [Mixed precision training](#mixed-precision-training)
* [Enabling mixed precision](#enabling-mixed-precision)
* [Enabling TF32](#enabling-tf32)
* [Setup](#setup)
* [Requirements](#requirements)
* [Quick Start Guide](#quick-start-guide)
* [Advanced](#advanced)
* [Scripts and sample code](#scripts-and-sample-code)
* [Parameters](#parameters)
* [Command-line options](#command-line-options)
* [Getting the data](#getting-the-data)
* [Dataset guidelines](#dataset-guidelines)
* [Training process](#training-process)
* [Inference process](#inference-process)
- [Performance](#performance)
* [Benchmarking](#benchmarking)
* [Training process](#training-process)
* [Inference process](#inference-process)
* [Performance](#performance)
* [Benchmarking](#benchmarking)
* [Training performance benchmark](#training-performance-benchmark)
* [Inference performance benchmark](#inference-performance-benchmark)
* [Results](#results)
* [Results](#results)
* [Training accuracy results](#training-accuracy-results)
* [Training accuracy: NVIDIA DGX-1 (8x V100 16G)](#training-accuracy-nvidia-dgx-1-8x-v100-16g)
* [Training accuracy: NVIDIA DGX A100 (8x A100 40GB)](#training-accuracy-nvidia-dgx-a100-8x-a100-40gb)
* [Training accuracy: NVIDIA DGX-1 (8x V100 32GB)](#training-accuracy-nvidia-dgx-1-8x-v100-32gb)
* [Training performance results](#training-performance-results)
* [Training performance: NVIDIA DGX-1 (8x V100 16G)](#training-performance-nvidia-dgx-1-8x-v100-16g)
* [Training performance: NVIDIA DGX A100 (8x A100 40GB)](#training-performance-nvidia-dgx-a100-8x-a100-40gb)
* [Training performance: NVIDIA DGX-1 (8x V100 32GB)](#training-performance-nvidia-dgx-1-8x-v100-32gb)
* [Inference performance results](#inference-performance-results)
* [Inference performance: NVIDIA DGX-1 (1x V100 16G)](#inference-performance-nvidia-dgx-1-1x-v100-16g)
- [Release notes](#release-notes)
* [Changelog](#changelog)
* [Known issues](#known-issues)
* [Inference performance: NVIDIA DGX A100 (1x A100 40GB)](#inference-performance-nvidia-dgx-a100-1x-a100-40gb)
* [Inference performance: NVIDIA DGX-1 (1x V100 16GB)](#inference-performance-nvidia-dgx-1-1x-v100-16gb)
* [Release notes](#release-notes)
* [Changelog](#changelog)
* [Known issues](#known-issues)
* [AMP speedup for Ampere](#amp-speedup-for-ampere)
* [Multi-GPU scaling](#multi-gpu-scaling)
## Model overview
The Variational Autoencoder (VAE) shown here is an optimized implementation of the architecture first described in [Variational Autoencoders for Collaborative Filtering](https://arxiv.org/abs/1802.05814) and can be used for recommendation tasks. The main differences between this model and the original one are the performance optimizations, such as using sparse matrices, mixed precision, larger mini-batches and multiple GPUs. These changes enabled us to achieve a significantly better speed while maintaining the same accuracy. Because of our fast implementation, weve also been able to carry out an extensive hyperparameter search to slightly improve the accuracy metrics.
The Variational Autoencoder (VAE) shown here is an optimized implementation of the architecture first described in [Variational Autoencoders for Collaborative Filtering](https://arxiv.org/abs/1802.05814) and can be used for recommendation tasks. The main differences between this model and the original one are the performance optimizations, such as using sparse matrices, mixed precision, larger mini-batches and multiple GPUs. These changes enabled us to achieve a significantly higher speed while maintaining the same accuracy. Because of our fast implementation, we've also been able to carry out an extensive hyperparameter search to slightly improve the accuracy metrics.
When using Variational Autoencoder for Collaborative Filtering (VAE-CF), you can quickly train a recommendation model for a collaborative filtering task. The required input data consists of pairs of user-item IDs for each interaction between a user and an item. With a trained model, you can run inference to predict what items are a new user most likely to interact with.
When using Variational Autoencoder for Collaborative Filtering (VAE-CF), you can quickly train a recommendation model for the collaborative filtering task. The required input data consists of pairs of user-item IDs for each interaction between a user and an item. With a trained model, you can run inference to predict what items is a new user most likely to interact with.
This model is trained with mixed precision using Tensor Cores on NVIDIA Volta and Turing GPUs. Therefore, researchers can get results 1.9x faster than training without Tensor Cores, while experiencing the benefits of mixed precision training. This model is tested against each NGC monthly container release to ensure consistent accuracy and performance over time.
This model is trained with mixed precision using Tensor Cores on NVIDIA Volta, Turing and Ampere GPUs. Therefore, researchers can get results 1.9x faster than training without Tensor Cores, while experiencing the benefits of mixed precision training. This model is tested against each NGC monthly container release to ensure consistent accuracy and performance over time.
This implementation has been initially developed as an educational project at the University of Warsaw by Albert Cieślak, Michał Filipiuk, Frederic Grabowski and Radosław Rowicki.
@ -57,49 +62,47 @@ This implementation has been initially developed as an educational project at th
Figure 1. The architecture of the VAE-CF model </p>
The Variational Autoencoder is a neural network that provides collaborative filtering based on implicit feedback. Specifically, it provides product recommendations based on user and item interactions. The training data for this model should contain a sequence of user ID, item ID pairs indicating that the specified user has interacted with, and the specified item.
The Variational Autoencoder is a neural network that provides collaborative filtering based on implicit feedback. Specifically, it provides product recommendations based on user and item interactions. The training data for this model should contain a sequence of (user ID, item ID) pairs indicating that the specified user has interacted with the specified item.
The model consists of two parts: the encoder and the decoder.
The encoder transforms the vector, that contains the interactions for a specific user, into an n-dimensional variational distribution. We can then use this variational distribution to obtain a latent representation of a user.
The model consists of two parts: the encoder and the decoder.
The encoder transforms the vector, which contains the interactions for a specific user, into a *n*-dimensional variational distribution. We can then use this variational distribution to obtain a latent representation of a user.
This latent representation is then fed into the decoder. The result is a vector of item interaction probabilities for a particular user.
### Default configuration
The following features were implemented in this model:
- general
- sparse matrix support
- data-parallel multi-GPU training
- dynamic loss scaling with backoff for tensor cores (mixed precision) training
- Sparse matrix support
- Data-parallel multi-GPU training
- Dynamic loss scaling with backoff for tensor cores (mixed precision) training
### Feature support matrix
The following features are supported by this model:
The following features are supported by this model:
| Feature | VAE-CF
|-----------------------|--------------------------
|Horovod Multi-GPU (NCCL) | Yes
|Automatic mixed precision (AMP) | Yes
| Feature | VAE-CF
|-----------------------|--------------------------
|Horovod Multi-GPU (NCCL) | Yes
|Automatic mixed precision (AMP) | Yes
#### Features
##### Horovod
Horovod:
Horovod is a distributed training framework for TensorFlow, Keras, PyTorch and MXNet. The goal of Horovod is to make distributed deep learning fast and easy to use. For more information about how to get started with Horovod, see the [Horovod: Official repository](https://github.com/horovod/horovod).
##### Multi-GPU training with Horovod
Multi-GPU training with Horovod:
Our model uses Horovod to implement efficient multi-GPU training with NCCL. For details, see example sources in this repository or see the [TensorFlow tutorial](https://github.com/horovod/horovod/#usage).
### Mixed precision training
Mixed precision is the combined use of different numerical precisions in a computational method. [Mixed precision](https://arxiv.org/abs/1710.03740) training offers significant computational speedup by performing operations in half-precision format, while storing minimal information in single-precision to retain as much information as possible in critical parts of the network. Since the introduction of [Tensor Cores](https://developer.nvidia.com/tensor-cores) in the Volta and Turing architecture, significant training speedups are experienced by switching to mixed precision -- up to 3x overall speedup on the most arithmetically intense model architectures. Using mixed precision training requires two steps:
1. Porting the model to use the FP16 data type where appropriate.
Mixed precision is the combined use of different numerical precisions in a computational method. [Mixed precision](https://arxiv.org/abs/1710.03740) training offers significant computational speedup by performing operations in half-precision format, while storing minimal information in single-precision to retain as much information as possible in critical parts of the network. Since the introduction of [Tensor Cores](https://developer.nvidia.com/tensor-cores) in Volta, and following with both the Turing and Ampere architectures, significant training speedups are experienced by switching to mixed precision -- up to 3x overall speedup on the most arithmetically intense model architectures. Using mixed precision training requires two steps:
1. Porting the model to use the FP16 data type where appropriate.
2. Adding loss scaling to preserve small gradient values.
The ability to train deep learning networks with lower precision was introduced in the Pascal architecture and first supported in [CUDA 8](https://devblogs.nvidia.com/parallelforall/tag/fp16/) in the NVIDIA Deep Learning SDK.
This can now be achieved using Automatic Mixed Precision (AMP) for TensorFlow to enable the full [mixed precision methodology](https://docs.nvidia.com/deeplearning/sdk/mixed-precision-training/index.html#tensorflow) in your existing TensorFlow model code. AMP enables mixed precision training on Volta, Turing, and NVIDIA Ampere GPU architectures automatically. The TensorFlow framework code makes all necessary model changes internally.
In TF-AMP, the computational graph is optimized to use as few casts as necessary and maximize the use of FP16, and the loss scaling is automatically applied inside of supported optimizers. AMP can be configured to work with the existing tf.contrib loss scaling manager by disabling the AMP scaling with a single environment variable to perform only the automatic mixed-precision optimization. It accomplishes this by automatically rewriting all computation graphs with the necessary operations to enable mixed precision training and automatic loss scaling.
For information about:
- How to train using mixed precision, see the [Mixed Precision Training](https://arxiv.org/abs/1710.03740) paper and [Training With Mixed Precision](https://docs.nvidia.com/deeplearning/sdk/mixed-precision-training/index.html) documentation.
@ -109,7 +112,32 @@ For information about:
#### Enabling mixed precision
To enable mixed precision in VAE-CF, run the `main.py` script with the `--use_tf_amp` flag.
Mixed precision is enabled in TensorFlow by using the Automatic Mixed Precision (TF-AMP) extension which casts variables to half-precision upon retrieval, while storing variables in single-precision format. Furthermore, to preserve small gradient magnitudes in backpropagation, a [loss scaling](https://docs.nvidia.com/deeplearning/sdk/mixed-precision-training/index.html#lossscaling) step must be included when applying gradients. In TensorFlow, loss scaling can be applied statically by using simple multiplication of loss by a constant value or automatically, by TF-AMP. Automatic mixed precision makes all the adjustments internally in TensorFlow, providing two benefits over manual operations. First, programmers need not modify network model code, reducing development and maintenance effort. Second, using AMP maintains forward and backward compatibility with all the APIs for defining and running TensorFlow models.
To enable mixed precision, you can simply add the values to the environmental variables inside your training script:
- Enable TF-AMP graph rewrite:
```
os.environ["TF_ENABLE_AUTO_MIXED_PRECISION_GRAPH_REWRITE"] = '1'
```
- Enable Automated Mixed Precision:
```
os.environ['TF_ENABLE_AUTO_MIXED_PRECISION'] = '1'
```
To enable mixed precision in VAE-CF, run the `main.py` script with the `--amp` flag.
#### Enabling TF32
TensorFloat-32 (TF32) is the new math mode in [NVIDIA A100](https://www.nvidia.com/en-us/data-center/a100/) GPUs for handling the matrix math also called tensor operations. TF32 running on Tensor Cores in A100 GPUs can provide up to 10x speedups compared to single-precision floating-point math (FP32) on Volta GPUs.
TF32 Tensor Cores can speed up networks using FP32, typically with no loss of accuracy. It is more robust than FP16 for models which require high dynamic range for weights or activations.
For more information, refer to the [TensorFloat-32 in the A100 GPU Accelerates AI Training, HPC up to 20x](https://blogs.nvidia.com/blog/2020/05/14/tensorfloat-32-precision-format/) blog post.
TF32 is supported in the NVIDIA Ampere GPU architecture and is enabled by default.
## Setup
@ -120,51 +148,78 @@ The following section lists the requirements that you need to meet in order to s
This repository contains Dockerfile which extends the Tensorflow NGC container and encapsulates some dependencies. Aside from these dependencies, ensure you have the following components:
- [NVIDIA Docker](https://github.com/NVIDIA/nvidia-docker)
- Tensorflow 19.11+ NGC container
- [NVIDIA Volta](https://www.nvidia.com/en-us/data-center/volta-gpu-architecture/) or [Turing](https://www.nvidia.com/en-us/geforce/turing/) based GPU
- TensorFlow-1 20.06+ NGC container
- Supported GPUs:
- [NVIDIA Volta architecture](https://www.nvidia.com/en-us/data-center/volta-gpu-architecture/)
- [NVIDIA Turing architecture](https://www.nvidia.com/en-us/geforce/turing/)
- [NVIDIA Ampere architecture](https://www.nvidia.com/en-us/data-center/nvidia-ampere-gpu-architecture/)
For more information about how to get started with NGC containers, see the following sections from the NVIDIA GPU Cloud Documentation and the Deep Learning Documentation:
- [Getting Started Using NVIDIA GPU Cloud](https://docs.nvidia.com/ngc/ngc-getting-started-guide/index.html)
- [Accessing And Pulling From The NGC Container Registry](https://docs.nvidia.com/deeplearning/frameworks/user-guide/index.html#accessing_registry)
- [Running TensorFlow](https://docs.nvidia.com/deeplearning/frameworks/tensorflow-release-notes/running.html#running)
- [Running TensorFlow](https://docs.nvidia.com/deeplearning/frameworks/tensorflow-release-notes/running.html#running)
For those unable to use the TensorFlow NGC container, to set up the required environment or create your own container, see the versioned [NVIDIA Container Support Matrix](https://docs.nvidia.com/deeplearning/frameworks/support-matrix/index.html).
## Quick Start Guide
To train your model using mixed precision with Tensor Cores or using FP32, perform the following steps using the default parameters of the VAE-CF model on the [MovieLens 20m dataset](https://grouplens.org/datasets/movielens/20m/). For the specifics concerning training and inference, see the [Advanced](#advanced) section.
To train your model using mixed or TF32 precision with Tensor Cores or using FP32, perform the following steps using the default parameters of the VAE-CF model on the [MovieLens 20m dataset](https://grouplens.org/datasets/movielens/20m/). For the specifics concerning training and inference, see the [Advanced](#advanced) section.
1. Clone the repository.
```bash
git clone https://github.com/NVIDIA/DeepLearningExamples
cd DeepLearningExamples/Tensorflow/Recommendation/VAE_CF
```
git clone https://github.com/NVIDIA/DeepLearningExamples
cd DeepLearningExamples/Tensorflow/Recommendation/VAE_CF
```
2. Build the VAE TensorFlow NGC container.
```bash
docker build . -t vae
```
```bash
docker build . -t vae
```
3. Launch the VAE-CF TensorFlow Docker container.
```bash
docker run -it --rm --runtime=nvidia -v /data/vae-cf:/data vae /bin/bash
```
4. Prepare the dataset.
```bash
python3 prepare_dataset.py
```
```bash
docker run -it --rm --runtime=nvidia -v /data/vae-cf:/data vae /bin/bash
```
5. Start training.
```bash
python3 main.py --train --use_tf_amp --checkpoint_dir ./checkpoints
```
4. Downloading the dataset: Here we use the [MovieLens 20m dataset](https://grouplens.org/datasets/movielens/20m/).
6. Start validation/evaluation.
The model is exported to the default `model_dir` and can be loaded and tested using:
```bash
python3 main.py --test --use_tf_amp --checkpoint_dir ./checkpoints
```
* If you do not have the dataset downloaded: Run the commands below to download and extract the MovieLens dataset to the ```/data/ml-20m/extracted/``` folder.
```
cd /data
mkdir ml-20m
cd ml-20m
mkdir extracted
cd extracted
wget http://files.grouplens.org/datasets/movielens/ml-20m.zip
unzip ml-20m.zip
```
* If you already have the dataset downloaded and unzipped elsewhere: Run the below commands to first exit the current VAE-CF Docker container and then Restart the VAE-CF Docker Container (like in Step 3 above) by mounting the MovieLens dataset location
```
exit
docker run -it --rm --runtime=nvidia -v /data/vae-cf:/data -v <ml-20m folder path>:/data/ml-20m/extracted/ml-20m vae /bin/bash
```
where, the unzipped MovieLens dataset is at ```<ml-20m folder path>```
5. Prepare the dataset.
```bash
python prepare_dataset.py
```
6. Start training on 8 GPUs.
```bash
mpirun --bind-to numa --allow-run-as-root -np 8 -H localhost:8 python main.py --train --amp --checkpoint_dir ./checkpoints
```
7. Start validation/evaluation.
The model is exported to the default `model_dir` and can be loaded and tested using:
```bash
python main.py --test --amp --checkpoint_dir ./checkpoints
```
## Advanced
@ -173,17 +228,28 @@ The following sections provide greater details of the dataset, running training
### Scripts and sample code
The `main.py` script provides an entry point to all the provided functionalities. This includes running training, testing and inference. The behavior of the script is controlled by command-line arguments listed below in the [Parameters](#parameters) section. The `prepare_dataset.py` script can be used to download and preprocess the MovieLens 20m dataset.
The `main.py` script provides an entry point to all the provided functionalities. This includes running training, testing and inference. The behavior of the script is controlled by command-line arguments listed below in the [Parameters](#parameters) section. The `prepare_dataset.py` script can be used to preprocess the MovieLens 20m dataset.
Most of the deep learning logic is implemented in the `vae/models` subdirectory. The `vae/load` subdirectory contains code for downloading and preprocessing the dataset. The `vae/metrics` subdirectory provides functions for computing the validation metrics such as recall and [NDCG](https://en.wikipedia.org/wiki/Discounted_cumulative_gain#Normalized_DCG).
Most of the deep learning logic is implemented in the `vae/models` subdirectory. The `vae/load` subdirectory contains the code for preprocessing the dataset. The `vae/metrics` subdirectory provides functions for computing the validation metrics such as recall and [NDCG](https://en.wikipedia.org/wiki/Discounted_cumulative_gain#Normalized_DCG).
### Parameters
To train a VAE-CF model in TensorFlow the following parameters are supported:
The most important command-line parameters include:
* `--data_dir` which specifies the directory inside the docker container where the data will be stored, overriding the default location ```/data```
* `--checkpoint_dir` which controls if and where the checkpoints will be stored
* `--amp` for enabling mixed precision training
```
usage: main.py [-h] [--train] [--test] [--inference] [--inference_benchmark]
[--use_tf_amp] [--epochs EPOCHS]
There are also multiple parameters controlling the various hyperparameters of the training process, such as the learning rate, batch size etc.
### Command-line options
To see the full list of available options and their descriptions, use the `-h` or `--help` command-line option, for example:
```bash
python main.py --help
usage: main.py [-h] [--train] [--test] [--inference_benchmark]
[--amp] [--epochs EPOCHS]
[--batch_size_train BATCH_SIZE_TRAIN]
[--batch_size_validation BATCH_SIZE_VALIDATION]
[--validation_step VALIDATION_STEP]
@ -200,12 +266,9 @@ optional arguments:
-h, --help show this help message and exit
--train Run training of VAE
--test Run validation of VAE
--inference Run inference on a single random example.This can also
be used to measure the latency for a batch size of 1
--inference_benchmark
Benchmark the inference throughput on a very large
batch size
--use_tf_amp Enable Automatic Mixed Precision
Benchmark the inference throughput and latency
--amp Enable Automatic Mixed Precision
--epochs EPOCHS Number of epochs to train
--batch_size_train BATCH_SIZE_TRAIN
Global batch size for training
@ -237,47 +300,39 @@ optional arguments:
```
### Command-line options
To see the full list of available options and their descriptions, use the `-h` or `--help` command-line option, for example:
```bash
python main.py --help
```
### Getting the data
The VA-CF model was trained on the [MovieLens 20M dataset](https://grouplens.org/datasets/movielens/20m/). The dataset can be downloaded and preprocessed simply by running: `python prepare_dataset.py` in the Docker container. By default, the dataset will be stored in the `/data` directory. If you want to store the data in a different location, you can pass the desired location to the `--data_dir` argument.
The VA-CF model was trained on the [MovieLens 20M dataset](https://grouplens.org/datasets/movielens/20m/). The dataset can be preprocessed simply by running: `python prepare_dataset.py` in the Docker container. By default, the dataset will be stored in the `/data` directory. If you want to store the data in a different location, you can pass the desired location to the `--data_dir` argument.
#### Dataset guidelines
As a Collaborative Filtering model, VAE-CF only uses information about which user interacted with which item. For the MovieLens dataset, this means that a particular user has positively reviewed a particular movie. VAE-CF can be adapted to any other collaborative filtering task. The input to the model is generally a list of all interactions between users and items. One column of the CSV should contain user IDs while the other should contain item IDs. Example preprocessing for the MovieLens 20M dataset is provided in the `vae/load/preprocessing.py` file.
As a Collaborative Filtering model, VAE-CF only uses information about which user interacted with which item. For the MovieLens dataset, this means that a particular user has positively reviewed a particular movie. VAE-CF can be adapted to any other collaborative filtering task. The input to the model is generally a list of all interactions between users and items. One column of the CSV should contain user IDs, while the other should contain item IDs. Preprocessing for the MovieLens 20M dataset is provided in the `vae/load/preprocessing.py` file.
### Training process
The training can be started by running the `main.py` script with the `train` argument. The resulting checkpoints containing the trained model weights are then stored in the directory specified by the `--checkpoint_dir` directory (by default, no checkpoints are saved).
The training can be started by running the `main.py` script with the `train` argument. The resulting checkpoints containing the trained model weights are then stored in the directory specified by the `--checkpoint_dir` directory (by default no checkpoints are saved).
Additionally, a command-line argument called `--results_dir` (by default None) can be used to enable saving some statistics to JSON files in a directory specified by this parameter. The statistics saved are:
1) a complete list of command-line arguments saved as `<results_dir>/args.json` and
2) a dictionary of validation metrics and performance metrics recorded during training
Additionally, a command-line argument called `--results_dir` (by default `None`) specifies where to save the following statistics in a JSON format:
1) a complete list of command-line arguments saved as `<results_dir>/args.json`, and
2) a dictionary of validation metrics and performance metrics recorded during training.
The main validation metric used is [NDCG@100](https://en.wikipedia.org/wiki/Discounted_cumulative_gain#Normalized_DCG). Following the original VAE-CF paper we also report numbers for Recall@20 and Recall@50.
Multi-GPU training uses horovod. You can run it with:
```horovodrun -np 8 -H localhost:8 python3 main.py --train --use_tf_amp```
Multi-GPU training uses horovod.
Mixed precision support is controlled by the `--use_tf_amp` command-line flag. It enables TensorFlows Automatic Mixed Precision mode.
Mixed precision support is controlled by the `--amp` command-line flag. It enables TensorFlows Automatic Mixed Precision mode.
### Inference process
Inference on a trained model can be run by passing the `--inference` argument to the main.py script, for example:
```
python3 main.py --inference --use_tf_amp --checkpoint_dir /checkpoints
```
This will generate a user with a collection of random items that they interacted with and run inference for that user. The result is a list of K recommended items the user is likely to interact with. You can control the number of items to be recommended by setting the `--top_results` command-line argument (by default 100).
Inference on a trained model can be run by passing the `--inference_benchmark` argument to the main.py script
```
python main.py --inference_benchmark [--amp] --checkpoint_dir ./checkpoints
```
This will generate a user with a collection of random items that they interacted with and run inference for that user multiple times to measure latency and throughput.
## Performance
@ -290,19 +345,16 @@ The following section shows how to run benchmarks measuring the model performanc
To benchmark the training performance, run:
```
horovodrun -np 8 -H localhost:8 python3 main.py --train --use_tf_amp
mpirun --bind-to numa --allow-run-as-root -np 8 -H localhost:8 python main.py --train [--amp]
```
Training benchmark was run on 8x V100 16G GPUs.
#### Inference performance benchmark
To benchmark the inference performance, run:
```
python3 main.py --inference_benchmark --use_tf_amp --batch_size_validation 24576
python main.py --inference_benchmark [--amp]
```
Inference benchmark was run on 1x V100 16G GPU.
### Results
@ -310,61 +362,106 @@ The following sections provide details on how we achieved our performance and ac
#### Training accuracy results
##### Training accuracy: NVIDIA DGX-1 (8x V100 16G)
All training performance results were obtained by running:
```
mpirun --bind-to numa --allow-run-as-root -np <gpus> -H localhost:8 python main.py --train [--amp]
```
in the TensorFlow 20.06 NGC container.
Our results were obtained by running the `main.py` training script in the TensorFlow 19.11 NGC container on NVIDIA DGX-1 with (8x V100 16G) GPUs.
| GPUs | Batch size / GPU | Accuracy - FP32 | Accuracy - mixed precision | Time to train - FP32 (s) | Time to train - mixed precision (s) | Time to train speedup (FP32 to mixed precision) |
|---|---|---|---|---|---|---|
| 1 | 24576 | 0.42863 | 0.42824 | 357.6| 205.9 | 1.737 |
| 8 | 3072 | 0.42763 | 0.42766 | 59.7 | 43.2 | 1.381 |
##### Training accuracy: NVIDIA DGX A100 (8x A100 40GB)
| GPUs | Batch size / GPU | Accuracy - TF32 | Accuracy - mixed precision | Time to train - TF32 [s] | Time to train - mixed precision [s] | Time to train speedup (TF32 to mixed precision)
|-------:|-----------------:|-------------:|-----------:|----------------:|--------------:|---------------:|
| 1 | 24,576 | 0.430298 | 0.430398 | 112.8 | 109.4 | 1.03 |
| 8 | 3,072 | 0.430897 | 0.430353 | 25.9 | 30.4 | 0.85 |
##### Training accuracy: NVIDIA DGX-1 (8x V100 32GB)
| GPUs | Batch size / GPU | Accuracy - FP32 | Accuracy - mixed precision | Time to train - FP32 [s] | Time to train - mixed precision [s] | Time to train speedup (FP32 to mixed precision) |
|-------:|-----------------:|-------------:|-----------:|----------------:|--------------:|---------------:|
| 1 | 24,576 | 0.430592 | 0.430525 | 346.5 | 186.5 | 1.86 |
| 8 | 3,072 | 0.430753 | 0.431202 | 59.1 | 42.2 | 1.40 |
#### Training performance results
##### Training performance: NVIDIA DGX-1 (8x V100 16G)
Performance numbers below show throughput in users processed per second. They were averaged over an entire training run.
Our results were obtained by running:
```
horovodrun -np 8 -H localhost:8 python3 main.py --train --use_tf_amp
```
in the TensorFlow 19.11 NGC container on NVIDIA DGX-1 with 8x V100 16G GPUs. Performance numbers (throughput in users processed per second) were averaged over an entire training run.
##### Training performance: NVIDIA DGX A100 (8x A100 40GB)
| GPUs | Batch size / GPU | Throughput - TF32 | Throughput - mixed precision | Throughput speedup (TF32 - mixed precision) | Strong scaling - TF32 | Strong scaling - mixed precision
|-------:|------------:|-------------------:|-----------------:|---------------------:|---:|---:|
| 1 | 24,576 | 354,032 | 365,474 | 1.03 | 1 | 1 |
| 8 | 3,072 | 1,660,700 | 1,409,770 | 0.85 | 4.69 | 3.86 |
##### Training performance: NVIDIA DGX-1 (8x V100 32GB)
| GPUs | Batch size / GPU | Throughput - FP32 | Throughput - mixed precision | Throughput speedup (FP32 - mixed precision) | Strong scaling - FP32 | Strong scaling - mixed precision |
|---|---|---|---|---|---|---|
| 1 | 24576| 116k | 219k | 1.897 | 1.00| 1.00|
| 8 | 3072 | 685k | 966k | 1.410 | 5.92 | 4.41 |
We use users processed per second as a throughput metric for measuring training performance.
|-------:|------------:|-------------------:|-----------------:|---------------------:|---:|---:|
| 1 | 24,576 | 114,125 | 213,283 | 1.87 | 1 | 1 |
| 8 | 3,072 | 697,628 | 1,001,210 | 1.44 | 6.11 | 4.69 |
#### Inference performance results
##### Inference performance: NVIDIA DGX-1 (1x V100 16G)
Our results were obtained by running:
```
python3 main.py --inference_benchmark --use_tf_amp --batch_size_validation 24576
python main.py --inference_benchmark [--amp]
```
in the TensorFlow 19.11 NGC container on NVIDIA DGX-1 with (1x V100 16G) GPU.
| GPUs | Batch size / GPU | Inference Throughput - FP32 | Inference Throughput - mixed precision | Inference Throughput speedup (FP32 - mixed precision) |
|---|---|---|---|---|
| 1 | 24576| 127k | 154k | 1.215 |
in the TensorFlow 20.06 NGC container.
We use users processed per second as a throughput metric for measuring inference performance.
All latency numbers are in seconds.
To achieve these same results, follow the steps in the [Quick Start Guide](#quick-start-guide).
##### Inference performance: NVIDIA DGX A100 (1x A100 40GB)
TF32
| Batch size | Throughput Avg | Latency Avg | Latency 90% | Latency 95% | Latency 99% |
|-------------:|-----------------:|--------------:|--------------:|--------------:|---------------:|
| 1 | 1181 | 0.000847 | 0.000863 | 0.000871 | 0.000901 |
FP16
| Batch size | Throughput Avg | Latency Avg | Latency 90% | Latency 95% | Latency 99% |
|-------------:|-----------------:|--------------:|--------------:|--------------:|---------------:|
| 1 | 1215 | 0.000823 | 0.000858 | 0.000864 | 0.000877 |
##### Inference performance: NVIDIA DGX-1 (1x V100 16GB)
FP32
| Batch size | Throughput Avg | Latency Avg | Latency 90% | Latency 95% | Latency 99% |
|-------------:|-----------------:|--------------:|--------------:|--------------:|---------------:|
| 1 | 718 | 0.001392 | 0.001443 | 0.001458 | 0.001499 |
FP16
| Batch size | Throughput Avg | Latency Avg | Latency 90% | Latency 95% | Latency 99% |
|-------------:|-----------------:|--------------:|--------------:|--------------:|---------------:|
| 1 | 707 | 0.001413 | 0.001511 | 0.001543 | 0.001622 |
## Release notes
### Changelog
July 2020
- Updated with Ampere convergence and performance results
November 2019
- Initial release
### Known issues
Multi-GPU scaling
#### AMP speedup for Ampere
In this model the TF32 precision can in some cases be as fast as the FP16 precision on Ampere GPUs.
This is because TF32 also uses Tensor Cores and doesn't need any additional logic
such as maintaining FP32 master weights and casts.
However, please note that VAE-CF is, by modern recommender standards, a very small model.
Larger models should still see significant benefits of using FP16 math.
#### Multi-GPU scaling
We benchmark this implementation on the ML-20m dataset so that our results are comparable to the original VAE-CF paper. We also use the same neural network architecture. As a consequence, the ratio of communication to computation is relatively large. This means that although using multiple GPUs speeds up the training substantially, the scaling efficiency is worse from what one would expect if using a larger model and a more realistic dataset.

View file

@ -1,6 +1,6 @@
#!/usr/bin/python3
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -15,14 +15,21 @@
# limitations under the License.
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
from functools import partial
import json
import logging
from argparse import ArgumentParser
import tensorflow as tf
tf.logging.set_verbosity(tf.logging.ERROR)
import numpy as np
import horovod.tensorflow as hvd
from mpi4py import MPI
import dllogger
import time
from vae.utils.round import round_8
from vae.metrics.recall import recall
@ -32,18 +39,16 @@ from vae.load.preprocessing import load_and_parse_ML_20M
def main():
hvd.init()
mpi_comm = MPI.COMM_WORLD
parser = ArgumentParser(description="Train a Variational Autoencoder for Collaborative Filtering in TensorFlow")
parser.add_argument('--train', action='store_true',
help='Run training of VAE')
parser.add_argument('--test', action='store_true',
help='Run validation of VAE')
parser.add_argument('--inference', action='store_true',
help='Run inference on a single random example.'
'This can also be used to measure the latency for a batch size of 1')
parser.add_argument('--inference_benchmark', action='store_true',
help='Benchmark the inference throughput on a very large batch size')
parser.add_argument('--use_tf_amp', action='store_true',
help='Measure inference latency and throughput on a variety of batch sizes')
parser.add_argument('--amp', action='store_true', default=False,
help='Enable Automatic Mixed Precision')
parser.add_argument('--epochs', type=int, default=400,
help='Number of epochs to train')
@ -85,6 +90,7 @@ def main():
default=None,
help='Path for saving a checkpoint after the training')
args = parser.parse_args()
args.world_size = hvd.size()
if args.batch_size_train % hvd.size() != 0:
raise ValueError('Global batch size should be a multiple of the number of workers')
@ -101,16 +107,27 @@ def main():
dllogger.init(backends=[])
logger.setLevel(logging.ERROR)
dllogger.log(data=vars(args), step='PARAMETER')
if args.seed is None:
if hvd.rank() == 0:
seed = int(time.time())
else:
seed = None
np.random.seed(args.seed)
tf.set_random_seed(args.seed)
seed = mpi_comm.bcast(seed, root=0)
else:
seed = args.seed
tf.random.set_random_seed(seed)
np.random.seed(seed)
args.seed = seed
dllogger.log(data=vars(args), step='PARAMETER')
# Suppress TF warnings
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
# set AMP
os.environ['TF_ENABLE_AUTO_MIXED_PRECISION'] = '1' if args.use_tf_amp else '0'
os.environ['TF_ENABLE_AUTO_MIXED_PRECISION'] = '1' if args.amp else '0'
# load dataset
(train_data,
@ -159,21 +176,36 @@ def main():
elif args.test and hvd.size() > 1:
print("Testing is not supported with horovod multigpu yet")
if args.inference_benchmark and hvd.size() <= 1:
# use the train data to get accurate throughput numbers for inference
# the test and validation sets are too small to measure this accurately
# vae.inference_benchmark()
_ = vae.test(test_data_input=train_data,
test_data_true=train_data, metrics={})
elif args.test and hvd.size() > 1:
print("Testing is not supported with horovod multigpu yet")
if args.inference:
input_data = np.random.randint(low=0, high=10000, size=10)
recommendations = vae.query(input_data=input_data)
print('Recommended item indices: ', recommendations)
if args.inference_benchmark:
items_per_user = 10
item_indices = np.random.randint(low=0, high=10000, size=items_per_user)
user_indices = np.zeros(len(item_indices))
indices = np.stack([user_indices, item_indices], axis=1)
num_batches = 200
latencies = []
for i in range(num_batches):
start_time = time.time()
_ = vae.query(indices=indices)
end_time = time.time()
if i < 10:
#warmup steps
continue
latencies.append(end_time - start_time)
result_data = {}
result_data[f'batch_1_mean_throughput'] = 1 / np.mean(latencies)
result_data[f'batch_1_mean_latency'] = np.mean(latencies)
result_data[f'batch_1_p90_latency'] = np.percentile(latencies, 90)
result_data[f'batch_1_p95_latency'] = np.percentile(latencies, 95)
result_data[f'batch_1_p99_latency'] = np.percentile(latencies, 99)
dllogger.log(data=result_data, step=tuple())
vae.close_session()
dllogger.flush()

View file

@ -1,96 +0,0 @@
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from os.path import basename, normpath
import urllib.request
import tarfile
import zipfile
from tqdm import tqdm
import itertools
from glob import glob
import logging
LOG = logging.getLogger("VAE")
def download_movielens(data_dir):
destination_filepath = os.path.join(data_dir, 'ml-20m/download/ml-20m.zip')
if not glob(destination_filepath):
ml_20m_download_url = 'http://files.grouplens.org/datasets/movielens/ml-20m.zip'
download_file(ml_20m_download_url, destination_filepath)
LOG.info("Extracting")
extract_file(destination_filepath, to_directory=os.path.join(data_dir, 'ml-20m/extracted'))
def download_file(url, filename):
if not os.path.isdir(os.path.dirname(filename)):
os.makedirs(os.path.dirname(filename))
u = urllib.request.urlopen(url)
with open(filename, 'wb') as f:
meta = u.info()
if (meta.get_all("Content-Length")):
file_size = int(meta.get_all("Content-Length")[0])
pbar = tqdm(
total=file_size,
desc=basename(normpath(filename)),
unit='B',
unit_scale=True)
file_size_dl = 0
block_sz = 8192
while True:
buff = u.read(block_sz)
if not buff:
break
pbar.update(len(buff))
file_size_dl += len(buff)
f.write(buff)
pbar.close()
else:
LOG.warning("No content length information")
file_size_dl = 0
block_sz = 8192
for cyc in itertools.cycle('/\\|'):
buff = u.read(block_sz)
if not buff:
break
print(cyc, end='\r')
file_size_dl += len(buff)
f.write(buff)
def extract_file(path, to_directory):
"""
Extract file
:param path: Path to compressed file
:param to_directory: Directory that is going to store extracte files
"""
if (path.endswith("tar.gz")):
tar = tarfile.open(path, "r:gz")
tar.extractall(path=to_directory)
tar.close()
elif (path.endswith("tar")):
tar = tarfile.open(path, "r:")
tar.extractall(path=to_directory)
tar.close()
elif (path.endswith("zip")):
with zipfile.ZipFile(path, 'r') as zip_ref:
zip_ref.extractall(to_directory)
else:
raise Exception(
"Could not extract {} as no appropriate extractor is found".format(path))

View file

@ -23,7 +23,6 @@ import scipy.sparse as sp
import numpy as np
from scipy.sparse import load_npz, csr_matrix
from vae.load.downloaders import download_movielens
import logging
import json
@ -69,7 +68,7 @@ def save_id_mappings(cache_dir, show2id, profile2id):
json.dump(d, f, indent=4)
def load_and_parse_ML_20M(data_dir, threshold=4):
def load_and_parse_ML_20M(data_dir, threshold=4, parse=True):
"""
Original way of processing ml-20m dataset from VAE for CF paper
Copyright [2018] [Dawen Liang, Rahul G. Krishnan, Matthew D. Hoffman, and Tony Jebara]
@ -98,11 +97,14 @@ def load_and_parse_ML_20M(data_dir, threshold=4):
load_npz(test_data_true_file), \
load_npz(test_data_test_file),
if not parse:
raise ValueError('Dataset not preprocessed. Please run python3 prepare_dataset.py first.')
LOG.info("Parsing movielens.")
source_file = os.path.join(data_dir, "ml-20m/extracted/ml-20m", "ratings.csv")
if not glob(source_file):
download_movielens(data_dir=data_dir)
raise ValueError('Dataset not downloaded. Please download the ML-20m dataset from https://grouplens.org/datasets/movielens/20m/, unzip it and put it in ', source_file)
raw_data = pd.read_csv(source_file)
raw_data.drop('timestamp', axis=1, inplace=True)

View file

@ -340,27 +340,21 @@ class VAE:
# Therefore we're using the nan-aware mean from numpy to ignore users with no items to be predicted.
return {name: np.nanmean(scores) for name, scores in metrics_scores.items()}
def query(self, input_data: np.ndarray):
def query(self, indices: np.ndarray):
"""
inference for batch size 1
:param input_data:
:return:
"""
query_start = time.time()
indices = np.stack([np.zeros(len(input_data)), input_data], axis=1)
values = np.ones(shape=(1, len(input_data)))
values = np.ones(shape=(1, len(indices)))
values = normalize(values)
values = values.reshape(-1)
sess_run_start = time.time()
res = self.session.run(
self.top_k_query,
feed_dict={self.inputs_query: (indices,
values)})
query_end_time = time.time()
LOG.info('query time: {}'.format(query_end_time - query_start))
LOG.info('sess run time: {}'.format(query_end_time - sess_run_start))
return res
def _increment_global_step(self):