DeepLearningExamples/PyTorch/Translation/Transformer/README.md
2021-07-21 14:39:48 +02:00

42 KiB
Raw Blame History

Transformer For PyTorch

This repository provides a script and recipe to train the Transformer model to achieve state of the art accuracy, and is tested and maintained by NVIDIA.

Table Of Contents

Model overview

The Transformer is a Neural Machine Translation (NMT) model which uses attention mechanism to boost training speed and overall accuracy. The Transformer model was introduced in Attention Is All You Need and improved in Scaling Neural Machine Translation. This implementation is based on the optimized implementation in Facebook's Fairseq NLP toolkit, built on top of PyTorch.

This model is trained with mixed precision using Tensor Cores on NVIDIA Volta, Turing and Ampere GPU architectures. Therefore, researchers can get results 6.5x faster than training without Tensor Cores, while experiencing the benefits of mixed precision training. This model is tested against each NGC monthly container release to ensure consistent accuracy and performance over time.

Model architecture

The Transformer model uses standard NMT encoder-decoder architecture. This model unlike other NMT models, uses no recurrent connections and operates on fixed size context window. The encoder stack is made up of N identical layers. Each layer is composed of the following sublayers: 1. Self-attention layer 2. Feedforward network (which is 2 fully-connected layers) Like the encoder stack, the decoder stack is made up of N identical layers. Each layer is composed of the sublayers: 1. Self-attention layer 2. Multi-headed attention layer combining encoder outputs with results from the previous self-attention layer. 3. Feedforward network (2 fully-connected layers)

The encoder uses self-attention to compute a representation of the input sequence. The decoder generates the output sequence one token at a time, taking the encoder output and previous decoder-outputted tokens as inputs. The model also applies embeddings on the input and output tokens, and adds a constant positional encoding. The positional encoding adds information about the position of each token.


Figure 1. The architecture of a Transformer model.

The complete description of the Transformer architecture can be found in Attention Is All You Need paper.

Default configuration

The Transformer uses Byte Pair Encoding tokenization scheme using Moses decoder. This is a lossy compression method (we drop information about white spaces). Tokenization is applied over whole WMT14 en-de dataset including test set. Default vocabulary size is 33708, excluding all special tokens. Encoder and decoder are using shared embeddings. We use 6 blocks in each encoder and decoder stacks. Self attention layer computes it's outputs according to the following formula `Attention(Q,K,V) = softmax(\frac{QK^T}{\sqrt{d_k}})V`. At each attention step, the model computes 16 different attention representations (which we will call attention heads) and concatenates them. We trained the Transformer model using the Adam optimizer with betas (0.9, 0.997), epsilon 1e-9 and learning rate 6e-4. We used the inverse square root training schedule preceded with linear warmup of 4000 steps. The implementation allows to perform training in mixed precision. We use dynamic loss scaling and custom mixed precision optimizer. Distributed multi-GPU and multi-Node is implemented with torch.distirbuted module with NCCL backend. For inference, we use beam search with default beam size of 5. Model performance is evaluated with BLEU4 metrics. For clarity, we report internal (legacy) BLEU implementation as well as external SacreBleu score.

Feature support matrix

The following features are supported by this model.

Feature Yes column
Multi-GPU training with Distributed Communication Package Yes
Nvidia APEX Yes
AMP Yes
TorchScript Yes

Features

  • Multi-GPU training with Distributed Communication Package: Our model uses torch.distributed package to implement efficient multi-GPU training with NCCL. To enable multi-GPU training with torch.distributed, you have to initialize your model identically in every process spawned by torch.distributed.launch. Distributed strategy is implemented with APEX's DistributedDataParallel. For details, see example sources in this repo or see the pytorch tutorial

  • Nvidia APEX: The purpose of the APEX is to provide easy and intuitive framework for distributed training and mixed precision training. For details, see official APEX repository.

  • AMP: This implementation uses Apex's AMP to perform mixed precision training.

  • TorchScript: Transformer can be converted to TorchScript format offering ease of deployment on platforms without Python dependencies. For more information see official TorchScript documentation.

Mixed precision training

Mixed precision is the combined use of different numerical precisions in a computational method. Mixed precision training offers significant computational speedup by performing operations in half-precision format, while storing minimal information in single-precision to retain as much information as possible in critical parts of the network. Since the introduction of Tensor Cores in the Volta and Turing architecture, significant training speedups are experienced by switching to mixed precision -- up to 3x overall speedup on the most arithmetically intense model architectures. Using mixed precision training requires two steps:

  1. Porting the model to use the FP16 data type where appropriate.
  2. Adding loss scaling to preserve small gradient values.

The ability to train deep learning networks with lower precision was introduced in the Pascal architecture and first supported in CUDA 8 in the NVIDIA Deep Learning SDK.

For information about:

Enabling mixed precision

Mixed precision is enabled using the --amp option in the train.py script. The default is optimization level O2 but can be overriden with --amp-level $LVL option (for details see amp documentation). Forward and backward pass are computed with FP16 precision with exclusion of a loss function which is computed in FP32 precision. Default optimization level keeps a copy of a model in higher precision in order to perform accurate weight update. After the update FP32 weights are again copied to FP16 model. We use dynamic loss scaling with initial scale of 2^7 increasing it by a factor of 2 every 2000 successful iterations. Overflow is being checked after reducing gradients from all of the workers. If we encounter infs or nans the whole batch is dropped.

Enabling TF32

TensorFloat-32 (TF32) is the new math mode in NVIDIA A100 GPUs for handling the matrix math also called tensor operations. TF32 running on Tensor Cores in A100 GPUs can provide up to 10x speedups compared to single-precision floating-point math (FP32) on Volta GPUs.

TF32 Tensor Cores can speed up networks using FP32, typically with no loss of accuracy. It is more robust than FP16 for models which require high dynamic range for weights or activations.

For more information, refer to the TensorFloat-32 in the A100 GPU Accelerates AI Training, HPC up to 20x blog post.

TF32 is supported in the NVIDIA Ampere GPU architecture and is enabled by default.

Glossary

Attention layer - Layer that computes which elements of input sequence or it's hidden representation contribute the most to the currently considered output element.
Beam search - A heuristic search algorithm which at each step of predictions keeps N most possible outputs as a base to perform further prediction.
BPE - Binary Pair Encoding, compression algorithm that find most common pair of symbols in a data and replaces them with new symbol absent in the data.
EOS - End of a sentence.
Self attention layer - Attention layer that computes hidden representation of input using the same tensor as query, key and value.
Token - A string that is representable within the model. We also refer to the token's position in the dictionary as a token. There are special non-string tokens: alphabet tokens (all characters in a dataset), EOS token, PAD token.
Tokenizer - Object that converts raw strings to sequences of tokens.
Vocabulary embedding - Layer that projects one-hot token representations to a high dimensional space which preserves some information about correlations between tokens.

Setup

The following section lists the requirements in order to start training the Transformer model.

Requirements

This repository contains Dockerfile which extends the PyTorch NGC container and encapsulates some dependencies. Aside from these dependencies, ensure you have the following components:

For more information about how to get started with NGC containers, see the following sections from the NVIDIA GPU Cloud Documentation and the Deep Learning Documentation:

For those unable to use the PyTorch NGC container, to set up the required environment or create your own container, see the versioned NVIDIA Container Support Matrix.

Quick Start Guide

To train your model using mixed or TF32 precision with Tensor Cores or using FP32, perform the following steps using the default parameters of the Transformer model on the WMT14 English-German dataset. For the specifics concerning training and inference, see the Advanced section.

  1. Clone the repository
git clone https://github.com/NVIDIA/DeepLearningExamples.git 
cd DeepLearningExamples/PyTorch/Translation/Transformer
  1. Build and launch the Transformer PyTorch NGC container
docker build . -t your.repository:transformer
nvidia-docker run -it --rm --ipc=host your.repository:transformer bash

If you already have preprocessed data, use:

nvidia-docker run -it --rm --ipc=host -v <path to your preprocessed data>:/data/wmt14_en_de_joined_dict your.repository:transformer bash

If you already have data downloaded, but it has not yet been preprocessed, use:

nvidia-docker run -it --rm --ipc=host -v <path to your unprocessed data>:/workspace/translation/examples/translation/orig your.repository:transformer bash
  1. Download and preprocess dataset: Download and preprocess the WMT14 English-German dataset.
scripts/run_preprocessing.sh

After running this command, data will be downloaded to /workspace/translation/examples/translation/orig directory and this data will be processed and put into /data/wmt14_en_de_joined_dict directory.

  1. Start training
python -m torch.distributed.launch --nproc_per_node 8 /workspace/translation/train.py /data/wmt14_en_de_joined_dict \
  --arch transformer_wmt_en_de_big_t2t \
  --share-all-embeddings \
  --optimizer adam \
  --adam-betas '(0.9, 0.997)' \
  --adam-eps "1e-9" \
  --clip-norm 0.0 \
  --lr-scheduler inverse_sqrt \
  --warmup-init-lr 0.0 \
  --warmup-updates 4000 \
  --lr 0.0006 \
  --min-lr 0.0 \
  --dropout 0.1 \
  --weight-decay 0.0 \
  --criterion label_smoothed_cross_entropy \
  --label-smoothing 0.1 \
  --max-tokens 5120 \
  --seed 1 \
  --fuse-layer-norm \
  --amp \
  --amp-level O2 \
  --save-dir /workspace/checkpoints \
  --distributed-init-method env:// 

The script saves checkpoints every epoch to the directory specified in the --save-dir option. In addition, the best performing checkpoint (in terms of loss) and the latest checkpoints are saved separately. WARNING: If you don't have access to sufficient disk space, use the --save-interval $N option. The checkpoints are ~3.4GB large. For example, it takes the Transformer model 30 epochs for the validation loss to plateau. The default option is to save last checkpoint, the best checkpoint and a checkpoint for every epoch, which means (30+1+1)*3.4GB = 108.8GB of a disk space used. Specifying --save-interval 10 reduces this to (30/10+1+1)*3.4GB = 17GB.

  1. Start interactive inference
python inference.py \ 
  --buffer-size 5000 \
  --path /path/to/your/checkpoint.pt \
  --max-tokens 10240 \
  --fuse-dropout-add \
  --remove-bpe \
  --bpe-codes /path/to/bpe_code_file \
  --fp16

where,

  • --path option is the location of the checkpoint file.
  • --bpe-codes option is the location of the code file. If the default training command mentioned above is used, this file can be found in the preprocessed data ( i.e., /data/wmt14_en_de_joined_dict ) directory.

Advanced

The following sections provide greater details of the dataset, running training and inference, and the training results.

Scripts and sample code

The preprocess.py script performs binarization of the dataset obtained and tokenized by the examples/translation/prepare-wmt14en2de.sh script. The train.py script contains training loop as well as statistics gathering code. Steps performed in single training step can be found in fairseq/ddp_trainer.py. Model definition is placed in the file fairseq/models/transformer.py. Model specific modules including multiheaded attention and sinusoidal positional embedding are inside the fairseq/modules/ directory. Finally, the data wrappers are placed inside the fairseq/data/ directory.

Parameters

In this section we give a user friendly description of the most common options used in the train.py script.

Command-line options

--arch - select the specific configuration for the model. You can select between various predefined hyper parameters values like number of encoder/decoder blocks, dropout value or size of hidden state representation.
--share-all-embeddings - use the same set of weights for encoder and decoder words embedding.
--optimizer - choose optimization algorithm.
--clip-norm - set a value that gradients will be clipped to.
--lr-scheduler - choose learning rate change strategy.
--warmup-init-lr - start linear warmup with a learning rate at this value.
--warmup-updates - set number of optimization steps after which linear warmup will end.
--lr - set learning rate.
--min-lr - prevent learning rate to fall below this value using arbitrary learning rate schedule.
--dropout - set dropout value.
--weight-decay - set weight decay value.
--criterion - select loss function.
--label-smoothing - distribute value of one-hot labels between all entries of a dictionary. Value set by this option will be a value subtracted from one-hot label.
--max-tokens - set batch size in terms of tokens.
--max-sentences - set batch size in terms of sentences. Note that then the actual batchsize will vary a lot more than when using --max-tokens option.
--seed - set random seed for NumPy and PyTorch RNGs.
--max-epochs - set the maximum number of epochs.
--online-eval - perform inference on test set and then compute BLEU score after every epoch.
--target-bleu - works like --online-eval and sets a BLEU score threshold which after being attained will cause training to stop.
--amp - use mixed precision.
--save-dir - set directory for saving checkpoints.
--distributed-init-method - method for initializing torch.distributed package. You can either provide addresses with the tcp method or use the envionment variables initialization with env method
--update-freq - use gradient accumulation. Set number of training steps across which gradient will be accumulated.

To see the full list of available options and their descriptions, use the -h or --help command line option, for example:

python train.py --help

The following (partial) output is printed when running the sample:

usage: train.py [-h] [--no-progress-bar] [--log-interval N]
                [--log-format {json,none,simple,tqdm}] [--seed N] [--fp16]
                [--profile PROFILE] [--task TASK]
                [--skip-invalid-size-inputs-valid-test] [--max-tokens N]
                [--max-sentences N] [--sentencepiece] [--train-subset SPLIT]
                [--valid-subset SPLIT] [--max-sentences-valid N]
                [--gen-subset SPLIT] [--num-shards N] [--shard-id ID]
                [--distributed-world-size N]
                [--distributed-rank DISTRIBUTED_RANK]
                [--local_rank LOCAL_RANK]
                [--distributed-backend DISTRIBUTED_BACKEND]
                [--distributed-init-method DISTRIBUTED_INIT_METHOD]
                [--distributed-port DISTRIBUTED_PORT] [--device-id DEVICE_ID]
                --arch ARCH [--criterion CRIT] [--max-epoch N]
                [--max-update N] [--target-bleu TARGET] [--clip-norm NORM]
                [--sentence-avg] [--update-freq N] [--optimizer OPT]
                [--lr LR_1,LR_2,...,LR_N] [--momentum M] [--weight-decay WD]
                [--lr-scheduler LR_SCHEDULER] [--lr-shrink LS] [--min-lr LR]
                [--min-loss-scale D] [--enable-parallel-backward-allred-opt]
                [--parallel-backward-allred-opt-threshold N]
                [--enable-parallel-backward-allred-opt-correctness-check]
                [--save-dir DIR] [--restore-file RESTORE_FILE]
                [--save-interval N] [--save-interval-updates N]
                [--keep-interval-updates N] [--no-save]
                [--no-epoch-checkpoints] [--validate-interval N] [--path FILE]
                [--remove-bpe [REMOVE_BPE]] [--cpu] [--quiet] [--beam N]
                [--nbest N] [--max-len-a N] [--max-len-b N] [--min-len N]
                [--no-early-stop] [--unnormalized] [--no-beamable-mm]
                [--lenpen LENPEN] [--unkpen UNKPEN]
                [--replace-unk [REPLACE_UNK]] [--score-reference]
                [--prefix-size PS] [--sampling] [--sampling-topk PS]
                [--sampling-temperature N] [--print-alignment]
                [--model-overrides DICT] [--online-eval] 
                [--bpe-codes CODES] [--fuse-dropout-add] [--fuse-relu-dropout]

Getting the data

The Transformer model was trained on the WMT14 English-German dataset. Concatenation of the commoncrawl, europarl and news-commentary is used as train and validation dataset and newstest2014 is used as test dataset.
This repository contains the run_preprocessing.sh script which will automatically downloads and preprocesses the training and test datasets. By default, data will be stored in the /data/wmt14_en_de_joined_dict directory.
Our download script utilizes Moses decoder to perform tokenization of the dataset and subword-nmt to segment text into subword units (BPE). By default, the script builds a shared vocabulary of 33708 tokens, which is consistent with Scaling Neural Machine Translation.

Dataset guidelines

The Transformer model works with a fixed sized vocabulary. Prior to the training, we need to learn a data representation that allows us to store the entire dataset as a sequence of tokens. To achieve this we use Binary Pair Encoding. This algorithm builds a vocabulary by iterating over a dataset, looking for the most frequent pair of symbols and replacing them with a new symbol, yet absent in the dataset. After identifying the desired number of encodings (new symbols can also be merged together) it outputs a code file that is used as an input for the Dictionary class. This approach does not minimize the length of the encoded dataset, however this is allowed using SentencePiece to tokenize the dataset with the unigram model. This approach tries to find encoding that is close to the theoretical entropy limit. Data is then sorted by length (in terms of tokens) and examples with similar length are batched together, padded if necessary.

Multi-dataset

The model has been tested oni the wmt14 en-fr dataset. Achieving state of the art accuracy of 41.4 BLEU.

Training process

The default training configuration can be launched by running the train.py training script. By default, the script saves one checkpoint every epoch in addition to the latest and the best ones. The best checkpoint is considered the one with the lowest value of loss, not the one with the highest BLEU score. To override this behavior use the --save-interval $N option to save epoch checkpoints every N epoch or --no-epoch-checkpoints to disable them entirely (with this option the latest and the best checkpoints still will be saved). Specify save the directory with --save-dir option.
In order to run multi-GPU training, launch the training script with python -m torch.distributed.launch --nproc_per_node $N prepended, where N is the number of GPUs. We have tested reliance on up to 16 GPUs on a single node.
After each training epoch, the script runs a loss validation on the validation split of the dataset and outputs the validation loss. By default the evaluation after each epoch is disabled. To enable it, use the --online-eval option or to use the BLEU score value as the training stopping condition use the --target-bleu $TGT option. The BLEU scores computed are case insensitive. The BLEU is computed by the internal fairseq algorithm which implementation can be found in the fairseq/bleu.py script.
By default, the train.py script will launch FP32 training without Tensor Cores. To use mixed precision with Tensor Cores use the --fp16 option.

To reach the BLEU score reported in Scaling Neural Machine Translation research paper, we used mixed precision training with a batch size of 5120 per GPU and learning rate of 6e-4 on a DGX-1V system with 8 Tesla V100s 16G. If you use a different setup, we recommend you scale your hyperparameters by applying the following rules:

  1. To use FP32, reduce the batch size to 2560 and set the --update-freq 2 option.
  2. To train on a fewer GPUs, multiply --update-freq by the reciprocal of the scaling factor.

For example, when training in FP32 mode on 4 GPUs, use the --update-freq=4 option.

Inference process

Inference on a raw input can be performed by piping file to be translated into the inference.py script. It requires a pre-trained model checkpoint, BPE codes file and dictionary file (both are produced by the run_preprocessing.sh script and can be found in the dataset directory).
In order to run interactive inference, run command:

python inference.py --path /path/to/your/checkpoint.pt --fuse-dropout-add --remove-bpe --bpe-codes /path/to/code/file

The --buffer-size option allows the batching of input sentences up to --max_token length.

To test model checkpoint accuracy on wmt14 test set run following command:

sacrebleu -t wmt14/full -l en-de --echo src | python inference.py --buffer-size 5000 --path /path/to/your/checkpoint.pt --max-tokens 10240 --fuse-dropout-add --remove-bpe --bpe-codes /data/code --fp16 | sacrebleu -t wmt14/full -l en-de -lc

Performance

The performance measurements in this document were conducted at the time of publication and may not reflect the performance achieved from NVIDIAs latest software release. For the most up-to-date performance measurements, go to NVIDIA Data Center Deep Learning Product Performance.

Benchmarking

The following section shows how to run benchmarks measuring the model performance in training and inference modes.

Training performance benchmark

To benchmark the training performance on a specific batch size, run train.py training script. Performance in tokens/s will be printed to standard output every N iterations, specified by the --log-interval option. Additionally performance and loss values will be logged by dllogger to the file specified in --stat-file option. Every line in the output file will be a valid JSON file prepended with DLLL prefix.

Inference performance benchmark

To benchmark the inference performance on a specific batch size, run following command to start the benchmark

for i in {1..10}; do sacrebleu -t wmt14/full -l en-de --echo src; done | python inference.py --buffer-size 5000 --path /path/to/your/checkpoint.pt --max-tokens 10240 --fuse-dropout-add --remove-bpe --bpe-codes /data/code --fp16 > /dev/null

Results will be printed to stderr.

Results

The following sections provide details on how we achieved our performance and accuracy in training and inference.

Training accuracy results

Following the spirit of the paper A Call for Clarity in Reporting BLEU Scores we decided to change evaluation metric implemented in fairseq to SacreBleu score. We have calculated that the new metric has almost linear relationship with the old one. We run linear regression on nearly 2000 checkpoints to discover that the SacreBleu score almost perfectly follows the formula: newScore = 0.978 * oldScore - 0.05.


Figure 2. Linear relationship between old and new BLEU metric.

To take into account the varibaility of the results we computed basic statistics that help us verify whether a model trains correctly. Evaluating nearly 2000 checkpoints from 20 runs, the best score we achieved is 28.09 BLEU (which corresponds to 28.77 old score). Variance of the score of the best performing model between those 20 runs is 0.011. Knowing that max statistic is skewed toward higher values we have also run studies which calculate threshold beyond which validation loss is no longer correlated with BLEU score. Of course our hope is that dev's set distribution is similar to test's set distribution and when validation loss drops, BLEU score rises. But due to the finiteness of the validation and test sets we expect that there is such a loss value that makes performance on both sets decoupled from each other. To find this point we used Pearson correlation coefficient as a metric. The results indicate that optimizing beyond 4.02 validation loss value is no longer beneficial for the BLEU score. Further optimization does not cause overfitting but results become stochastic. Mean BLEU score after reaching 4.02 validation loss is 27.38. We observe variance of 0.08, which translate to nearly 0.3 BLEU average difference between mean score and obtained score.


Figure 3. Validation loss vs BLEU score. Plots are trimmed to certain validation loss threshold.

Training accuracy: NVIDIA DGX A100 (8x A100 40GB)

Our results were obtained by running the run_DGXA100_AMP_8GPU.sh and run_DGXA100_TF32_8GPU.sh training scripts in the pytorch-20.06-py3 NGC container on NVIDIA DGX A100 (8x A100 40GB) GPUs. We report average accuracy over 6 runs. We consider a model trained when it reaches minimal validation loss. Time to train contains only training time without validation. Depending on a configuration and frequency of validation it can take up to additional minute per epoch.

GPUs Batch size / GPU Accuracy - TF32 Accuracy - mixed precision Time to train - TF32 Time to train - mixed precision Time to train speedup (TF32 to mixed precision)
8 10240 27.92 27.76 2.87 hours 2.79 hours x1.03
Training accuracy: NVIDIA DGX-1 (8x V100 16GB)

Our results were obtained by running the run_DGX1_AMP_8GPU.sh and run_DGX1_FP32_8GPU.sh training scripts in the pytorch-20.06-py3 NGC container on NVIDIA DGX-1 (8x V100 16GB) GPUs. We report average accuracy over 6 runs. We consider a model trained when it reaches minimal validation loss. Time to train contains only training time without validation. Depending on a configuration and frequency of validation it can take up to additional minute per epoch. Using mixed precision we could fit a larger batch size in the memory, further speeding up the training.

GPUs Batch size / GPU Accuracy - FP32 Accuracy - mixed precision Time to train - FP32 Time to train - mixed precision Time to train speedup (FP32 to mixed precision)
8 5120/2560 27.66 27.82 12 hours 4.6 hours x2.64

Training performance results

Training performance: NVIDIA DGX A100 (8x A100 40GB)

Our results were obtained by running the run_DGXA100_AMP_8GPU.sh and run_DGXA100_TF32_8GPU.sh training scripts in the pytorch-20.06-py3 NGC container on NVIDIA DGX A100 (8x A100 40GB) GPUs. Performance numbers (in tokens per second) were averaged over an entire training epoch.

GPUs Batch size / GPU Throughput - TF32 Throughput - mixed precision Throughput speedup (TF32 - mixed precision) Weak scaling - TF32 Weak scaling - mixed precision
8 10240 316913 582721 x1.84 6.93 7.05
4 10240 161980 298741 x1.84 3.54 3.62
1 10240 45755 82618 x1.81 1 1

To achieve these same results, follow the steps in the Quick Start Guide.

Training stability test

The following plot shows average validation loss curves for different configs. We can see that training with AMP O2 converges slightly slower that FP32 and TF32 training. In order to mitigate this, you can use option --amp-level O1 at the cost of 20% performance drop compared to the default AMP setting.


Figure 4. Validation loss curves

Training performance: NVIDIA DGX-1 (8x V100 16GB)

Our results were obtained by running the run_DGX1_AMP_8GPU.sh and run_DGX1_FP32_8GPU.sh training scripts in the pytorch-20.06-py3 NGC container on NVIDIA DGX-1 with (8x V100 16GB) GPUs. Performance numbers (in tokens per second) were averaged over an entire training epoch. Using mixed precision we could fit a larger batch size in the memory, further speeding up the training.

GPUs Batch size / GPU Throughput - FP32 Throughput - mixed precision Throughput speedup (FP32 - mixed precision) Weak scaling - FP32 Weak scaling - mixed precision
8 5120/2560 58742 223245 x3.80 6.91 6.67
4 5120/2560 29674 115269 x3.88 3.49 3.44
1 5120/2560 8498 33468 x3.94 1 1

To achieve these same results, follow the steps in the Quick Start Guide.

Training performance: NVIDIA DGX-2 (16x V100 32GB)

Our results were obtained by running the run_DGX1_AMP_8GPU.sh and run_DGX1_FP32_8GPU.sh training scripts setting number of GPUs to 16 in the pytorch-20.06-py3 NGC container on NVIDIA DGX-2 with (16x V100 32GB) GPUs. Performance numbers (in tokens per second) were averaged over an entire training epoch. Using mixed precision we could fit a larger batch size in the memory, further speeding up the training.

GPUs Batch size / GPU Throughput - FP32 Throughput - mixed precision Throughput speedup (FP32 - mixed precision) Weak scaling - FP32 Weak scaling - mixed precision
16 10240/5120 130867 510267 x3.9 13.38 12.7
8 10240/5120 68829 269464 x3.91 7.04 6.71
4 10240/5120 35168 141143 x4.01 3.6 3.51
1 10240/5120 9779 40163 x4.11 1 1

To achieve these same results, follow the steps in the Quick Start Guide.

Inference performance results

Our implementation of the Transformer has dynamic batching algorithm, which batches sentences together in such a way that there are no more than N tokens in each batch or no more than M sentences in each batch. In this benchmark we use the first option in order to get the most stable results.

Inference performance: NVIDIA DGX A100 (1x A100 40GB)

Our results were obtained by running the inference.py inferencing benchmarking script in the pytorch-20.06-py3 NGC container on NVIDIA DGX A100 (1x A100 40GB) GPU.

FP16

Batch size Throughput Avg Latency Avg Latency 90% Latency 95% Latency 99%
10240 9653 0.986s 1.291s 2.157s 2.167s
2560 5092 0.504s 0.721s 0.830s 1.752s
1024 2590 0.402s 0.587s 0.666s 0.918s
512 1357 0.380s 0.561s 0.633s 0.788s
256 721 0.347s 0.513s 0.576s 0.698s

TF32

Batch size Throughput Avg Latency Avg Latency 90% Latency 95% Latency 99%
10240 7755 1.227s 1.592s 2.512s 2.525s
2560 4624 0.555s 0.786s 0.872s 1.886s
1024 2394 0.435s 0.627s 0.702s 0.881s
512 1275 0.405s 0.586s 0.663s 0.821s
256 677 0.370s 0.546s 0.613s 0.733s

To achieve these same results, follow the steps in the Quick Start Guide.

Inference performance: NVIDIA DGX-1 (1x V100 16GB)

Our results were obtained by running the inference.py inferencing benchmarking script in the pytorch-20.06-py3 NGC container on NVIDIA DGX-1 with (1x V100 16GB) GPU.

FP16

Batch size Throughput Avg Latency Avg Latency 90% Latency 95% Latency 99%
10240 7464 1.283s 1.704s 1.792s 1.801s
2560 3596 0.719s 1.066s 1.247s 1.423s
1024 1862 0.563s 0.857s 0.936s 1.156s
512 1003 0.518s 0.782s 0.873s 1.103s
256 520 0.484s 0.723s 0.813s 0.992s

FP32

Batch size Throughput Avg Latency Avg Latency 90% Latency 95% Latency 99%
10240 3782 2.531s 3.091s 3.121s 3.136s
2560 2910 0.888s 1.221s 1.252s 1.432s
1024 1516 0.692s 1.001s 1.126s 1.297s
512 941 0.551s 0.812s 0.893s 1.133s
256 502 0.501s 0.734s 0.822s 0.978s

To achieve these same results, follow the steps in the Quick Start Guide.

Release notes

Changelog

June 2020

  • add TorchScript support
  • Ampere support

March 2020

  • remove language modeling from the repository
  • one inference script for large chunks of data as well as for interactive demo
  • change custom distributed strategy to APEX's DDP
  • replace custom fp16 training with AMP
  • major refactoring of the codebase

December 2019

  • Change evaluation metric

August 2019

  • add basic AMP support

July 2019

  • Replace custom fused operators with jit functions

June 2019

  • New README

March 2019

  • Add mid-training SacreBLEU evaluation. Better handling of OOMs.

Initial commit, forked from fairseq

Known issues

  • Using batch size greater than 16k causes indexing error in strided_batched_gemm module