diff --git a/Dockerfile b/Dockerfile index 99013d46d..f8983ba59 100644 --- a/Dockerfile +++ b/Dockerfile @@ -14,7 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -ARG BASE_IMAGE=nvcr.io/nvidia/pytorch:21.08-py3 +ARG BASE_IMAGE=nvcr.io/nvidia/pytorch:21.10-py3 # build an image that includes only the nemo dependencies, ensures that dependencies diff --git a/Jenkinsfile b/Jenkinsfile index 03080d093..7f795212c 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -1,7 +1,7 @@ pipeline { agent { docker { - image 'nvcr.io/nvidia/pytorch:21.08-py3' + image 'gitlab-master.nvidia.com/dl/dgx/pytorch:21.10-py3-devel' args '--device=/dev/nvidia0 --gpus all --user 0:128 -v /home/TestData:/home/TestData -v $HOME/.cache/torch:/root/.cache/torch --shm-size=8g' } } @@ -10,6 +10,7 @@ pipeline { disableConcurrentBuilds() } stages { + stage('PyTorch version') { steps { sh 'python -c "import torch; print(torch.__version__)"' @@ -23,6 +24,12 @@ pipeline { } } + stage('Code formatting checks') { + steps { + sh 'python setup.py style' + } + } + stage('Copyright Headers check') { steps { sh 'python tests/check_copyright_header.py --dir .' @@ -35,12 +42,6 @@ pipeline { } } - stage('Code formatting checks') { - steps { - sh 'python setup.py style' - } - } - stage('Torch TTS unit tests') { when { anyOf { @@ -58,7 +59,14 @@ pipeline { } } - stage('Installation') { + // Revert once import guards are added by PTL or version comparing is fixed + stage('Install PyTorch Lighting 1.5 RC') { + steps{ + sh 'pip install pytorch-lightning==1.5.0rc0 && sed -i "s/from pytorch_lightning.callbacks.quantization import QuantizationAwareTraining/try:\\n\\tfrom pytorch_lightning.callbacks.quantization import QuantizationAwareTraining\\nexcept:\\n\\tpass/g" /opt/conda/lib/python3.8/site-packages/pytorch_lightning/callbacks/__init__.py' + } + } + + stage('NeMo Installation') { steps { sh './reinstall.sh release' } @@ -605,30 +613,31 @@ pipeline { } } - stage('L2: Multi-GPU Megatron finetuning') { - when { - anyOf { - branch 'main' - changeRequest target: 'main' - } - } - failFast true - parallel { - stage('L2: Cased Megatron finetuning on MRPC') { - steps { - sh 'cd examples/nlp/glue_benchmark && \ - python glue_benchmark.py \ - model.dataset.data_dir=/home/TestData/nlp/glue_fake/MRPC \ - trainer.gpus=[0,1] \ - +trainer.fast_dev_run=true \ - model.dataset.use_cache=false \ - model.language_model.pretrained_model_name=megatron-bert-345m-cased \ - trainer.accelerator=ddp \ - exp_manager=null' - } - } - } - } + // TODO: add test once megatron-bert is supported again + // stage('L2: Multi-GPU Megatron finetuning') { + // when { + // anyOf { + // branch 'main' + // changeRequest target: 'main' + // } + // } + // failFast true + // parallel { + // stage('L2: Cased Megatron finetuning on MRPC') { + // steps { + // sh 'cd examples/nlp/glue_benchmark && \ + // python glue_benchmark.py \ + // model.dataset.data_dir=/home/TestData/nlp/glue_fake/MRPC \ + // trainer.gpus=[0,1] \ + // +trainer.fast_dev_run=true \ + // model.dataset.use_cache=false \ + // model.language_model.pretrained_model_name=megatron-bert-345m-cased \ + // trainer.accelerator=ddp \ + // exp_manager=null' + // } + // } + // } + // } stage('L2: SGD-QA') { when { @@ -725,7 +734,6 @@ pipeline { model.language_model.pretrained_model_name=bert-base-uncased \ model.dataset.version_2_with_negative=false \ trainer.precision=16 \ - trainer.amp_level=O1 \ trainer.gpus=[0] \ exp_manager=null' } @@ -747,7 +755,6 @@ pipeline { model.language_model.pretrained_model_name=bert-base-uncased \ model.dataset.version_2_with_negative=true \ trainer.precision=16 \ - trainer.amp_level=O1 \ trainer.gpus=[1] \ exp_manager=null' } @@ -777,30 +784,31 @@ pipeline { } } // Runs out of memory on the 12G TITAN V (GPU 0 on main CI) - stage('L2: MegaBERT Token Classification') { - when { - anyOf { - branch 'main' - changeRequest target: 'main' - } - } - failFast true - steps { - sh 'cd examples/nlp/token_classification && \ - python token_classification_train.py \ - model.dataset.data_dir=/home/TestData/nlp/token_classification_punctuation/ \ - model.language_model.pretrained_model_name=megatron-bert-345m-uncased \ - model.train_ds.batch_size=10 \ - model.dataset.max_seq_length=50 \ - model.dataset.use_cache=false \ - trainer.accelerator=ddp \ - trainer.precision=16 \ - trainer.amp_level=O1 \ - trainer.gpus=[1] \ - +trainer.fast_dev_run=true \ - exp_manager=null' - } - } + // TODO: add when megatron bert is supported again in NeMo + // stage('L2: MegaBERT Token Classification') { + // when { + // anyOf { + // branch 'main' + // changeRequest target: 'main' + // } + // } + // failFast true + // steps { + // sh 'cd examples/nlp/token_classification && \ + // python token_classification_train.py \ + // model.dataset.data_dir=/home/TestData/nlp/token_classification_punctuation/ \ + // model.language_model.pretrained_model_name=megatron-bert-345m-uncased \ + // model.train_ds.batch_size=10 \ + // model.dataset.max_seq_length=50 \ + // model.dataset.use_cache=false \ + // trainer.accelerator=ddp \ + // trainer.precision=16 \ + // trainer.gpus=[1] \ + // +trainer.fast_dev_run=true \ + // exp_manager=null' + // } + // } + stage('L2: Parallel SQUAD v1.1 & v2.0') { when { anyOf { @@ -810,8 +818,13 @@ pipeline { } failFast true parallel { - stage('SQUAD v2.0 with Megatron with ckpt & config') { + // TODO: use megatron bert when supported again + stage('SQUAD v2.0 with DistilBERT Uncased') { + // stage('SQUAD v2.0 with Megatron with ckpt & config') { // Cannot do fast_dev_run because squad needs whole dev dataset + // model.language_model.pretrained_model_name=megatron-bert-uncased \ + // model.language_model.lm_checkpoint=/home/TestData/nlp/megatron_345m_uncased/model_optim_rng.pt \ + // model.language_model.config_file=/home/TestData/nlp/megatron_345m_uncased/345m_config.json \ steps { sh 'cd examples/nlp/question_answering && \ python question_answering_squad.py \ @@ -825,12 +838,9 @@ pipeline { trainer.max_epochs=1 \ +trainer.max_steps=1 \ model.validation_ds.file=/home/TestData/nlp/squad_mini/v2.0/dev-v2.0.json \ - model.language_model.pretrained_model_name=megatron-bert-uncased \ - model.language_model.lm_checkpoint=/home/TestData/nlp/megatron_345m_uncased/model_optim_rng.pt \ - model.language_model.config_file=/home/TestData/nlp/megatron_345m_uncased/345m_config.json \ + model.language_model.pretrained_model_name=distilbert-base-uncased \ model.dataset.version_2_with_negative=true \ trainer.precision=16 \ - trainer.amp_level=O1 \ trainer.gpus=[1] \ exp_manager=null' } @@ -852,7 +862,6 @@ pipeline { model.language_model.pretrained_model_name=roberta-base \ model.dataset.version_2_with_negative=false \ trainer.precision=16 \ - trainer.amp_level=O1 \ trainer.gpus=[0] \ exp_manager=null' } @@ -864,7 +873,7 @@ pipeline { model.dataset.num_classes=6 \ model.train_ds.file_path=/home/TestData/nlp/retail_text_classification/train.tsv \ model.validation_ds.file_path=/home/TestData/nlp/retail_text_classification/dev.tsv \ - model.language_model.pretrained_model_name=bert-base-uncased \ + model.language_model.pretrained_model_name=distilbert-base-uncased \ model.train_ds.batch_size=10 \ model.dataset.max_seq_length=50 \ model.dataset.use_cache=false \ @@ -889,107 +898,106 @@ pipeline { } } - stage('L2: Model Parallel Size 2 Megatron Text Classification') { - when { - anyOf{ - branch 'main' - changeRequest target: 'main' - } - } - failFast true - steps{ - sh 'cd examples/nlp/text_classification && \ - python text_classification_with_bert.py \ - trainer.gpus=[0,1] \ - trainer.num_nodes=1 \ - trainer.precision=16 \ - trainer.gradient_clip_val=1.0 \ - ~trainer.amp_level \ - +trainer.fast_dev_run=true \ - model.dataset.num_classes=6 \ - model.train_ds.file_path=/home/TestData/nlp/retail_text_classification/train.tsv \ - model.train_ds.batch_size=4 \ - model.language_model.pretrained_model_name=megatron-bert-uncased \ - model.language_model.config_file=/home/TestData/nlp/mp_2_bert_toy/config.json \ - model.language_model.lm_checkpoint=/home/TestData/nlp/mp_2_bert_toy/iter_2000000 \ - model.nemo_path=null \ - ~model.infer_samples \ - exp_manager=null' - } - } + // TODO: add when megatron-bert is supported again + // stage('L2: Model Parallel Size 2 Megatron Text Classification') { + // when { + // anyOf{ + // branch 'main' + // changeRequest target: 'main' + // } + // } + // failFast true + // steps{ + // sh 'cd examples/nlp/text_classification && \ + // python text_classification_with_bert.py \ + // trainer.gpus=[0,1] \ + // trainer.num_nodes=1 \ + // trainer.precision=16 \ + // trainer.gradient_clip_val=1.0 \ + // +trainer.fast_dev_run=true \ + // model.dataset.num_classes=6 \ + // model.train_ds.file_path=/home/TestData/nlp/retail_text_classification/train.tsv \ + // model.train_ds.batch_size=4 \ + // model.language_model.pretrained_model_name=megatron-bert-uncased \ + // model.language_model.config_file=/home/TestData/nlp/mp_2_bert_toy/config.json \ + // model.language_model.lm_checkpoint=/home/TestData/nlp/mp_2_bert_toy/iter_2000000 \ + // model.nemo_path=null \ + // ~model.infer_samples \ + // exp_manager=null' + // } + // } - stage('L2: Model Parallel Size 2 Megatron Autoresume') { - when { - anyOf{ - branch 'main' - changeRequest target: 'main' - } - } - failFast true - steps{ - sh 'cd examples/nlp/text_classification && \ - python text_classification_with_bert.py \ - trainer.gpus=[0,1] \ - trainer.num_nodes=1 \ - trainer.precision=16 \ - trainer.gradient_clip_val=1.0 \ - trainer.max_epochs=1 \ - ~trainer.amp_level \ - +trainer.fast_dev_run=true \ - model.dataset.num_classes=6 \ - model.train_ds.file_path=/home/TestData/nlp/retail_text_classification/train.tsv \ - model.train_ds.batch_size=4 \ - model.language_model.pretrained_model_name=megatron-bert-uncased \ - model.language_model.config_file=/home/TestData/nlp/mp_2_bert_toy/config.json \ - model.language_model.lm_checkpoint=/home/TestData/nlp/mp_2_bert_toy/iter_2000000 \ - model.nemo_path=null \ - ~model.infer_samples \ - +exp_manager.explicit_log_dir=/home/TestData/nlp/mp_autoresume \ - +exp_manager.resume_if_exists=true' - } - } + // stage('L2: Model Parallel Size 2 Megatron Autoresume') { + // when { + // anyOf{ + // branch 'main' + // changeRequest target: 'main' + // } + // } + // failFast true + // steps{ + // sh 'cd examples/nlp/text_classification && \ + // python text_classification_with_bert.py \ + // trainer.gpus=[0,1] \ + // trainer.num_nodes=1 \ + // trainer.precision=16 \ + // trainer.gradient_clip_val=1.0 \ + // trainer.max_epochs=1 \ + // +trainer.fast_dev_run=true \ + // model.dataset.num_classes=6 \ + // model.train_ds.file_path=/home/TestData/nlp/retail_text_classification/train.tsv \ + // model.train_ds.batch_size=4 \ + // model.language_model.pretrained_model_name=megatron-bert-uncased \ + // model.language_model.config_file=/home/TestData/nlp/mp_2_bert_toy/config.json \ + // model.language_model.lm_checkpoint=/home/TestData/nlp/mp_2_bert_toy/iter_2000000 \ + // model.nemo_path=null \ + // ~model.infer_samples \ + // +exp_manager.explicit_log_dir=/home/TestData/nlp/mp_autoresume \ + // +exp_manager.resume_if_exists=true' + // } + // } - stage('L2: Model Parallel Size 2 Megatron Evaluation from .nemo') { - when { - anyOf{ - branch 'main' - changeRequest target: 'main' - } - } - failFast true - steps{ - sh 'cd examples/nlp/text_classification && \ - python model_parallel_text_classification_evaluation.py \ - trainer.gpus=[0,1] \ - trainer.num_nodes=1 \ - model.dataset.num_classes=6 \ - model.test_ds.file_path=/home/TestData/nlp/retail_text_classification/dev.tsv \ - model.nemo_path=/home/TestData/nlp/mp_2_nemo/retail_text_class_350M.nemo \ - exp_manager=null' - } - } + // stage('L2: Model Parallel Size 2 Megatron Evaluation from .nemo') { + // when { + // anyOf{ + // branch 'main' + // changeRequest target: 'main' + // } + // } + // failFast true + // steps{ + // sh 'cd examples/nlp/text_classification && \ + // python model_parallel_text_classification_evaluation.py \ + // trainer.gpus=[0,1] \ + // trainer.num_nodes=1 \ + // model.dataset.num_classes=6 \ + // model.test_ds.file_path=/home/TestData/nlp/retail_text_classification/dev.tsv \ + // model.nemo_path=/home/TestData/nlp/mp_2_nemo/retail_text_class_350M.nemo \ + // exp_manager=null' + // } + // } - stage('L2: Model Parallel Size 2 Megatron Train from .nemo') { - when { - anyOf{ - branch 'main' - changeRequest target: 'main' - } - } - failFast true - steps{ - sh 'cd examples/nlp/token_classification && \ - python token_classification_train.py \ - pretrained_model=/home/TestData/nlp/mp_2_nemo/ner_350M.nemo \ - model.dataset.data_dir=/home/TestData/nlp/ner/ \ - model.train_ds.batch_size=2 \ - model.dataset.use_cache=false \ - trainer.gpus=[0,1] \ - +trainer.fast_dev_run=true \ - model.dataset.class_balancing="weighted_loss" \ - exp_manager=null' - } - } + // stage('L2: Model Parallel Size 2 Megatron Train from .nemo') { + // when { + // anyOf{ + // branch 'main' + // changeRequest target: 'main' + // } + // } + // failFast true + // steps{ + // sh 'cd examples/nlp/token_classification && \ + // python token_classification_train.py \ + // pretrained_model=/home/TestData/nlp/mp_2_nemo/ner_350M.nemo \ + // model.dataset.data_dir=/home/TestData/nlp/ner/ \ + // model.train_ds.batch_size=2 \ + // model.dataset.use_cache=false \ + // trainer.gpus=[0,1] \ + // +trainer.fast_dev_run=true \ + // model.dataset.class_balancing="weighted_loss" \ + // exp_manager=null' + // } + // } stage('L2: Parallel NLP Examples 2') { when { @@ -1089,7 +1097,6 @@ pipeline { --config-name=bert_pretraining_from_text_config.yaml \ trainer.gpus=[0] \ trainer.precision=16 \ - trainer.amp_level=O1 \ +trainer.fast_dev_run=true \ model.train_ds.data_file=/home/TestData/nlp/wikitext-2/train.txt \ model.train_ds.batch_size=32 \ @@ -1116,7 +1123,6 @@ pipeline { --config-name=bert_pretraining_from_preprocessed_config.yaml \ trainer.gpus=[1] \ trainer.precision=16 \ - trainer.amp_level=O1 \ +trainer.fast_dev_run=true \ model.train_ds.data_file=/home/TestData/nlp/wiki_book_mini/training \ model.train_ds.batch_size=8 \ @@ -1351,39 +1357,40 @@ pipeline { } } - stage('L2: NMT Megatron BERT Model Parallel Size 2 Encoder') { - when { - anyOf{ - branch 'main' - changeRequest target: 'main' - } - } - failFast true - steps{ - sh 'cd examples/nlp/machine_translation && \ - python enc_dec_nmt.py \ - --config-path=conf \ - --config-name=megatron \ - model.encoder.model_name=megatron-bert-uncased \ - model.encoder.checkpoint_file=/home/TestData/nlp/mp_2_bert_toy/iter_2000000 \ - model.encoder.hidden_size=1024 \ - model.encoder.num_attention_heads=16 \ - model.encoder.num_layers=24 \ - model.encoder.max_position_embeddings=512 \ - model.train_ds.src_file_name=/home/TestData/nlp/nmt/toy_data/wmt14-de-en.src \ - model.train_ds.tgt_file_name=/home/TestData/nlp/nmt/toy_data/wmt14-de-en.ref \ - model.validation_ds.src_file_name=/home/TestData/nlp/nmt/toy_data/wmt14-de-en.src \ - model.validation_ds.tgt_file_name=/home/TestData/nlp/nmt/toy_data/wmt14-de-en.src \ - model.test_ds.src_file_name=/home/TestData/nlp/nmt/toy_data/wmt14-de-en.src \ - model.test_ds.tgt_file_name=/home/TestData/nlp/nmt/toy_data/wmt14-de-en.src \ - model.decoder_tokenizer.tokenizer_model=/home/TestData/nlp/nmt/toy_data/tt_tokenizer.BPE.4096.model \ - model.decoder.hidden_size=1024 \ - trainer.gpus=[0,1] \ - +trainer.fast_dev_run=true \ - exp_manager=null \ - ' - } - } + // TODO: add when megatron bert is supported again in NeMo + // stage('L2: NMT Megatron BERT Model Parallel Size 2 Encoder') { + // when { + // anyOf{ + // branch 'main' + // changeRequest target: 'main' + // } + // } + // failFast true + // steps{ + // sh 'cd examples/nlp/machine_translation && \ + // python enc_dec_nmt.py \ + // --config-path=conf \ + // --config-name=megatron \ + // model.encoder.model_name=megatron-bert-uncased \ + // model.encoder.checkpoint_file=/home/TestData/nlp/mp_2_bert_toy/iter_2000000 \ + // model.encoder.hidden_size=1024 \ + // model.encoder.num_attention_heads=16 \ + // model.encoder.num_layers=24 \ + // model.encoder.max_position_embeddings=512 \ + // model.train_ds.src_file_name=/home/TestData/nlp/nmt/toy_data/wmt14-de-en.src \ + // model.train_ds.tgt_file_name=/home/TestData/nlp/nmt/toy_data/wmt14-de-en.ref \ + // model.validation_ds.src_file_name=/home/TestData/nlp/nmt/toy_data/wmt14-de-en.src \ + // model.validation_ds.tgt_file_name=/home/TestData/nlp/nmt/toy_data/wmt14-de-en.src \ + // model.test_ds.src_file_name=/home/TestData/nlp/nmt/toy_data/wmt14-de-en.src \ + // model.test_ds.tgt_file_name=/home/TestData/nlp/nmt/toy_data/wmt14-de-en.src \ + // model.decoder_tokenizer.tokenizer_model=/home/TestData/nlp/nmt/toy_data/tt_tokenizer.BPE.4096.model \ + // model.decoder.hidden_size=1024 \ + // trainer.gpus=[0,1] \ + // +trainer.fast_dev_run=true \ + // exp_manager=null \ + // ' + // } + // } stage('L2: NMT Tarred Dataset Creation') { when { diff --git a/README.rst b/README.rst index d80db4b4e..a9424e386 100644 --- a/README.rst +++ b/README.rst @@ -85,7 +85,7 @@ Requirements ------------ 1) Python 3.6, 3.7 or 3.8 -2) Pytorch 1.8.1 or above +2) Pytorch 1.10.0 or above 3) NVIDIA GPU for training Documentation @@ -163,16 +163,28 @@ Note that RNNT requires numba to be installed from conda. pip uninstall numba conda install -c numba numba +Megatron GPT +~~~~~~~~~~~~ +Megatron GPT training requires NVIDIA Apex to be installed. + +.. code-block:: bash + + git clone https://github.com/NVIDIA/apex + cd apex + pip install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./ + Docker containers: ~~~~~~~~~~~~~~~~~~ -If you chose to work with main branch, we recommend using NVIDIA's PyTorch container version 21.08-py3 and then installing from GitHub. +If you chose to work with main branch, we recommend using NVIDIA's PyTorch container version 21.10-py3 and then installing from GitHub. +Note NVIDIA's PyTorch 21.10-py3 has not yet been released publicy. Please use a container with the nightly version of PyTorch installed if you are +unable to access the NVIDIA's PyTorch 21.10 container. .. code-block:: bash docker run --gpus all -it --rm -v :/NeMo --shm-size=8g \ -p 8888:8888 -p 6006:6006 --ulimit memlock=-1 --ulimit \ - stack=67108864 --device=/dev/snd nvcr.io/nvidia/pytorch:21.08-py3 + stack=67108864 --device=/dev/snd nvcr.io/nvidia/pytorch:21.10-py3 Examples -------- diff --git a/examples/asr/conf/conformer/conformer_ctc_bpe.yaml b/examples/asr/conf/conformer/conformer_ctc_bpe.yaml index 91b8dc477..e9219ab21 100644 --- a/examples/asr/conf/conformer/conformer_ctc_bpe.yaml +++ b/examples/asr/conf/conformer/conformer_ctc_bpe.yaml @@ -156,7 +156,6 @@ trainer: accelerator: ddp accumulate_grad_batches: 1 gradient_clip_val: 0.0 - amp_level: O0 # O1/O2 for mixed precision precision: 32 # Should be set to 16 for O1 and O2 to enable the AMP. log_every_n_steps: 10 # Interval of logging. progress_bar_refresh_rate: 10 diff --git a/examples/asr/conf/conformer/conformer_ctc_char.yaml b/examples/asr/conf/conformer/conformer_ctc_char.yaml index c6a3d085e..38eb194bb 100644 --- a/examples/asr/conf/conformer/conformer_ctc_char.yaml +++ b/examples/asr/conf/conformer/conformer_ctc_char.yaml @@ -131,7 +131,6 @@ trainer: accelerator: ddp accumulate_grad_batches: 1 gradient_clip_val: 0.0 - amp_level: O0 # O1/O2 for mixed precision precision: 32 # Should be set to 16 for O1 and O2 to enable the AMP. log_every_n_steps: 10 # Interval of logging. progress_bar_refresh_rate: 10 diff --git a/examples/asr/conf/conformer/conformer_transducer_bpe.yaml b/examples/asr/conf/conformer/conformer_transducer_bpe.yaml index 9727336ec..e002b0cce 100644 --- a/examples/asr/conf/conformer/conformer_transducer_bpe.yaml +++ b/examples/asr/conf/conformer/conformer_transducer_bpe.yaml @@ -206,7 +206,6 @@ trainer: accelerator: ddp accumulate_grad_batches: 1 gradient_clip_val: 0.0 - amp_level: O0 # O1/O2 for mixed precision precision: 32 # Should be set to 16 for O1 and O2 to enable the AMP. log_every_n_steps: 10 # Interval of logging. progress_bar_refresh_rate: 10 diff --git a/examples/asr/conf/conformer/conformer_transducer_char.yaml b/examples/asr/conf/conformer/conformer_transducer_char.yaml index bbc1f2f31..3f5d0ec2f 100644 --- a/examples/asr/conf/conformer/conformer_transducer_char.yaml +++ b/examples/asr/conf/conformer/conformer_transducer_char.yaml @@ -201,7 +201,6 @@ trainer: accelerator: ddp accumulate_grad_batches: 1 gradient_clip_val: 0.0 - amp_level: O0 # O1/O2 for mixed precision precision: 32 # Should be set to 16 for O1 and O2 to enable the AMP. log_every_n_steps: 10 # Interval of logging. progress_bar_refresh_rate: 10 diff --git a/examples/asr/experimental/wav2vec/configs/wav2vec_pretrain.yaml b/examples/asr/experimental/wav2vec/configs/wav2vec_pretrain.yaml index 00268fcef..3a2708da9 100644 --- a/examples/asr/experimental/wav2vec/configs/wav2vec_pretrain.yaml +++ b/examples/asr/experimental/wav2vec/configs/wav2vec_pretrain.yaml @@ -100,7 +100,6 @@ trainer: accelerator: ddp accumulate_grad_batches: 1 gradient_clip_val: 0.0 - amp_level: O0 # O1/O2 for mixed precision precision: 32 # Should be set to 16 for O1 and O2 to enable the AMP. log_every_n_steps: 10 # Interval of logging. resume_from_checkpoint: null # The path to a checkpoint file to continue the training, restores the whole state including the epoch, step, LR schedulers, apex, etc. diff --git a/examples/asr/experimental/wav2vec/configs/wav2vec_pretrain_large.yaml b/examples/asr/experimental/wav2vec/configs/wav2vec_pretrain_large.yaml index 4a92a281c..030467f67 100644 --- a/examples/asr/experimental/wav2vec/configs/wav2vec_pretrain_large.yaml +++ b/examples/asr/experimental/wav2vec/configs/wav2vec_pretrain_large.yaml @@ -100,7 +100,6 @@ trainer: accelerator: ddp accumulate_grad_batches: 1 gradient_clip_val: 0.0 - amp_level: O0 # O1/O2 for mixed precision precision: 32 # Should be set to 16 for O1 and O2 to enable the AMP. log_every_n_steps: 10 # Interval of logging. resume_from_checkpoint: null # The path to a checkpoint file to continue the training, restores the whole state including the epoch, step, LR schedulers, apex, etc. diff --git a/examples/nlp/dialogue_state_tracking/conf/sgdqa_config.yaml b/examples/nlp/dialogue_state_tracking/conf/sgdqa_config.yaml index 7ecc5652a..8aa7529cf 100644 --- a/examples/nlp/dialogue_state_tracking/conf/sgdqa_config.yaml +++ b/examples/nlp/dialogue_state_tracking/conf/sgdqa_config.yaml @@ -22,7 +22,6 @@ trainer: max_steps: null # precedence over max_epochs accumulate_grad_batches: 1 # accumulates grads every k batches gradient_clip_val: 1.0 - amp_level: O1 # O1/O2 for mixed precision precision: 16 # Should be set to 16 for O1 and O2 to enable the AMP. accelerator: ddp log_every_n_steps: 5 # Interval of logging. diff --git a/examples/nlp/duplex_text_normalization/conf/duplex_tn_config.yaml b/examples/nlp/duplex_text_normalization/conf/duplex_tn_config.yaml index 4e5a0f993..aadb4a66e 100644 --- a/examples/nlp/duplex_text_normalization/conf/duplex_tn_config.yaml +++ b/examples/nlp/duplex_text_normalization/conf/duplex_tn_config.yaml @@ -15,7 +15,6 @@ tagger_trainer: logger: false # provided by exp_manager accumulate_grad_batches: 1 # accumulates grads every k batches gradient_clip_val: 0.0 - amp_level: O0 # O1/O2 for mixed precision precision: 32 # Should be set to 16 for O1 and O2 to enable the AMP. accelerator: ddp @@ -66,7 +65,6 @@ decoder_trainer: logger: false # provided by exp_manager accumulate_grad_batches: 1 # accumulates grads every k batches gradient_clip_val: 0.0 - amp_level: O0 # O1/O2 for mixed precision precision: 32 # Should be set to 16 for O1 and O2 to enable the AMP. accelerator: ddp log_every_n_steps: 1 # Interval of logging. diff --git a/examples/nlp/entity_linking/conf/tiny_example_entity_linking_config.yaml b/examples/nlp/entity_linking/conf/tiny_example_entity_linking_config.yaml index 36a027306..5ec5e0aec 100644 --- a/examples/nlp/entity_linking/conf/tiny_example_entity_linking_config.yaml +++ b/examples/nlp/entity_linking/conf/tiny_example_entity_linking_config.yaml @@ -7,7 +7,6 @@ trainer: max_steps: null accumulate_grad_batches: 1 precision: 16 - amp_level: O1 accelerator: ddp gradient_clip_val: 0.0 log_every_n_steps: 1 diff --git a/examples/nlp/entity_linking/conf/umls_medical_entity_linking_config.yaml b/examples/nlp/entity_linking/conf/umls_medical_entity_linking_config.yaml index 5b6709e72..c404e3682 100644 --- a/examples/nlp/entity_linking/conf/umls_medical_entity_linking_config.yaml +++ b/examples/nlp/entity_linking/conf/umls_medical_entity_linking_config.yaml @@ -7,7 +7,6 @@ trainer: max_steps: null accumulate_grad_batches: 1 precision: 16 - amp_level: O1 accelerator: ddp gradient_clip_val: 0.0 log_every_n_steps: 1 diff --git a/examples/nlp/glue_benchmark/glue_benchmark_config.yaml b/examples/nlp/glue_benchmark/glue_benchmark_config.yaml index 23b5522ae..47fccb78e 100644 --- a/examples/nlp/glue_benchmark/glue_benchmark_config.yaml +++ b/examples/nlp/glue_benchmark/glue_benchmark_config.yaml @@ -7,7 +7,6 @@ trainer: max_epochs: 3 max_steps: null # precedence over max_epochs accumulate_grad_batches: 1 # accumulates grads every k batches - amp_level: O0 # O1/O2 for mixed precision precision: 16 accelerator: ddp checkpoint_callback: False # Provided by exp_manager diff --git a/examples/nlp/information_retrieval/conf/bert_ir_config.yaml b/examples/nlp/information_retrieval/conf/bert_ir_config.yaml index e4fe72b34..16d8c0359 100644 --- a/examples/nlp/information_retrieval/conf/bert_ir_config.yaml +++ b/examples/nlp/information_retrieval/conf/bert_ir_config.yaml @@ -7,7 +7,6 @@ trainer: max_steps: null # precedence over max_epochs accumulate_grad_batches: 1 # accumulates grads every k batches precision: 16 # 16 to use AMP - amp_level: O1 # O1 or O2 if using AMP accelerator: ddp log_every_n_steps: 1 # Interval of logging. val_check_interval: 0.05 # Set to 0.25 to check 4 times per epoch, or an int for number of iterations diff --git a/examples/nlp/intent_slot_classification/conf/intent_slot_classification_config.yaml b/examples/nlp/intent_slot_classification/conf/intent_slot_classification_config.yaml index 502458c1d..3d405615c 100644 --- a/examples/nlp/intent_slot_classification/conf/intent_slot_classification_config.yaml +++ b/examples/nlp/intent_slot_classification/conf/intent_slot_classification_config.yaml @@ -6,7 +6,6 @@ trainer: max_epochs: 50 max_steps: null # precedence over max_epochs accumulate_grad_batches: 1 # accumulates grads every k batches - amp_level: O0 # O1/O2 for mixed precision precision: 32 # Should be set to 16 for O1 and O2 amp_level to enable the AMP. accelerator: ddp log_every_n_steps: 1 # Interval of logging. diff --git a/examples/nlp/language_modeling/conf/bert_pretraining_from_preprocessed_config.yaml b/examples/nlp/language_modeling/conf/bert_pretraining_from_preprocessed_config.yaml index 5f8dd7d23..d4e005ced 100644 --- a/examples/nlp/language_modeling/conf/bert_pretraining_from_preprocessed_config.yaml +++ b/examples/nlp/language_modeling/conf/bert_pretraining_from_preprocessed_config.yaml @@ -8,7 +8,6 @@ trainer: replace_sampler_ddp: false # needed for bert pretraining from preproc accumulate_grad_batches: 1 # accumulates grads every k batches precision: 16 # 16 to use AMP - amp_level: O1 # O1 or O2 if using AMP accelerator: ddp gradient_clip_val: 1.0 log_every_n_steps: 1 diff --git a/examples/nlp/language_modeling/conf/bert_pretraining_from_text_config.yaml b/examples/nlp/language_modeling/conf/bert_pretraining_from_text_config.yaml index da5a03d86..454b9e6ba 100644 --- a/examples/nlp/language_modeling/conf/bert_pretraining_from_text_config.yaml +++ b/examples/nlp/language_modeling/conf/bert_pretraining_from_text_config.yaml @@ -7,7 +7,6 @@ trainer: max_steps: null # precedence over max_epochs accumulate_grad_batches: 1 # accumulates grads every k batches precision: 16 # 16 to use AMP - amp_level: O1 # O1 or O2 if using AMP accelerator: ddp gradient_clip_val: 0.0 log_every_n_steps: 1 diff --git a/examples/nlp/language_modeling/conf/megatron_gpt_config.yaml b/examples/nlp/language_modeling/conf/megatron_gpt_config.yaml new file mode 100644 index 000000000..14c1cf65f --- /dev/null +++ b/examples/nlp/language_modeling/conf/megatron_gpt_config.yaml @@ -0,0 +1,117 @@ +name: megatron_gpt +restore_from_path: null # used when starting from a .nemo file + +trainer: + gpus: 1 + num_nodes: 1 + accelerator: ddp + precision: 16 + logger: False # logger provided by exp_manager + checkpoint_callback: False + replace_sampler_ddp: False + max_epochs: null + max_steps: 100000 # consumed_samples = global_step * micro_batch_size * data_parallel_size * accumulate_grad_batches + log_every_n_steps: 10 + val_check_interval: 100 + limit_val_batches: 50 + limit_test_batches: 500 + accumulate_grad_batches: 1 + gradient_clip_val: 1.0 + +exp_manager: + explicit_log_dir: null + exp_dir: null + name: megatron_gpt + create_wandb_logger: True + wandb_logger_kwargs: + project: null + name: null + resume_if_exists: True + resume_ignore_no_checkpoint: True + create_checkpoint_callback: True + checkpoint_callback_params: + monitor: val_loss + save_top_k: 10 + mode: min + always_save_nemo: False # saves nemo file during validation, not implemented for model parallel + filename: 'megatron_gpt--{val_loss:.2f}-{step}-{consumed_samples}' + + +model: + # model parallelism + micro_batch_size: 4 + tensor_model_parallel_size: 1 + + # model architecture + encoder_seq_length: 512 + max_position_embeddings: 512 + num_layers: 12 + hidden_size: 768 + ffn_hidden_size: 3072 # Transformer FFN hidden size. Usually 4 * hidden_size. + num_attention_heads: 12 + init_method_std: 0.02 # Standard deviation of the zero mean normal distribution used for weight initialization.') + hidden_dropout: 0.1 # Dropout probability for hidden state transformer. + kv_channels: null # Projection weights dimension in multi-head attention. Set to hidden_size // num_attention_heads if null + apply_query_key_layer_scaling: True # scale Q * K^T by 1 / layer-number. + layernorm_epsilon: 1e-5 + make_vocab_size_divisible_by: 128 # Pad the vocab size to be divisible by this value for computation efficiency. + pre_process: True # add embedding + post_process: True # add pooler + + tokenizer: + library: 'megatron' + type: 'GPT2BPETokenizer' + model: null + vocab_file: null + merge_file: null + + # precision + native_amp_init_scale: 4294967296 # 2 ** 32 + native_amp_growth_interval: 1000 + fused_fp16: True # False if using fp32 or bf16 + fused_bf16: False # True if using bf16 + fp32_residual_connection: False # Move residual connections to fp32 + fp16_lm_cross_entropy: False # Move the cross entropy unreduced loss calculation for lm head to fp16 + + + # miscellaneous + seed: 1234 + use_cpu_initialization: False # Init weights on the CPU (slow for large models) + onnx_safe: False # Use work-arounds for known problems with Torch ONNX exporter. + + # not implemented in NeMo yet + activations_checkpoint_method: null # 'uniform', 'block' + activations_checkpoint_num_layers: 1 + + data: + # Path to data must be specified by the user. + # can override from the CLI: "model.data.data_prefix=[.5,/raid/data/pile/my-gpt3_00_text_document,.5,/raid/data/pile/my-gpt3_01_text_document]", + # Or see example below: + # data_prefix: + # - .5 + # - /raid/data/pile/my-gpt3_00_text_document + # - .5 + # - /raid/data/pile/my-gpt3_01_text_document + data_prefix: ??? + data_impl: mmap + splits_string: 900,50,50 + seq_length: 1024 + skip_warmup: True + num_workers: 0 + dataloader_type: single # cyclic + reset_position_ids: False # Reset position ids after end-of-document token + reset_attention_mask: False # Reset attention mask after end-of-document token + eod_mask_loss: False # Mask loss for the end of document tokens + + optim: + name: adam + lr: 2e-4 + weight_decay: 0.01 + betas: + - 0.9 + - 0.98 + sched: + name: CosineAnnealing + warmup_steps: 500 + constant_steps: 50000 + min_lr: 2e-5 diff --git a/examples/nlp/language_modeling/conf/transformer_lm_config.yaml b/examples/nlp/language_modeling/conf/transformer_lm_config.yaml index 044ffdaa3..1e6b92c7f 100644 --- a/examples/nlp/language_modeling/conf/transformer_lm_config.yaml +++ b/examples/nlp/language_modeling/conf/transformer_lm_config.yaml @@ -88,7 +88,6 @@ trainer: gpus: 4 num_nodes: 1 max_epochs: 200 - amp_level: O2 # O1/O2 for mixed precision precision: 16 # Should be set to 16 for O1 and O2, default is 16 as PT ignores it when am_level is O0 accelerator: ddp checkpoint_callback: False diff --git a/examples/nlp/language_modeling/megatron_gpt_ckpt_to_nemo.py b/examples/nlp/language_modeling/megatron_gpt_ckpt_to_nemo.py new file mode 100644 index 000000000..5478c18f8 --- /dev/null +++ b/examples/nlp/language_modeling/megatron_gpt_ckpt_to_nemo.py @@ -0,0 +1,80 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import os +from argparse import ArgumentParser + +import torch.multiprocessing as mp +from pytorch_lightning.trainer.trainer import Trainer + +from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel +from nemo.collections.nlp.parts.nlp_overrides import NLPSaveRestoreConnector +from nemo.utils import AppState, logging + + +def get_args(): + parser = ArgumentParser() + parser.add_argument( + "--checkpoint_folder", + type=str, + default=None, + required=True, + help="Path to PTL checkpoints saved during training. Ex: /raid/nemo_experiments/megatron_gpt/checkpoints", + ) + parser.add_argument( + "--checkpoint_name", + type=str, + default=None, + required=True, + help="Name of checkpoint to be used. Ex: megatron_gpt--val_loss=6.34-step=649-last.ckpt", + ) + + parser.add_argument( + "--hparams_file", + type=str, + default=None, + required=False, + help="Path config for restoring. It's created during training and may need to be modified during restore if restore environment is different than training. Ex: /raid/nemo_experiments/megatron_gpt/hparams.yaml", + ) + parser.add_argument("--nemo_file_path", type=str, default=None, required=True, help="Path to output .nemo file.") + + parser.add_argument("--tensor_model_parallel_size", type=int, required=True, default=None) + + args = parser.parse_args() + return args + + +def convert(rank, world_size, args): + + app_state = AppState() + app_state.data_parallel_rank = 0 + trainer = Trainer(gpus=args.tensor_model_parallel_size) + # TODO: reach out to PTL For an API-safe local rank override + trainer.accelerator.training_type_plugin._local_rank = rank + checkpoint_path = os.path.join(args.checkpoint_folder, f'mp_rank_{rank:02d}', args.checkpoint_name) + model = MegatronGPTModel.load_from_checkpoint(checkpoint_path, hparams_file=args.hparams_file, trainer=trainer) + model._save_restore_connector = NLPSaveRestoreConnector() + model.save_to(args.nemo_file_path) + logging.info(f'NeMo model saved to: {args.nemo_file_path}') + + +def main() -> None: + args = get_args() + world_size = args.tensor_model_parallel_size + mp.spawn(convert, args=(world_size, args), nprocs=world_size, join=True) + + +if __name__ == '__main__': + main() # noqa pylint: disable=no-value-for-parameter diff --git a/examples/nlp/language_modeling/megatron_gpt_eval.py b/examples/nlp/language_modeling/megatron_gpt_eval.py new file mode 100644 index 000000000..563921f9b --- /dev/null +++ b/examples/nlp/language_modeling/megatron_gpt_eval.py @@ -0,0 +1,90 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from argparse import ArgumentParser + +import torch +from pytorch_lightning.trainer.trainer import Trainer +from torch.utils.data import DataLoader + +from nemo.collections.nlp.data.language_modeling.megatron.gpt_request_dataset import GPTRequestDataset +from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel +from nemo.collections.nlp.modules.common.megatron.megatron_utils import compute_model_parallel_rank +from nemo.collections.nlp.parts.nlp_overrides import NLPDDPPlugin +from nemo.utils import logging +from nemo.utils.app_state import AppState + +assert torch.cuda.is_available() + + +def main(): + parser = ArgumentParser() + parser.add_argument("--model_file", type=str, default="", required=True, help="Pass path to model's .nemo file") + parser.add_argument( + "--prompt", type=str, default="", required=True, help="Prompt for the model (a text to complete)" + ) + parser.add_argument( + "--tokens_to_generate", type=int, default="64", required=False, help="How many tokens to add to prompt" + ) + parser.add_argument( + "--stop_after_sentence", + type=bool, + default="True", + required=False, + help="True/False: whether to stop after full sentence has been generated.", + ) + parser.add_argument( + "--tensor_model_parallel_size", type=int, default=1, required=True, + ) + parser.add_argument("--precision", default=32, help="PyTorch Lightning Trainer precision flag") + + args = parser.parse_args() + + # cast precision to int if 32 or 16 + if args.precision in ["32", "16"]: + args.precision = int(float(args.precision)) + + # trainer required for restoring model parallel models + trainer = Trainer(plugins=NLPDDPPlugin(), gpus=args.tensor_model_parallel_size, precision=args.precision) + + app_state = AppState() + if args.tensor_model_parallel_size is not None and args.tensor_model_parallel_size > 1: + app_state.model_parallel_size = args.tensor_model_parallel_size + app_state.model_parallel_rank = compute_model_parallel_rank(trainer.local_rank, app_state.model_parallel_size) + + model = MegatronGPTModel.restore_from(restore_path=args.model_file, trainer=trainer) + + model.freeze() + + request = { + "prompt": args.prompt, + "tokens_to_generate": args.tokens_to_generate, + "stop_after_sentence": args.stop_after_sentence, + } + + dataset = GPTRequestDataset(request, model.tokenizer) + + request_dl = DataLoader(dataset) + + response = trainer.predict(model, request_dl) + + print("***************************") + print(response[0]['completion']['text']) + print("***************************") + logging.info(f"Generation stopped because: {response[0]['completion']['stop reason']}") + + +if __name__ == '__main__': + main() # noqa pylint: disable=no-value-for-parameter diff --git a/examples/nlp/language_modeling/megatron_gpt_pretraining.py b/examples/nlp/language_modeling/megatron_gpt_pretraining.py new file mode 100644 index 000000000..c1c850b7f --- /dev/null +++ b/examples/nlp/language_modeling/megatron_gpt_pretraining.py @@ -0,0 +1,79 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pathlib import Path + +from omegaconf.omegaconf import OmegaConf +from pytorch_lightning import Trainer + +from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel +from nemo.collections.nlp.modules.common.megatron.megatron_utils import compute_model_parallel_rank +from nemo.collections.nlp.parts.nlp_overrides import ( + NLPCheckpointConnector, + NLPDDPPlugin, + NLPNativeBfloat16PrecisionPlugin, + NLPNativeMixedPrecisionPlugin, + NLPPrecisionPlugin, +) +from nemo.core.config import hydra_runner +from nemo.utils import logging +from nemo.utils.exp_manager import exp_manager + + +@hydra_runner(config_path="conf", config_name="megatron_gpt_config") +def main(cfg) -> None: + logging.info("\n\n************** Experiment configuration ***********") + logging.info(f'\n{OmegaConf.to_yaml(cfg)}') + + if cfg.trainer.precision == 16: + trainer = Trainer( + plugins=[ + NLPDDPPlugin(num_nodes=cfg.trainer.num_nodes), + NLPNativeMixedPrecisionPlugin( + init_scale=cfg.model.get('native_amp_init_scale', 2 ** 32), + growth_interval=cfg.model.get('native_amp_growth_interval', 1000), + ), + ], + **cfg.trainer, + ) + elif cfg.trainer.precision == 'bf16': + trainer = Trainer( + plugins=[NLPDDPPlugin(num_nodes=cfg.trainer.num_nodes), NLPNativeBfloat16PrecisionPlugin(),], + **cfg.trainer, + ) + else: + trainer = Trainer(plugins=[NLPDDPPlugin(num_nodes=cfg.trainer.num_nodes), NLPPrecisionPlugin()], **cfg.trainer) + + exp_manager(trainer, cfg.exp_manager) + + # update resume from checkpoint found by exp_manager + resume_from_checkpoint = trainer.resume_from_checkpoint + if resume_from_checkpoint is not None: + mp_rank = compute_model_parallel_rank(trainer.local_rank, cfg.model.tensor_model_parallel_size) + resume_from_checkpoint = Path(resume_from_checkpoint) + resume_from_checkpoint = resume_from_checkpoint.parent.parent.joinpath(f'mp_rank_{mp_rank:02d}').joinpath( + resume_from_checkpoint.name + ) + resume_from_checkpoint = str(resume_from_checkpoint) + logging.info(f'Resuming training from checkpoint: {resume_from_checkpoint}') + + trainer.checkpoint_connector = NLPCheckpointConnector(trainer, resume_from_checkpoint=resume_from_checkpoint) + + model = MegatronGPTModel(cfg.model, trainer) + + trainer.fit(model) + + +if __name__ == '__main__': + main() diff --git a/examples/nlp/language_modeling/megatron_gpt_test.py b/examples/nlp/language_modeling/megatron_gpt_test.py new file mode 100644 index 000000000..6a927223c --- /dev/null +++ b/examples/nlp/language_modeling/megatron_gpt_test.py @@ -0,0 +1,72 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from omegaconf.omegaconf import OmegaConf +from pytorch_lightning import Trainer + +from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel +from nemo.collections.nlp.modules.common.megatron.megatron_utils import compute_model_parallel_rank +from nemo.collections.nlp.parts.nlp_overrides import ( + NLPDDPPlugin, + NLPNativeMixedPrecisionPlugin, + NLPPrecisionPlugin, + NLPSaveRestoreConnector, +) +from nemo.core.config import hydra_runner +from nemo.utils import logging +from nemo.utils.app_state import AppState + + +@hydra_runner(config_path="conf", config_name="megatron_gpt_config") +def main(cfg) -> None: + logging.info("\n\n************** Experiment configuration ***********") + logging.info(f'\n{OmegaConf.to_yaml(cfg)}') + + trainer = None + if cfg.trainer.precision == 16: + trainer = Trainer( + plugins=[ + NLPDDPPlugin(num_nodes=cfg.trainer.num_nodes), + NLPNativeMixedPrecisionPlugin( + init_scale=cfg.model.get('native_amp_init_scale', 2 ** 32), + growth_interval=cfg.model.get('native_amp_growth_interval', 1000), + ), + ], + **cfg.trainer, + ) + elif cfg.trainer.precision == 'bf16': + trainer = Trainer( + plugins=[NLPDDPPlugin(num_nodes=cfg.trainer.num_nodes), NLPNativeBfloat16PrecisionPlugin(),], + **cfg.trainer, + ) + else: + trainer = Trainer(plugins=[NLPDDPPlugin(num_nodes=cfg.trainer.num_nodes), NLPPrecisionPlugin()], **cfg.trainer) + + app_state = AppState() + app_state.model_parallel_size = cfg.model.tensor_model_parallel_size + app_state.model_parallel_rank = compute_model_parallel_rank(trainer.local_rank, app_state.model_parallel_size) + + model = MegatronGPTModel.restore_from( + cfg.restore_from_path, trainer=trainer, save_restore_connector=NLPSaveRestoreConnector(), + ) + + # Note: most nemo models must have the data paths configured before instantiating the model + # MegatronGPTMOdel sets up the data in the PTL method .setup which happens after DDP spawns. + model.cfg.data.splits_string = cfg.model.data.splits_string + + trainer.test(model) + + +if __name__ == '__main__': + main() diff --git a/examples/nlp/machine_translation/conf/aayn_base.yaml b/examples/nlp/machine_translation/conf/aayn_base.yaml index 5c8d59a08..03a6295e5 100644 --- a/examples/nlp/machine_translation/conf/aayn_base.yaml +++ b/examples/nlp/machine_translation/conf/aayn_base.yaml @@ -145,7 +145,6 @@ trainer: gpus: 4 num_nodes: 1 max_epochs: 200 - amp_level: O2 # O1/O2 for mixed precision precision: 16 # Should be set to 16 for O1 and O2, default is 16 as PT ignores it when am_level is O0 accelerator: ddp checkpoint_callback: False diff --git a/examples/nlp/machine_translation/conf/aayn_bottleneck.yaml b/examples/nlp/machine_translation/conf/aayn_bottleneck.yaml index 6bf0414f4..6dad50175 100644 --- a/examples/nlp/machine_translation/conf/aayn_bottleneck.yaml +++ b/examples/nlp/machine_translation/conf/aayn_bottleneck.yaml @@ -172,7 +172,6 @@ trainer: gpus: 4 num_nodes: 1 max_epochs: 200 - amp_level: O2 # O1/O2 for mixed precision precision: 16 # Should be set to 16 for O1 and O2, default is 16 as PT ignores it when am_level is O0 accelerator: ddp checkpoint_callback: False diff --git a/examples/nlp/machine_translation/conf/huggingface.yaml b/examples/nlp/machine_translation/conf/huggingface.yaml index c7f432b6c..0b95e9c59 100644 --- a/examples/nlp/machine_translation/conf/huggingface.yaml +++ b/examples/nlp/machine_translation/conf/huggingface.yaml @@ -119,7 +119,6 @@ trainer: gpus: 4 num_nodes: 1 max_epochs: 200 - amp_level: O2 # O1/O2 for mixed precision precision: 16 # Should be set to 16 for O1 and O2, default is 16 as PT ignores it when am_level is O0 accelerator: ddp checkpoint_callback: False diff --git a/examples/nlp/machine_translation/conf/megatron.yaml b/examples/nlp/machine_translation/conf/megatron.yaml index c17799284..56007172d 100644 --- a/examples/nlp/machine_translation/conf/megatron.yaml +++ b/examples/nlp/machine_translation/conf/megatron.yaml @@ -147,7 +147,6 @@ trainer: gpus: 4 num_nodes: 1 max_epochs: 200 - amp_level: O2 # O1/O2 for mixed precision precision: 16 # Should be set to 16 for O1 and O2, default is 16 as PT ignores it when am_level is O0 accelerator: ddp checkpoint_callback: False diff --git a/examples/nlp/question_answering/conf/question_answering_squad_config.yaml b/examples/nlp/question_answering/conf/question_answering_squad_config.yaml index 929409d53..328a08bb3 100644 --- a/examples/nlp/question_answering/conf/question_answering_squad_config.yaml +++ b/examples/nlp/question_answering/conf/question_answering_squad_config.yaml @@ -10,7 +10,6 @@ trainer: max_steps: null # precedence over max_epochs accumulate_grad_batches: 1 # accumulates grads every k batches precision: 16 # 16 to use AMP - amp_level: O1 # O1 or O2 if using AMP accelerator: ddp gradient_clip_val: 0.0 val_check_interval: 1.0 # check once per epoch .25 for 4 times per epoch diff --git a/examples/nlp/text_classification/conf/text_classification_config.yaml b/examples/nlp/text_classification/conf/text_classification_config.yaml index 8470991ea..f05b5e0fa 100644 --- a/examples/nlp/text_classification/conf/text_classification_config.yaml +++ b/examples/nlp/text_classification/conf/text_classification_config.yaml @@ -21,7 +21,6 @@ trainer: max_steps: null # precedence over max_epochs accumulate_grad_batches: 1 # accumulates grads every k batches gradient_clip_val: 0.0 - amp_level: O0 # O1/O2 for mixed precision precision: 32 # Should be set to 16 for O1 and O2 to enable the AMP. accelerator: ddp log_every_n_steps: 1 # Interval of logging. diff --git a/examples/nlp/token_classification/conf/punctuation_capitalization_config.yaml b/examples/nlp/token_classification/conf/punctuation_capitalization_config.yaml index 0d2790775..729784af3 100644 --- a/examples/nlp/token_classification/conf/punctuation_capitalization_config.yaml +++ b/examples/nlp/token_classification/conf/punctuation_capitalization_config.yaml @@ -24,7 +24,6 @@ trainer: max_steps: null # precedence over max_epochs accumulate_grad_batches: 1 # accumulates grads every k batches gradient_clip_val: 0.0 - amp_level: O0 # O1/O2 for mixed precision precision: 16 # Should be set to 16 for O1 and O2, default is 16 as PT ignores it when am_level is O0 accelerator: ddp checkpoint_callback: false # Provided by exp_manager diff --git a/examples/nlp/token_classification/conf/token_classification_config.yaml b/examples/nlp/token_classification/conf/token_classification_config.yaml index 8cb162a60..09c9d84d7 100644 --- a/examples/nlp/token_classification/conf/token_classification_config.yaml +++ b/examples/nlp/token_classification/conf/token_classification_config.yaml @@ -23,7 +23,6 @@ trainer: max_steps: null # precedence over max_epochs accumulate_grad_batches: 1 # accumulates grads every k batches gradient_clip_val: 0.0 - amp_level: O0 # O1/O2 for mixed precision precision: 16 # Should be set to 16 for O1 and O2, default is 16 as PT ignores it when am_level is O0 accelerator: ddp checkpoint_callback: False # Provided by exp_manager diff --git a/examples/nlp/token_classification/punctuation_capitalization_evaluate.py b/examples/nlp/token_classification/punctuation_capitalization_evaluate.py index b91b332a8..3cd5e57ad 100644 --- a/examples/nlp/token_classification/punctuation_capitalization_evaluate.py +++ b/examples/nlp/token_classification/punctuation_capitalization_evaluate.py @@ -62,13 +62,7 @@ def main(cfg: DictConfig) -> None: else: gpu = 1 if cfg.trainer.gpus != 0 else 0 - trainer = pl.Trainer( - gpus=gpu, - precision=cfg.trainer.precision, - amp_level=cfg.trainer.amp_level, - logger=False, - checkpoint_callback=False, - ) + trainer = pl.Trainer(gpus=gpu, precision=cfg.trainer.precision, logger=False, checkpoint_callback=False,) exp_dir = exp_manager(trainer, cfg.exp_manager) if not cfg.pretrained_model: diff --git a/examples/nlp/token_classification/token_classification_evaluate.py b/examples/nlp/token_classification/token_classification_evaluate.py index 4a888f50d..3a03e6270 100644 --- a/examples/nlp/token_classification/token_classification_evaluate.py +++ b/examples/nlp/token_classification/token_classification_evaluate.py @@ -70,13 +70,7 @@ def main(cfg: DictConfig) -> None: else: gpu = 1 if cfg.trainer.gpus != 0 else 0 - trainer = pl.Trainer( - gpus=gpu, - precision=cfg.trainer.precision, - amp_level=cfg.trainer.amp_level, - logger=False, - checkpoint_callback=False, - ) + trainer = pl.Trainer(gpus=gpu, precision=cfg.trainer.precision, logger=False, checkpoint_callback=False,) exp_dir = exp_manager(trainer, cfg.exp_manager) if not cfg.pretrained_model: diff --git a/examples/nlp/zero_shot_intent_recognition/conf/zero_shot_intent_config.yaml b/examples/nlp/zero_shot_intent_recognition/conf/zero_shot_intent_config.yaml index 4be1636f2..d07ad249e 100644 --- a/examples/nlp/zero_shot_intent_recognition/conf/zero_shot_intent_config.yaml +++ b/examples/nlp/zero_shot_intent_recognition/conf/zero_shot_intent_config.yaml @@ -19,7 +19,6 @@ trainer: max_epochs: 1 max_steps: null # precedence over max_epochs accumulate_grad_batches: 1 # accumulates grads every k batches - amp_level: O0 # O1/O2 for mixed precision precision: 16 accelerator: ddp log_every_n_steps: 1 # Interval of logging. diff --git a/examples/speaker_tasks/recognition/conf/SpeakerNet_recognition_3x2x512.yaml b/examples/speaker_tasks/recognition/conf/SpeakerNet_recognition_3x2x512.yaml index b07cb0277..fc1e9ac58 100644 --- a/examples/speaker_tasks/recognition/conf/SpeakerNet_recognition_3x2x512.yaml +++ b/examples/speaker_tasks/recognition/conf/SpeakerNet_recognition_3x2x512.yaml @@ -143,7 +143,6 @@ trainer: num_nodes: 1 accelerator: ddp accumulate_grad_batches: 1 - amp_level: O0 deterministic: True checkpoint_callback: False logger: False diff --git a/examples/speaker_tasks/recognition/conf/SpeakerNet_verification_3x2x256.yaml b/examples/speaker_tasks/recognition/conf/SpeakerNet_verification_3x2x256.yaml index 9ce23f1f8..27811862d 100644 --- a/examples/speaker_tasks/recognition/conf/SpeakerNet_verification_3x2x256.yaml +++ b/examples/speaker_tasks/recognition/conf/SpeakerNet_verification_3x2x256.yaml @@ -131,7 +131,6 @@ trainer: num_nodes: 1 accelerator: ddp accumulate_grad_batches: 1 - amp_level: O0 deterministic: True checkpoint_callback: False logger: False diff --git a/examples/speaker_tasks/recognition/conf/ecapa_tdnn.yaml b/examples/speaker_tasks/recognition/conf/ecapa_tdnn.yaml index b3763927e..d990a68d8 100644 --- a/examples/speaker_tasks/recognition/conf/ecapa_tdnn.yaml +++ b/examples/speaker_tasks/recognition/conf/ecapa_tdnn.yaml @@ -93,7 +93,6 @@ trainer: num_nodes: 1 accelerator: ddp accumulate_grad_batches: 1 - amp_level: O0 deterministic: False checkpoint_callback: False logger: False diff --git a/examples/tts/conf/squeezewave.yaml b/examples/tts/conf/squeezewave.yaml index 9fc620f40..befecd00f 100644 --- a/examples/tts/conf/squeezewave.yaml +++ b/examples/tts/conf/squeezewave.yaml @@ -95,7 +95,6 @@ trainer: log_every_n_steps: 200 check_val_every_n_epoch: 25 precision: 16 - amp_level: O1 exp_manager: exp_dir: null diff --git a/nemo/collections/common/callbacks/callbacks.py b/nemo/collections/common/callbacks/callbacks.py index 3d7882e49..a0245b866 100644 --- a/nemo/collections/common/callbacks/callbacks.py +++ b/nemo/collections/common/callbacks/callbacks.py @@ -29,7 +29,7 @@ class LogEpochTimeCallback(Callback): self.epoch_start = time.time() @rank_zero_only - def on_train_epoch_end(self, trainer, pl_module, outputs): + def on_train_epoch_end(self, trainer, pl_module): curr_time = time.time() duration = curr_time - self.epoch_start trainer.logger.log_metrics({"epoch_time": duration}, step=trainer.global_step) diff --git a/nemo/collections/common/parts/ptl_overrides.py b/nemo/collections/common/parts/ptl_overrides.py new file mode 100644 index 000000000..19f2ac2b1 --- /dev/null +++ b/nemo/collections/common/parts/ptl_overrides.py @@ -0,0 +1,23 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from pytorch_lightning.plugins.precision.native_amp import NativeMixedPrecisionPlugin + + +class NeMoNativeMixedPrecisionPlugin(NativeMixedPrecisionPlugin): + def __init__(self, init_scale: float = 2 ** 32, growth_interval: int = 1000) -> None: + super().__init__(precision=16) + + self.scaler = torch.cuda.amp.GradScaler(init_scale=init_scale, growth_interval=growth_interval) diff --git a/nemo/collections/common/tokenizers/huggingface/auto_tokenizer.py b/nemo/collections/common/tokenizers/huggingface/auto_tokenizer.py index f94d150bc..a4c784605 100644 --- a/nemo/collections/common/tokenizers/huggingface/auto_tokenizer.py +++ b/nemo/collections/common/tokenizers/huggingface/auto_tokenizer.py @@ -33,6 +33,7 @@ class AutoTokenizer(TokenizerSpec): self, pretrained_model_name: str, vocab_file: Optional[str] = None, + merges_file: Optional[str] = None, mask_token: Optional[str] = None, bos_token: Optional[str] = None, eos_token: Optional[str] = None, @@ -60,19 +61,26 @@ class AutoTokenizer(TokenizerSpec): use_fast: whether to use fast HuggingFace tokenizer """ try: - if vocab_file is not None: - message = 'Using "slow" HuggingFace tokenizer' - if use_fast: - message += f'{vocab_file} is ignored in "fast" tokenizers, using a "slow" version' + # this logic deals with different huggingface tokenizers having different positional args + if vocab_file is None: self.tokenizer = AUTOTOKENIZER.from_pretrained( - pretrained_model_name_or_path=pretrained_model_name, vocab_file=vocab_file, use_fast=False + pretrained_model_name_or_path=pretrained_model_name, use_fast=use_fast, + ) + elif merges_file is None: + self.tokenizer = AUTOTOKENIZER.from_pretrained( + pretrained_model_name_or_path=pretrained_model_name, vocab_file=vocab_file, use_fast=use_fast, ) else: self.tokenizer = AUTOTOKENIZER.from_pretrained( - pretrained_model_name_or_path=pretrained_model_name, use_fast=use_fast + pretrained_model_name_or_path=pretrained_model_name, + vocab_file=vocab_file, + merges_file=merges_file, + use_fast=use_fast, ) except Exception as e: - raise ValueError(f'{pretrained_model_name} is not supported by HuggingFace. {e}') + raise ValueError( + f'Unable to instantiate HuggingFace AUTOTOKENIZER for {pretrained_model_name}. Exception: {e}' + ) special_tokens_dict = {} diff --git a/nemo/collections/nlp/data/language_modeling/__init__.py b/nemo/collections/nlp/data/language_modeling/__init__.py index 16831be85..de9e60c6b 100644 --- a/nemo/collections/nlp/data/language_modeling/__init__.py +++ b/nemo/collections/nlp/data/language_modeling/__init__.py @@ -17,4 +17,5 @@ from nemo.collections.nlp.data.language_modeling.lm_bert_dataset import ( BertPretrainingDataset, BertPretrainingPreprocessedDataloader, ) +from nemo.collections.nlp.data.language_modeling.megatron import GPTDataset, IndexedDataset, MMapIndexedDataset from nemo.collections.nlp.data.language_modeling.sentence_dataset import SentenceDataset, TarredSentenceDataset diff --git a/nemo/collections/nlp/data/language_modeling/megatron/Makefile b/nemo/collections/nlp/data/language_modeling/megatron/Makefile new file mode 100644 index 000000000..150939026 --- /dev/null +++ b/nemo/collections/nlp/data/language_modeling/megatron/Makefile @@ -0,0 +1,23 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +CXXFLAGS += -O3 -Wall -shared -std=c++11 -fPIC -fdiagnostics-color +CPPFLAGS += $(shell python3 -m pybind11 --includes) +LIBNAME = helpers +LIBEXT = $(shell python3-config --extension-suffix) + +default: $(LIBNAME)$(LIBEXT) + +%$(LIBEXT): %.cpp + $(CXX) $(CXXFLAGS) $(CPPFLAGS) $< -o $@ diff --git a/nemo/collections/nlp/data/language_modeling/megatron/__init__.py b/nemo/collections/nlp/data/language_modeling/megatron/__init__.py new file mode 100644 index 000000000..0062f83a4 --- /dev/null +++ b/nemo/collections/nlp/data/language_modeling/megatron/__init__.py @@ -0,0 +1,20 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from nemo.collections.nlp.data.language_modeling.megatron.gpt_dataset import GPTDataset +from nemo.collections.nlp.data.language_modeling.megatron.indexed_dataset import IndexedDataset, MMapIndexedDataset + +# TODO: refactor these datasets to work without megatron-lm dependency +# from nemo.collections.nlp.data.language_modeling.megatron.bert_dataset import BertDataset +# from nemo.collections.nlp.data.language_modeling.megatron.t5_dataset import T5Dataset diff --git a/nemo/collections/nlp/data/language_modeling/megatron/bert_dataset.py b/nemo/collections/nlp/data/language_modeling/megatron/bert_dataset.py new file mode 100644 index 000000000..49df69601 --- /dev/null +++ b/nemo/collections/nlp/data/language_modeling/megatron/bert_dataset.py @@ -0,0 +1,216 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""BERT Style dataset.""" + +import numpy as np +import torch + +from nemo.collections.nlp.data.language_modeling.megatron.dataset_utils import ( + create_masked_lm_predictions, + create_tokens_and_tokentypes, + get_a_and_b_segments, + get_samples_mapping, + truncate_segments, +) + + +class BertDataset(torch.utils.data.Dataset): + def __init__( + self, + name, + indexed_dataset, + data_prefix, + num_epochs, + max_num_samples, + masked_lm_prob, + max_seq_length, + short_seq_prob, + seed, + binary_head, + tokenizer, + ): + + # Params to store. + self.name = name + self.seed = seed + self.masked_lm_prob = masked_lm_prob + self.max_seq_length = max_seq_length + self.binary_head = binary_head + + # Dataset. + self.indexed_dataset = indexed_dataset + + # Build the samples mapping. + self.samples_mapping = get_samples_mapping( + self.indexed_dataset, + data_prefix, + num_epochs, + max_num_samples, + self.max_seq_length - 3, # account for added tokens + short_seq_prob, + self.seed, + self.name, + self.binary_head, + ) + + # Vocab stuff. + self.vocab_id_list = list(tokenizer.inv_vocab.keys()) + self.vocab_id_to_token_dict = tokenizer.inv_vocab + self.cls_id = tokenizer.cls_id + self.sep_id = tokenizer.sep_id + self.mask_id = tokenizer.mask_id + self.pad_id = tokenizer.pad_id + + def __len__(self): + return self.samples_mapping.shape[0] + + def __getitem__(self, idx): + start_idx, end_idx, seq_length = self.samples_mapping[idx] + sample = [self.indexed_dataset[i] for i in range(start_idx, end_idx)] + # Note that this rng state should be numpy and not python since + # python randint is inclusive whereas the numpy one is exclusive. + # We % 2**32 since numpy requres the seed to be between 0 and 2**32 - 1 + np_rng = np.random.RandomState(seed=((self.seed + idx) % 2 ** 32)) + return build_training_sample( + sample, + seq_length, + self.max_seq_length, # needed for padding + self.vocab_id_list, + self.vocab_id_to_token_dict, + self.cls_id, + self.sep_id, + self.mask_id, + self.pad_id, + self.masked_lm_prob, + np_rng, + self.binary_head, + ) + + +def build_training_sample( + sample, + target_seq_length, + max_seq_length, + vocab_id_list, + vocab_id_to_token_dict, + cls_id, + sep_id, + mask_id, + pad_id, + masked_lm_prob, + np_rng, + binary_head, +): + """Biuld training sample. + + Arguments: + sample: A list of sentences in which each sentence is a list token ids. + target_seq_length: Desired sequence length. + max_seq_length: Maximum length of the sequence. All values are padded to + this length. + vocab_id_list: List of vocabulary ids. Used to pick a random id. + vocab_id_to_token_dict: A dictionary from vocab ids to text tokens. + cls_id: Start of example id. + sep_id: Separator id. + mask_id: Mask token id. + pad_id: Padding token id. + masked_lm_prob: Probability to mask tokens. + np_rng: Random number genenrator. Note that this rng state should be + numpy and not python since python randint is inclusive for + the opper bound whereas the numpy one is exclusive. + """ + + if binary_head: + # We assume that we have at least two sentences in the sample + assert len(sample) > 1 + assert target_seq_length <= max_seq_length + + # Divide sample into two segments (A and B). + if binary_head: + tokens_a, tokens_b, is_next_random = get_a_and_b_segments(sample, np_rng) + else: + tokens_a = [] + for j in range(len(sample)): + tokens_a.extend(sample[j]) + tokens_b = [] + is_next_random = False + + # Truncate to `target_sequence_length`. + max_num_tokens = target_seq_length + truncated = truncate_segments(tokens_a, tokens_b, len(tokens_a), len(tokens_b), max_num_tokens, np_rng) + + # Build tokens and toketypes. + tokens, tokentypes = create_tokens_and_tokentypes(tokens_a, tokens_b, cls_id, sep_id) + + # Masking. + max_predictions_per_seq = masked_lm_prob * max_num_tokens + (tokens, masked_positions, masked_labels, _, _) = create_masked_lm_predictions( + tokens, + vocab_id_list, + vocab_id_to_token_dict, + masked_lm_prob, + cls_id, + sep_id, + mask_id, + max_predictions_per_seq, + np_rng, + ) + + # Padding. + tokens_np, tokentypes_np, labels_np, padding_mask_np, loss_mask_np = pad_and_convert_to_numpy( + tokens, tokentypes, masked_positions, masked_labels, pad_id, max_seq_length + ) + + train_sample = { + 'text': tokens_np, + 'types': tokentypes_np, + 'labels': labels_np, + 'is_random': int(is_next_random), + 'loss_mask': loss_mask_np, + 'padding_mask': padding_mask_np, + 'truncated': int(truncated), + } + return train_sample + + +def pad_and_convert_to_numpy(tokens, tokentypes, masked_positions, masked_labels, pad_id, max_seq_length): + """Pad sequences and convert them to numpy.""" + + # Some checks. + num_tokens = len(tokens) + padding_length = max_seq_length - num_tokens + assert padding_length >= 0 + assert len(tokentypes) == num_tokens + assert len(masked_positions) == len(masked_labels) + + # Tokens and token types. + filler = [pad_id] * padding_length + tokens_np = np.array(tokens + filler, dtype=np.int64) + tokentypes_np = np.array(tokentypes + filler, dtype=np.int64) + + # Padding mask. + padding_mask_np = np.array([1] * num_tokens + [0] * padding_length, dtype=np.int64) + + # Lables and loss mask. + labels = [-1] * max_seq_length + loss_mask = [0] * max_seq_length + for i in range(len(masked_positions)): + assert masked_positions[i] < num_tokens + labels[masked_positions[i]] = masked_labels[i] + loss_mask[masked_positions[i]] = 1 + labels_np = np.array(labels, dtype=np.int64) + loss_mask_np = np.array(loss_mask, dtype=np.int64) + + return tokens_np, tokentypes_np, labels_np, padding_mask_np, loss_mask_np diff --git a/nemo/collections/nlp/data/language_modeling/megatron/blendable_dataset.py b/nemo/collections/nlp/data/language_modeling/megatron/blendable_dataset.py new file mode 100644 index 000000000..8b3706fd8 --- /dev/null +++ b/nemo/collections/nlp/data/language_modeling/megatron/blendable_dataset.py @@ -0,0 +1,75 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Blendable dataset.""" + +import time + +import numpy as np +import torch + +from nemo.utils import logging +from nemo.utils.get_rank import is_global_rank_zero + + +class BlendableDataset(torch.utils.data.Dataset): + def __init__(self, datasets, weights): + + self.datasets = datasets + num_datasets = len(datasets) + assert num_datasets == len(weights) + + self.size = 0 + for dataset in self.datasets: + self.size += len(dataset) + + # Normalize weights. + weights = np.array(weights, dtype=np.float64) + sum_weights = np.sum(weights) + assert sum_weights > 0.0 + weights /= sum_weights + + # Build indecies. + start_time = time.time() + assert num_datasets < 255 + self.dataset_index = np.zeros(self.size, dtype=np.uint8) + self.dataset_sample_index = np.zeros(self.size, dtype=np.int64) + + try: + if is_global_rank_zero(): + from nemo.collections.nlp.data.language_modeling.megatron.dataset_utils import compile_helper + + compile_helper() + from nemo.collections.nlp.data.language_modeling.megatron import helpers + except: + raise Exception(f'Could not compile helpers.') + helpers.build_blending_indices( + self.dataset_index, + self.dataset_sample_index, + weights, + num_datasets, + self.size, + torch.distributed.get_rank() == 0, + ) + logging.info( + '> elapsed time for building blendable dataset indices: ' '{:.2f} (sec)'.format(time.time() - start_time) + ) + + def __len__(self): + return self.size + + def __getitem__(self, idx): + dataset_idx = self.dataset_index[idx] + sample_idx = self.dataset_sample_index[idx] + return self.datasets[dataset_idx][sample_idx] diff --git a/nemo/collections/nlp/data/language_modeling/megatron/data_samplers.py b/nemo/collections/nlp/data/language_modeling/megatron/data_samplers.py new file mode 100644 index 000000000..d0342a180 --- /dev/null +++ b/nemo/collections/nlp/data/language_modeling/megatron/data_samplers.py @@ -0,0 +1,121 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Dataloaders.""" + + +import torch + +from nemo.utils import logging + + +class MegatronPretrainingSampler: + def __init__( + self, total_samples, consumed_samples, micro_batch_size, data_parallel_rank, data_parallel_size, drop_last=True + ): + # Keep a copy of input params for later use. + self.total_samples = total_samples + self.consumed_samples = consumed_samples + self.micro_batch_size = micro_batch_size + self.data_parallel_rank = data_parallel_rank + self.micro_batch_times_data_parallel_size = self.micro_batch_size * data_parallel_size + self.drop_last = drop_last + + logging.info( + f'Instantiating MegatronPretrainingSampler with total_samples: {total_samples} and consumed_samples: {consumed_samples}' + ) + + # Sanity checks. + assert self.total_samples > 0, 'no sample to consume: {}'.format(self.total_samples) + assert self.consumed_samples < self.total_samples, 'no samples left to consume: {}, {}'.format( + self.consumed_samples, self.total_samples + ) + assert self.micro_batch_size > 0 + assert data_parallel_size > 0 + assert self.data_parallel_rank < data_parallel_size, ( + 'data_parallel_rank should be smaller than data size: {}, ' + '{}'.format(self.data_parallel_rank, data_parallel_size) + ) + + def __len__(self): + return self.total_samples + + def get_start_end_idx(self): + start_idx = self.data_parallel_rank * self.micro_batch_size + end_idx = start_idx + self.micro_batch_size + return start_idx, end_idx + + def __iter__(self): + batch = [] + # Last batch will be dropped if drop_last is not set False + for idx in range(self.consumed_samples, self.total_samples): + batch.append(idx) + if len(batch) == self.micro_batch_times_data_parallel_size: + start_idx, end_idx = self.get_start_end_idx() + yield batch[start_idx:end_idx] + batch = [] + + # Check the last partial batch and see drop_last is set + if len(batch) > 0 and not self.drop_last: + start_idx, end_idx = self.get_start_end_idx() + yield batch[start_idx:end_idx] + + +class MegatronPretrainingRandomSampler: + def __init__(self, total_samples, consumed_samples, micro_batch_size, data_parallel_rank, data_parallel_size): + # Keep a copy of input params for later use. + self.total_samples = total_samples + self.consumed_samples = consumed_samples + self.micro_batch_size = micro_batch_size + self.data_parallel_rank = data_parallel_rank + self.data_parallel_size = data_parallel_size + self.micro_batch_times_data_parallel_size = self.micro_batch_size * data_parallel_size + self.last_batch_size = self.total_samples % self.micro_batch_times_data_parallel_size + + # Sanity checks. + assert self.total_samples > 0, 'no sample to consume: {}'.format(self.total_samples) + assert self.micro_batch_size > 0 + assert data_parallel_size > 0 + assert self.data_parallel_rank < data_parallel_size, ( + 'data_parallel_rank should be smaller than data size: {}, ' + '{}'.format(self.data_parallel_rank, data_parallel_size) + ) + + def __len__(self): + return self.total_samples + + def __iter__(self): + active_total_samples = self.total_samples - self.last_batch_size + self.epoch = self.consumed_samples // active_total_samples + current_epoch_samples = self.consumed_samples % active_total_samples + assert current_epoch_samples % self.micro_batch_times_data_parallel_size == 0 + + # data sharding and random sampling + bucket_size = (self.total_samples // self.micro_batch_times_data_parallel_size) * self.micro_batch_size + bucket_offset = current_epoch_samples // self.data_parallel_size + start_idx = self.data_parallel_rank * bucket_size + + g = torch.Generator() + g.manual_seed(self.epoch) + random_idx = torch.randperm(bucket_size, generator=g).tolist() + idx_range = [start_idx + x for x in random_idx[bucket_offset:]] + + batch = [] + # Last batch if not complete will be dropped. + for idx in idx_range: + batch.append(idx) + if len(batch) == self.micro_batch_size: + self.consumed_samples += self.micro_batch_times_data_parallel_size + yield batch + batch = [] diff --git a/nemo/collections/nlp/data/language_modeling/megatron/dataset_utils.py b/nemo/collections/nlp/data/language_modeling/megatron/dataset_utils.py new file mode 100644 index 000000000..706c99273 --- /dev/null +++ b/nemo/collections/nlp/data/language_modeling/megatron/dataset_utils.py @@ -0,0 +1,732 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Copyright 2018 The Google AI Team Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Most of the code here has been copied from: +# https://github.com/google-research/albert/blob/master/create_pretraining_data.py +# with some modifications. + +import collections +import math +import os +import subprocess +import time + +import numpy as np +import torch +from apex.transformer import parallel_state + +from nemo.collections.nlp.data.language_modeling.megatron.blendable_dataset import BlendableDataset +from nemo.collections.nlp.data.language_modeling.megatron.indexed_dataset import make_dataset as make_indexed_dataset +from nemo.utils import logging +from nemo.utils.get_rank import is_global_rank_zero + +DSET_TYPE_BERT = 'standard_bert' +DSET_TYPE_ICT = 'ict' +DSET_TYPE_T5 = 't5' + +DSET_TYPES = [DSET_TYPE_BERT, DSET_TYPE_ICT, DSET_TYPE_T5] + + +def get_datasets_weights_and_num_samples(data_prefix, train_valid_test_num_samples): + + # The data prefix should be in the format of: + # weight-1, data-prefix-1, weight-2, data-prefix-2, .. + assert len(data_prefix) % 2 == 0 + num_datasets = len(data_prefix) // 2 + weights = [0] * num_datasets + prefixes = [0] * num_datasets + for i in range(num_datasets): + weights[i] = float(data_prefix[2 * i]) + prefixes[i] = (data_prefix[2 * i + 1]).strip() + # Normalize weights + weight_sum = 0.0 + for weight in weights: + weight_sum += weight + assert weight_sum > 0.0 + weights = [weight / weight_sum for weight in weights] + + # Add 0.5% (the 1.005 factor) so in case the bleding dataset does + # not uniformly distribute the number of samples, we still have + # samples left to feed to the network. + datasets_train_valid_test_num_samples = [] + for weight in weights: + datasets_train_valid_test_num_samples.append( + [int(math.ceil(val * weight * 1.005)) for val in train_valid_test_num_samples] + ) + + return prefixes, weights, datasets_train_valid_test_num_samples + + +def compile_helper(): + """Compile helper function ar runtime. Make sure this + is invoked on a single process.""" + + path = os.path.abspath(os.path.dirname(__file__)) + ret = subprocess.run(['make', '-C', path]) + if ret.returncode != 0: + logging.error("Making C++ dataset helpers module failed, exiting.") + import sys + + sys.exit(1) + + +def get_a_and_b_segments(sample, np_rng): + """Divide sample into a and b segments.""" + + # Number of sentences in the sample. + n_sentences = len(sample) + # Make sure we always have two sentences. + assert n_sentences > 1, 'make sure each sample has at least two sentences.' + + # First part: + # `a_end` is how many sentences go into the `A`. + a_end = 1 + if n_sentences >= 3: + # Note that randin in numpy is exclusive. + a_end = np_rng.randint(1, n_sentences) + tokens_a = [] + for j in range(a_end): + tokens_a.extend(sample[j]) + + # Second part: + tokens_b = [] + for j in range(a_end, n_sentences): + tokens_b.extend(sample[j]) + + # Random next: + is_next_random = False + if np_rng.random() < 0.5: + is_next_random = True + tokens_a, tokens_b = tokens_b, tokens_a + + return tokens_a, tokens_b, is_next_random + + +def truncate_segments(tokens_a, tokens_b, len_a, len_b, max_num_tokens, np_rng): + """Truncates a pair of sequences to a maximum sequence length.""" + # print(len_a, len_b, max_num_tokens) + assert len_a > 0 + if len_a + len_b <= max_num_tokens: + return False + while len_a + len_b > max_num_tokens: + if len_a > len_b: + len_a -= 1 + tokens = tokens_a + else: + len_b -= 1 + tokens = tokens_b + if np_rng.random() < 0.5: + del tokens[0] + else: + tokens.pop() + return True + + +def create_tokens_and_tokentypes(tokens_a, tokens_b, cls_id, sep_id): + """Merge segments A and B, add [CLS] and [SEP] and build tokentypes.""" + + tokens = [] + tokentypes = [] + # [CLS]. + tokens.append(cls_id) + tokentypes.append(0) + # Segment A. + for token in tokens_a: + tokens.append(token) + tokentypes.append(0) + # [SEP]. + tokens.append(sep_id) + tokentypes.append(0) + # Segment B. + for token in tokens_b: + tokens.append(token) + tokentypes.append(1) + if tokens_b: + # [SEP]. + tokens.append(sep_id) + tokentypes.append(1) + + return tokens, tokentypes + + +MaskedLmInstance = collections.namedtuple("MaskedLmInstance", ["index", "label"]) + + +def is_start_piece(piece): + """Check if the current word piece is the starting piece (BERT).""" + # When a word has been split into + # WordPieces, the first token does not have any marker and any subsequence + # tokens are prefixed with ##. So whenever we see the ## token, we + # append it to the previous set of word indexes. + return not piece.startswith("##") + + +def create_masked_lm_predictions( + tokens, + vocab_id_list, + vocab_id_to_token_dict, + masked_lm_prob, + cls_id, + sep_id, + mask_id, + max_predictions_per_seq, + np_rng, + max_ngrams=3, + do_whole_word_mask=True, + favor_longer_ngram=False, + do_permutation=False, + geometric_dist=False, + masking_style="bert", +): + """Creates the predictions for the masked LM objective. + Note: Tokens here are vocab ids and not text tokens.""" + + cand_indexes = [] + # Note(mingdachen): We create a list for recording if the piece is + # the starting piece of current token, where 1 means true, so that + # on-the-fly whole word masking is possible. + token_boundary = [0] * len(tokens) + + for (i, token) in enumerate(tokens): + if token == cls_id or token == sep_id: + token_boundary[i] = 1 + continue + # Whole Word Masking means that if we mask all of the wordpieces + # corresponding to an original word. + # + # Note that Whole Word Masking does *not* change the training code + # at all -- we still predict each WordPiece independently, softmaxed + # over the entire vocabulary. + if do_whole_word_mask and len(cand_indexes) >= 1 and not is_start_piece(vocab_id_to_token_dict[token]): + cand_indexes[-1].append(i) + else: + cand_indexes.append([i]) + if is_start_piece(vocab_id_to_token_dict[token]): + token_boundary[i] = 1 + + output_tokens = list(tokens) + + masked_lm_positions = [] + masked_lm_labels = [] + + if masked_lm_prob == 0: + return (output_tokens, masked_lm_positions, masked_lm_labels, token_boundary) + + num_to_predict = min(max_predictions_per_seq, max(1, int(round(len(tokens) * masked_lm_prob)))) + + ngrams = np.arange(1, max_ngrams + 1, dtype=np.int64) + if not geometric_dist: + # Note(mingdachen): + # By default, we set the probilities to favor shorter ngram sequences. + pvals = 1.0 / np.arange(1, max_ngrams + 1) + pvals /= pvals.sum(keepdims=True) + if favor_longer_ngram: + pvals = pvals[::-1] + + ngram_indexes = [] + for idx in range(len(cand_indexes)): + ngram_index = [] + for n in ngrams: + ngram_index.append(cand_indexes[idx : idx + n]) + ngram_indexes.append(ngram_index) + + np_rng.shuffle(ngram_indexes) + + (masked_lms, masked_spans) = ([], []) + covered_indexes = set() + for cand_index_set in ngram_indexes: + if len(masked_lms) >= num_to_predict: + break + if not cand_index_set: + continue + # Note(mingdachen): + # Skip current piece if they are covered in lm masking or previous ngrams. + for index_set in cand_index_set[0]: + for index in index_set: + if index in covered_indexes: + continue + + if not geometric_dist: + n = np_rng.choice( + ngrams[: len(cand_index_set)], + p=pvals[: len(cand_index_set)] / pvals[: len(cand_index_set)].sum(keepdims=True), + ) + else: + # Sampling "n" from the geometric distribution and clipping it to + # the max_ngrams. Using p=0.2 default from the SpanBERT paper + # https://arxiv.org/pdf/1907.10529.pdf (Sec 3.1) + n = min(np_rng.geometric(0.2), max_ngrams) + + index_set = sum(cand_index_set[n - 1], []) + n -= 1 + # Note(mingdachen): + # Repeatedly looking for a candidate that does not exceed the + # maximum number of predictions by trying shorter ngrams. + while len(masked_lms) + len(index_set) > num_to_predict: + if n == 0: + break + index_set = sum(cand_index_set[n - 1], []) + n -= 1 + # If adding a whole-word mask would exceed the maximum number of + # predictions, then just skip this candidate. + if len(masked_lms) + len(index_set) > num_to_predict: + continue + is_any_index_covered = False + for index in index_set: + if index in covered_indexes: + is_any_index_covered = True + break + if is_any_index_covered: + continue + for index in index_set: + covered_indexes.add(index) + masked_token = None + if masking_style == "bert": + # 80% of the time, replace with [MASK] + if np_rng.random() < 0.8: + masked_token = mask_id + else: + # 10% of the time, keep original + if np_rng.random() < 0.5: + masked_token = tokens[index] + # 10% of the time, replace with random word + else: + masked_token = vocab_id_list[np_rng.randint(0, len(vocab_id_list))] + elif masking_style == "t5": + masked_token = mask_id + else: + raise ValueError("invalid value of masking style") + + output_tokens[index] = masked_token + masked_lms.append(MaskedLmInstance(index=index, label=tokens[index])) + + masked_spans.append(MaskedLmInstance(index=index_set, label=[tokens[index] for index in index_set])) + + assert len(masked_lms) <= num_to_predict + np_rng.shuffle(ngram_indexes) + + select_indexes = set() + if do_permutation: + for cand_index_set in ngram_indexes: + if len(select_indexes) >= num_to_predict: + break + if not cand_index_set: + continue + # Note(mingdachen): + # Skip current piece if they are covered in lm masking or previous ngrams. + for index_set in cand_index_set[0]: + for index in index_set: + if index in covered_indexes or index in select_indexes: + continue + + n = np.random.choice( + ngrams[: len(cand_index_set)], + p=pvals[: len(cand_index_set)] / pvals[: len(cand_index_set)].sum(keepdims=True), + ) + index_set = sum(cand_index_set[n - 1], []) + n -= 1 + + while len(select_indexes) + len(index_set) > num_to_predict: + if n == 0: + break + index_set = sum(cand_index_set[n - 1], []) + n -= 1 + # If adding a whole-word mask would exceed the maximum number of + # predictions, then just skip this candidate. + if len(select_indexes) + len(index_set) > num_to_predict: + continue + is_any_index_covered = False + for index in index_set: + if index in covered_indexes or index in select_indexes: + is_any_index_covered = True + break + if is_any_index_covered: + continue + for index in index_set: + select_indexes.add(index) + assert len(select_indexes) <= num_to_predict + + select_indexes = sorted(select_indexes) + permute_indexes = list(select_indexes) + np_rng.shuffle(permute_indexes) + orig_token = list(output_tokens) + + for src_i, tgt_i in zip(select_indexes, permute_indexes): + output_tokens[src_i] = orig_token[tgt_i] + masked_lms.append(MaskedLmInstance(index=src_i, label=orig_token[src_i])) + + masked_lms = sorted(masked_lms, key=lambda x: x.index) + # Sort the spans by the index of the first span + masked_spans = sorted(masked_spans, key=lambda x: x.index[0]) + + for p in masked_lms: + masked_lm_positions.append(p.index) + masked_lm_labels.append(p.label) + return (output_tokens, masked_lm_positions, masked_lm_labels, token_boundary, masked_spans) + + +def pad_and_convert_to_numpy(tokens, tokentypes, masked_positions, masked_labels, pad_id, max_seq_length): + """Pad sequences and convert them to numpy.""" + + # Some checks. + num_tokens = len(tokens) + padding_length = max_seq_length - num_tokens + assert padding_length >= 0 + assert len(tokentypes) == num_tokens + assert len(masked_positions) == len(masked_labels) + + # Tokens and token types. + filler = [pad_id] * padding_length + tokens_np = np.array(tokens + filler, dtype=np.int64) + tokentypes_np = np.array(tokentypes + filler, dtype=np.int64) + + # Padding mask. + padding_mask_np = np.array([1] * num_tokens + [0] * padding_length, dtype=np.int64) + + # Lables and loss mask. + labels = [-1] * max_seq_length + loss_mask = [0] * max_seq_length + for i in range(len(masked_positions)): + assert masked_positions[i] < num_tokens + labels[masked_positions[i]] = masked_labels[i] + loss_mask[masked_positions[i]] = 1 + labels_np = np.array(labels, dtype=np.int64) + loss_mask_np = np.array(loss_mask, dtype=np.int64) + + return tokens_np, tokentypes_np, labels_np, padding_mask_np, loss_mask_np + + +def build_train_valid_test_datasets( + data_prefix, + data_impl, + splits_string, + train_valid_test_num_samples, + max_seq_length, + masked_lm_prob, + short_seq_prob, + seed, + skip_warmup, + binary_head=False, + max_seq_length_dec=None, + dataset_type='standard_bert', +): + + if len(data_prefix) == 1: + return _build_train_valid_test_datasets( + data_prefix[0], + data_impl, + splits_string, + train_valid_test_num_samples, + max_seq_length, + masked_lm_prob, + short_seq_prob, + seed, + skip_warmup, + binary_head, + max_seq_length_dec, + dataset_type=dataset_type, + ) + # Blending dataset. + # Parse the values. + output = get_datasets_weights_and_num_samples(data_prefix, train_valid_test_num_samples) + prefixes, weights, datasets_train_valid_test_num_samples = output + + # Build individual datasets. + train_datasets = [] + valid_datasets = [] + test_datasets = [] + for i in range(len(prefixes)): + train_ds, valid_ds, test_ds = _build_train_valid_test_datasets( + prefixes[i], + data_impl, + splits_string, + datasets_train_valid_test_num_samples[i], + max_seq_length, + masked_lm_prob, + short_seq_prob, + seed, + skip_warmup, + binary_head, + dataset_type=dataset_type, + ) + if train_ds: + train_datasets.append(train_ds) + if valid_ds: + valid_datasets.append(valid_ds) + if test_ds: + test_datasets.append(test_ds) + + # Blend. + blending_train_dataset = None + if train_datasets: + blending_train_dataset = BlendableDataset(train_datasets, weights) + blending_valid_dataset = None + if valid_datasets: + blending_valid_dataset = BlendableDataset(valid_datasets, weights) + blending_test_dataset = None + if test_datasets: + blending_test_dataset = BlendableDataset(test_datasets, weights) + + return (blending_train_dataset, blending_valid_dataset, blending_test_dataset) + + +def _build_train_valid_test_datasets( + data_prefix, + data_impl, + splits_string, + train_valid_test_num_samples, + max_seq_length, + masked_lm_prob, + short_seq_prob, + seed, + skip_warmup, + binary_head, + max_seq_length_dec, + dataset_type='standard_bert', +): + + if dataset_type not in DSET_TYPES: + raise ValueError("Invalid dataset_type: ", dataset_type) + + # Indexed dataset. + indexed_dataset = get_indexed_dataset_(data_prefix, data_impl, skip_warmup) + + if dataset_type == DSET_TYPE_ICT: + args = get_args() + title_dataset = get_indexed_dataset_(args.titles_data_path, data_impl, skip_warmup) + + # Get start and end indices of train/valid/train into doc-idx + # Note that doc-idx is desinged to be num-docs + 1 so we can + # easily iterate over it. + total_num_of_documents = indexed_dataset.doc_idx.shape[0] - 1 + splits = get_train_valid_test_split_(splits_string, total_num_of_documents) + + # Print stats about the splits. + logging.info(' > dataset split:') + + def print_split_stats(name, index): + logging.info(' {}:'.format(name)) + logging.info( + ' document indices in [{}, {}) total of {} ' + 'documents'.format(splits[index], splits[index + 1], splits[index + 1] - splits[index]) + ) + start_index = indexed_dataset.doc_idx[splits[index]] + end_index = indexed_dataset.doc_idx[splits[index + 1]] + logging.info( + ' sentence indices in [{}, {}) total of {} ' + 'sentences'.format(start_index, end_index, end_index - start_index) + ) + + print_split_stats('train', 0) + print_split_stats('validation', 1) + print_split_stats('test', 2) + + def build_dataset(index, name): + from nemo.collections.nlp.data.language_modeling.megatron.bert_dataset import BertDataset + from nemo.collections.nlp.data.language_modeling.megatron.t5_dataset import T5Dataset + + dataset = None + if splits[index + 1] > splits[index]: + # Get the pointer to the original doc-idx so we can set it later. + doc_idx_ptr = indexed_dataset.get_doc_idx() + # Slice the doc-idx + start_index = splits[index] + # Add +1 so we can index into the dataset to get the upper bound. + end_index = splits[index + 1] + 1 + # New doc_idx view. + indexed_dataset.set_doc_idx(doc_idx_ptr[start_index:end_index]) + # Build the dataset accordingly. + kwargs = dict( + name=name, + data_prefix=data_prefix, + num_epochs=None, + max_num_samples=train_valid_test_num_samples[index], + max_seq_length=max_seq_length, + seed=seed, + ) + + if dataset_type == DSET_TYPE_T5: + dataset = T5Dataset( + indexed_dataset=indexed_dataset, + masked_lm_prob=masked_lm_prob, + max_seq_length_dec=max_seq_length_dec, + short_seq_prob=short_seq_prob, + **kwargs, + ) + elif dataset_type == DSET_TYPE_BERT: + dataset = BertDataset( + indexed_dataset=indexed_dataset, + masked_lm_prob=masked_lm_prob, + short_seq_prob=short_seq_prob, + binary_head=binary_head, + **kwargs, + ) + else: + raise NotImplementedError("Dataset type not fully implemented.") + + # Set the original pointer so dataset remains the main dataset. + indexed_dataset.set_doc_idx(doc_idx_ptr) + # Checks. + assert indexed_dataset.doc_idx[0] == 0 + assert indexed_dataset.doc_idx.shape[0] == (total_num_of_documents + 1) + return dataset + + train_dataset = build_dataset(0, 'train') + valid_dataset = build_dataset(1, 'valid') + test_dataset = build_dataset(2, 'test') + + return (train_dataset, valid_dataset, test_dataset) + + +def get_indexed_dataset_(data_prefix, data_impl, skip_warmup): + + logging.info(' > building dataset index ...') + + start_time = time.time() + indexed_dataset = make_indexed_dataset(data_prefix, data_impl, skip_warmup) + assert indexed_dataset.sizes.shape[0] == indexed_dataset.doc_idx[-1] + logging.info(' > finished creating indexed dataset in {:4f} ' 'seconds'.format(time.time() - start_time)) + + logging.info(' > indexed dataset stats:') + logging.info(' number of documents: {}'.format(indexed_dataset.doc_idx.shape[0] - 1)) + logging.info(' number of sentences: {}'.format(indexed_dataset.sizes.shape[0])) + + return indexed_dataset + + +def get_train_valid_test_split_(splits_string, size): + """ Get dataset splits from comma or '/' separated string list.""" + + splits = [] + if splits_string.find(',') != -1: + splits = [float(s) for s in splits_string.split(',')] + elif splits_string.find('/') != -1: + splits = [float(s) for s in splits_string.split('/')] + else: + splits = [float(splits_string)] + while len(splits) < 3: + splits.append(0.0) + splits = splits[:3] + splits_sum = sum(splits) + assert splits_sum > 0.0 + splits = [split / splits_sum for split in splits] + splits_index = [0] + for index, split in enumerate(splits): + splits_index.append(splits_index[index] + int(round(split * float(size)))) + diff = splits_index[-1] - size + for index in range(1, len(splits_index)): + splits_index[index] -= diff + assert len(splits_index) == 4 + assert splits_index[-1] == size + return splits_index + + +def get_samples_mapping( + indexed_dataset, data_prefix, num_epochs, max_num_samples, max_seq_length, short_seq_prob, seed, name, binary_head +): + """Get a list that maps a sample index to a starting sentence index, end sentence index, and length""" + + if not num_epochs: + if not max_num_samples: + raise ValueError("Need to specify either max_num_samples " "or num_epochs") + num_epochs = np.iinfo(np.int32).max - 1 + if not max_num_samples: + max_num_samples = np.iinfo(np.int64).max - 1 + + # Filename of the index mapping + indexmap_filename = data_prefix + indexmap_filename += '_{}_indexmap'.format(name) + if num_epochs != (np.iinfo(np.int32).max - 1): + indexmap_filename += '_{}ep'.format(num_epochs) + if max_num_samples != (np.iinfo(np.int64).max - 1): + indexmap_filename += '_{}mns'.format(max_num_samples) + indexmap_filename += '_{}msl'.format(max_seq_length) + indexmap_filename += '_{:0.2f}ssp'.format(short_seq_prob) + indexmap_filename += '_{}s'.format(seed) + indexmap_filename += '.npy' + + # Build the indexed mapping if not exist. + if torch.distributed.get_rank() == 0 and not os.path.isfile(indexmap_filename): + print( + ' > WARNING: could not find index map file {}, building ' + 'the indices on rank 0 ...'.format(indexmap_filename) + ) + + # Make sure the types match the helpers input types. + assert indexed_dataset.doc_idx.dtype == np.int64 + assert indexed_dataset.sizes.dtype == np.int32 + + # Build samples mapping + verbose = torch.distributed.get_rank() == 0 + start_time = time.time() + logging.info(' > building samples index mapping for {} ...'.format(name)) + # First compile and then import. + + try: + if is_global_rank_zero(): + compile_helper() + from nemo.collections.nlp.data.language_modeling.megatron import helpers + except: + raise Exception(f'Could not compile helpers.') + samples_mapping = helpers.build_mapping( + indexed_dataset.doc_idx, + indexed_dataset.sizes, + num_epochs, + max_num_samples, + max_seq_length, + short_seq_prob, + seed, + verbose, + 2 if binary_head else 1, + ) + logging.info(' > done building samples index maping') + np.save(indexmap_filename, samples_mapping, allow_pickle=True) + logging.info(' > saved the index mapping in {}'.format(indexmap_filename)) + # Make sure all the ranks have built the mapping + logging.info( + ' > elasped time to build and save samples mapping ' '(seconds): {:4f}'.format(time.time() - start_time) + ) + # This should be a barrier but nccl barrier assumes + # device_index=rank which is not the case for model + # parallel case + counts = torch.cuda.LongTensor([1]) + torch.distributed.all_reduce(counts, group=parallel_state.get_data_parallel_group()) + torch.distributed.all_reduce(counts, group=parallel_state.get_pipeline_model_parallel_group()) + assert counts[0].item() == ( + torch.distributed.get_world_size() + // torch.distributed.get_world_size(group=parallel_state.get_tensor_model_parallel_group()) + ) + + # Load indexed dataset. + logging.info(' > loading indexed mapping from {}'.format(indexmap_filename)) + start_time = time.time() + samples_mapping = np.load(indexmap_filename, allow_pickle=True, mmap_mode='r') + logging.info(' loaded indexed file in {:3.3f} seconds'.format(time.time() - start_time)) + logging.info(' total number of samples: {}'.format(samples_mapping.shape[0])) + + return samples_mapping diff --git a/nemo/collections/nlp/data/language_modeling/megatron/gpt_dataset.py b/nemo/collections/nlp/data/language_modeling/megatron/gpt_dataset.py new file mode 100644 index 000000000..30a1219a9 --- /dev/null +++ b/nemo/collections/nlp/data/language_modeling/megatron/gpt_dataset.py @@ -0,0 +1,449 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""GPT style dataset.""" + +import os +import time + +import numpy as np +import torch +from apex.transformer import parallel_state + +from nemo.collections.nlp.data.language_modeling.megatron.blendable_dataset import BlendableDataset +from nemo.collections.nlp.data.language_modeling.megatron.dataset_utils import ( + get_datasets_weights_and_num_samples, + get_train_valid_test_split_, +) +from nemo.collections.nlp.data.language_modeling.megatron.indexed_dataset import make_dataset as make_indexed_dataset +from nemo.collections.nlp.data.language_modeling.megatron.megatron_dataset import MegatronDataset +from nemo.utils import logging +from nemo.utils.get_rank import is_global_rank_zero + + +def build_train_valid_test_datasets( + cfg, trainer, data_prefix, data_impl, splits_string, train_valid_test_num_samples, seq_length, seed, skip_warmup +): + """Build train, valid, and test datasets.""" + + # Single dataset. + if len(data_prefix) == 1: + return _build_train_valid_test_datasets( + cfg, + trainer, + data_prefix[0], + data_impl, + splits_string, + train_valid_test_num_samples, + seq_length, + seed, + skip_warmup, + ) + + # Blending dataset. + # Parse the values. + output = get_datasets_weights_and_num_samples(data_prefix, train_valid_test_num_samples) + prefixes, weights, datasets_train_valid_test_num_samples = output + + # Build individual datasets. + train_datasets = [] + valid_datasets = [] + test_datasets = [] + for i in range(len(prefixes)): + train_ds, valid_ds, test_ds = _build_train_valid_test_datasets( + cfg, + trainer, + prefixes[i], + data_impl, + splits_string, + datasets_train_valid_test_num_samples[i], + seq_length, + seed, + skip_warmup, + ) + if train_ds: + train_datasets.append(train_ds) + if valid_ds: + valid_datasets.append(valid_ds) + if test_ds: + test_datasets.append(test_ds) + + # Blend. + blending_train_dataset = None + if train_datasets: + blending_train_dataset = BlendableDataset(train_datasets, weights) + blending_valid_dataset = None + if valid_datasets: + blending_valid_dataset = BlendableDataset(valid_datasets, weights) + blending_test_dataset = None + if test_datasets: + blending_test_dataset = BlendableDataset(test_datasets, weights) + + return (blending_train_dataset, blending_valid_dataset, blending_test_dataset) + + +def _build_train_valid_test_datasets( + cfg, trainer, data_prefix, data_impl, splits_string, train_valid_test_num_samples, seq_length, seed, skip_warmup +): + """Build train, valid, and test datasets.""" + + # Indexed dataset. + indexed_dataset = get_indexed_dataset_(data_prefix, data_impl, skip_warmup) + + total_num_of_documents = indexed_dataset.sizes.shape[0] + splits = get_train_valid_test_split_(splits_string, total_num_of_documents) + + # Print stats about the splits. + logging.info(' > dataset split:') + + def print_split_stats(name, index): + logging.info(' {}:'.format(name)) + logging.info( + ' document indices in [{}, {}) total of {} ' + 'documents'.format(splits[index], splits[index + 1], splits[index + 1] - splits[index]) + ) + + print_split_stats('train', 0) + print_split_stats('validation', 1) + print_split_stats('test', 2) + + def build_dataset(index, name): + dataset = None + if splits[index + 1] > splits[index]: + documents = np.arange(start=splits[index], stop=splits[index + 1], step=1, dtype=np.int32) + dataset = GPTDataset( + cfg, + trainer, + name, + data_prefix, + documents, + indexed_dataset, + train_valid_test_num_samples[index], + seq_length, + seed, + ) + return dataset + + train_dataset = build_dataset(0, 'train') + valid_dataset = build_dataset(1, 'valid') + test_dataset = build_dataset(2, 'test') + + return (train_dataset, valid_dataset, test_dataset) + + +def get_indexed_dataset_(data_prefix, data_impl, skip_warmup): + """Build indexed dataset.""" + logging.info(' > building dataset index ...') + + start_time = time.time() + indexed_dataset = make_indexed_dataset(data_prefix, data_impl, skip_warmup) + logging.info(' > finished creating indexed dataset in {:4f} ' 'seconds'.format(time.time() - start_time)) + logging.info(' number of documents: {}'.format(indexed_dataset.sizes.shape[0])) + + return indexed_dataset + + +class GPTDataset(MegatronDataset): + def __init__(self, cfg, trainer, name, data_prefix, documents, indexed_dataset, num_samples, seq_length, seed): + + super().__init__(cfg, trainer=trainer) + self.name = name + self.indexed_dataset = indexed_dataset + + # Checks + assert np.min(documents) >= 0 + assert np.max(documents) < indexed_dataset.sizes.shape[0] + + # Build index mappings. + self.doc_idx, self.sample_idx, self.shuffle_idx = _build_index_mappings( + self.name, data_prefix, documents, self.indexed_dataset.sizes, num_samples, seq_length, seed + ) + + def __len__(self): + # -1 is due to data structure used to retieve the index: + # sample i --> [sample_idx[i], sample_idx[i+1]) + return self.sample_idx.shape[0] - 1 + + def __getitem__(self, idx): + # Get the shuffled index. + idx = self.shuffle_idx[idx] + # Start and end documents and offsets. + doc_index_f = self.sample_idx[idx][0] + doc_index_l = self.sample_idx[idx + 1][0] + offset_f = self.sample_idx[idx][1] + offset_l = self.sample_idx[idx + 1][1] + # If we are within the same document, just extract the chunk. + if doc_index_f == doc_index_l: + sample = self.indexed_dataset.get( + self.doc_idx[doc_index_f], offset=offset_f, length=offset_l - offset_f + 1 + ) + else: + # Otherwise, get the rest of the initial document. + sample_list = [self.indexed_dataset.get(self.doc_idx[doc_index_f], offset=offset_f)] + # Loop over all in between documents and add the entire document. + for i in range(doc_index_f + 1, doc_index_l): + sample_list.append(self.indexed_dataset.get(self.doc_idx[i])) + # And finally add the relevant portion of last document. + sample_list.append(self.indexed_dataset.get(self.doc_idx[doc_index_l], length=offset_l + 1)) + sample = np.concatenate(sample_list) + + return {'text': np.array(sample, dtype=np.int64)} + + +def _build_index_mappings(name, data_prefix, documents, sizes, num_samples, seq_length, seed): + """Build doc-idx, sample-idx, and shuffle-idx. + doc-idx: is an array (ordered) of documents to be used in training. + sample-idx: is the start document index and document offset for each + training sample. + shuffle-idx: maps the sample index into a random index into sample-idx. + """ + # Number of tokens in each epoch and number of required epochs. + tokens_per_epoch = _num_tokens(documents, sizes) + num_epochs = _num_epochs(tokens_per_epoch, seq_length, num_samples) + # rng state + np_rng = np.random.RandomState(seed=seed) + + # Filename of the index mappings. + _filename = data_prefix + _filename += '_{}_indexmap'.format(name) + _filename += '_{}ns'.format(num_samples) + _filename += '_{}sl'.format(seq_length) + _filename += '_{}s'.format(seed) + doc_idx_filename = _filename + '_doc_idx.npy' + sample_idx_filename = _filename + '_sample_idx.npy' + shuffle_idx_filename = _filename + '_shuffle_idx.npy' + + # Build the indexed mapping if not exist. + if torch.distributed.get_rank() == 0: + if ( + (not os.path.isfile(doc_idx_filename)) + or (not os.path.isfile(sample_idx_filename)) + or (not os.path.isfile(shuffle_idx_filename)) + ): + + logging.info(' > WARNING: could not find index map files, building ' 'the indices on rank 0 ...') + + # For the last epoch, decide whether include the entire epoch + # in the global shuffle or not. + + # If we need only one epoch, then separating last epoch does + # not mean anything. + if num_epochs == 1: + separate_last_epoch = False + print(' > only one epoch required, setting ' 'separate_last_epoch to False', flush=True) + + else: + # Get the number of samples for the last epoch + num_samples_from_epochs_minus_one = ((num_epochs - 1) * tokens_per_epoch - 1) // seq_length + last_epoch_num_samples = num_samples - num_samples_from_epochs_minus_one + assert last_epoch_num_samples >= 0, 'last epoch number of samples should be non-negative.' + num_samples_per_epoch = (tokens_per_epoch - 1) // seq_length + assert last_epoch_num_samples < ( + num_samples_per_epoch + 1 + ), 'last epoch number of samples exceeded max value.' + # If we have less than 80% of the samples for the last epoch, + # seperate out the epoch and treat it differently. + # Note: the 80% number is just based on common sense and can + # be adjusted if needed. + separate_last_epoch = last_epoch_num_samples < int(0.80 * num_samples_per_epoch) + if separate_last_epoch: + string = ( + ' > last epoch number of samples ({}) is smaller ' + 'than 80% of number of samples per epoch ({}), ' + 'setting separate_last_epoch to True' + ) + else: + string = ( + ' > last epoch number of samples ({}) is larger ' + 'than 80% of number of samples per epoch ({}), ' + 'setting separate_last_epoch to False' + ) + print(string.format(last_epoch_num_samples, num_samples_per_epoch), flush=True) + + # doc-idx. + start_time = time.time() + doc_idx = _build_doc_idx(documents, num_epochs, np_rng, separate_last_epoch) + np.save(doc_idx_filename, doc_idx, allow_pickle=True) + logging.info( + ' > elasped time to build and save doc-idx mapping ' + '(seconds): {:4f}'.format(time.time() - start_time) + ) + # sample-idx. + start_time = time.time() + # Use C++ implementation for speed. + # First compile and then import. + assert doc_idx.dtype == np.int32 + assert sizes.dtype == np.int32 + try: + if is_global_rank_zero(): + from nemo.collections.nlp.data.language_modeling.megatron.dataset_utils import compile_helper + + compile_helper() + from nemo.collections.nlp.data.language_modeling.megatron import helpers + except: + raise Exception(f'Could not compile helpers.') + + sample_idx = helpers.build_sample_idx(sizes, doc_idx, seq_length, num_epochs, tokens_per_epoch) + # sample_idx = _build_sample_idx(sizes, doc_idx, seq_length, + # num_epochs, tokens_per_epoch) + np.save(sample_idx_filename, sample_idx, allow_pickle=True) + logging.info( + ' > elasped time to build and save sample-idx mapping ' + '(seconds): {:4f}'.format(time.time() - start_time) + ) + # shuffle-idx. + start_time = time.time() + # -1 is due to data structure used to retieve the index: + # sample i --> [sample_idx[i], sample_idx[i+1]) + if separate_last_epoch: + num_samples_ = num_samples_from_epochs_minus_one + else: + num_samples_ = sample_idx.shape[0] - 1 + shuffle_idx = _build_shuffle_idx(num_samples_, sample_idx.shape[0] - 1, np_rng) + np.save(shuffle_idx_filename, shuffle_idx, allow_pickle=True) + logging.info( + ' > elasped time to build and save shuffle-idx mapping' + ' (seconds): {:4f}'.format(time.time() - start_time) + ) + + torch.distributed.barrier() + + counts = torch.cuda.LongTensor([1]) + torch.distributed.all_reduce(counts, group=parallel_state.get_data_parallel_group()) + torch.distributed.all_reduce(counts, group=parallel_state.get_pipeline_model_parallel_group()) + assert counts[0].item() == ( + torch.distributed.get_world_size() + // torch.distributed.get_world_size(group=parallel_state.get_tensor_model_parallel_group()) + ) + + # Load mappings. + start_time = time.time() + logging.info(' > loading doc-idx mapping from {}'.format(doc_idx_filename)) + doc_idx = np.load(doc_idx_filename, allow_pickle=True, mmap_mode='r') + logging.info(' > loading sample-idx mapping from {}'.format(sample_idx_filename)) + sample_idx = np.load(sample_idx_filename, allow_pickle=True, mmap_mode='r') + logging.info(' > loading shuffle-idx mapping from {}'.format(shuffle_idx_filename)) + shuffle_idx = np.load(shuffle_idx_filename, allow_pickle=True, mmap_mode='r') + logging.info(' loaded indexed file in {:3.3f} seconds'.format(time.time() - start_time)) + logging.info(' total number of samples: {}'.format(sample_idx.shape[0])) + logging.info(' total number of epochs: {}'.format(num_epochs)) + + return doc_idx, sample_idx, shuffle_idx + + +def _num_tokens(documents, sizes): + """Total number of tokens in the dataset.""" + return np.sum(sizes[documents]) + + +def _num_epochs(tokens_per_epoch, seq_length, num_samples): + """Based on number of samples and sequence lenght, calculate how many + epochs will be needed.""" + num_epochs = 0 + total_tokens = 0 + while True: + num_epochs += 1 + total_tokens += tokens_per_epoch + # -1 is because we need to retrieve seq_length + 1 token each time + # but the last token will overlap with the first token of the next + # sample except for the last sample. + if ((total_tokens - 1) // seq_length) >= num_samples: + return num_epochs + + +def _build_doc_idx(documents, num_epochs, np_rng, separate_last_epoch): + """Build an array with length = number-of-epochs * number-of-dcuments. + Each index is mapped to a corresponding document.""" + if not separate_last_epoch or num_epochs == 1: + doc_idx = np.mgrid[0:num_epochs, 0 : len(documents)][1] + doc_idx[:] = documents + doc_idx = doc_idx.reshape(-1) + doc_idx = doc_idx.astype(np.int32) + np_rng.shuffle(doc_idx) + return doc_idx + + doc_idx_first = _build_doc_idx(documents, num_epochs - 1, np_rng, False) + doc_idx_last = _build_doc_idx(documents, 1, np_rng, False) + return np.concatenate((doc_idx_first, doc_idx_last)) + + +def _build_sample_idx(sizes, doc_idx, seq_length, num_epochs, tokens_per_epoch): + """Sample index mapping is a 2D array with sizes + [number-of-samples + 1, 2] where [..., 0] contains + the index into `doc_idx` and [..., 1] is the + starting offset in that document.""" + + # Total number of samples. For -1 see comments in `_num_epochs`. + num_samples = (num_epochs * tokens_per_epoch - 1) // seq_length + sample_idx = np.zeros([num_samples + 1, 2], dtype=np.int32) + + # Index into sample_idx. + sample_index = 0 + # Index into doc_idx. + doc_idx_index = 0 + # Begining offset for each document. + doc_offset = 0 + # Start with first document and no offset. + sample_idx[sample_index][0] = doc_idx_index + sample_idx[sample_index][1] = doc_offset + sample_index += 1 + while sample_index <= num_samples: + # Start with a fresh sequence. + remaining_seq_length = seq_length + 1 + while remaining_seq_length != 0: + # Get the document length. + doc_id = doc_idx[doc_idx_index] + doc_length = sizes[doc_id] - doc_offset + # And add it to the current sequence. + remaining_seq_length -= doc_length + # If we have more than a full sequence, adjust offset and set + # remaining length to zero so we return from the while loop. + # Note that -1 here is for the same reason we have -1 in + # `_num_epochs` calculations. + if remaining_seq_length <= 0: + doc_offset += remaining_seq_length + doc_length - 1 + remaining_seq_length = 0 + else: + # Otherwise, start from the begining of the next document. + doc_idx_index += 1 + doc_offset = 0 + # Record the sequence. + sample_idx[sample_index][0] = doc_idx_index + sample_idx[sample_index][1] = doc_offset + sample_index += 1 + + return sample_idx + + +def _build_shuffle_idx(num_samples, total_size, np_rng): + """Build the range [0, size) and shuffle.""" + print( + ' > building shuffle index with split [0, {}) and [{}, {}) ' + '...'.format(num_samples, num_samples, total_size), + flush=True, + ) + + dtype_ = np.uint32 + if total_size >= (np.iinfo(np.uint32).max - 1): + dtype_ = np.int64 + + shuffle_idx_first = np.arange(start=0, stop=num_samples, step=1, dtype=dtype_) + np_rng.shuffle(shuffle_idx_first) + if num_samples == total_size: + return shuffle_idx_first + + shuffle_idx_last = np.arange(start=num_samples, stop=total_size, step=1, dtype=dtype_) + np_rng.shuffle(shuffle_idx_last) + + return np.concatenate((shuffle_idx_first, shuffle_idx_last)) diff --git a/nemo/collections/nlp/data/language_modeling/megatron/gpt_request_dataset.py b/nemo/collections/nlp/data/language_modeling/megatron/gpt_request_dataset.py new file mode 100644 index 000000000..5b0462805 --- /dev/null +++ b/nemo/collections/nlp/data/language_modeling/megatron/gpt_request_dataset.py @@ -0,0 +1,37 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import Dict + +import torch +from torch.utils.data.dataset import Dataset + + +class GPTRequestDataset(Dataset): + def __init__(self, request: Dict, tokenizer) -> None: + super().__init__() + self.request = request + self.tokenizer = tokenizer + + # tokenize prompt + self.request['tokenized_prompt'] = self.tokenizer.text_to_tokens(request['prompt']) + tokens = self.tokenizer.text_to_ids(request['prompt']) + self.request['tokens'] = torch.tensor(tokens) + + def __len__(self): + return 1 + + def __getitem__(self, index): + return self.request diff --git a/nemo/collections/nlp/data/language_modeling/megatron/helpers.cpp b/nemo/collections/nlp/data/language_modeling/megatron/helpers.cpp new file mode 100644 index 000000000..b36060f87 --- /dev/null +++ b/nemo/collections/nlp/data/language_modeling/megatron/helpers.cpp @@ -0,0 +1,716 @@ +/* +Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. + +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +/* Helper methods for fast index mapping builds */ + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace py = pybind11; +using namespace std; + +const int32_t LONG_SENTENCE_LEN = 512; + + +void build_blending_indices(py::array_t& dataset_index, + py::array_t& dataset_sample_index, + const py::array_t& weights, + const int32_t num_datasets, + const int64_t size, const bool verbose) { + /* Given multiple datasets and a weighting array, build samples + such that it follows those wieghts.*/ + + if (verbose) { + std::cout << "> building indices for blendable datasets ..." << std::endl; + } + + // Get the pointer access without the checks. + auto dataset_index_ptr = dataset_index.mutable_unchecked<1>(); + auto dataset_sample_index_ptr = dataset_sample_index.mutable_unchecked<1>(); + auto weights_ptr = weights.unchecked<1>(); + + // Initialize buffer for number of samples used for each dataset. + int64_t current_samples[num_datasets]; + for(int64_t i = 0; i < num_datasets; ++i) { + current_samples[i] = 0; + } + + // For each sample: + for(int64_t sample_idx = 0; sample_idx < size; ++sample_idx) { + + // Determine where the max error in sampling is happening. + auto sample_idx_double = std::max(static_cast(sample_idx), 1.0); + int64_t max_error_index = 0; + double max_error = weights_ptr[0] * sample_idx_double - + static_cast(current_samples[0]); + for (int64_t dataset_idx = 1; dataset_idx < num_datasets; ++dataset_idx) { + double error = weights_ptr[dataset_idx] * sample_idx_double - + static_cast(current_samples[dataset_idx]); + if (error > max_error) { + max_error = error; + max_error_index = dataset_idx; + } + } + + // Populate the indices. + dataset_index_ptr[sample_idx] = static_cast(max_error_index); + dataset_sample_index_ptr[sample_idx] = current_samples[max_error_index]; + + // Update the total samples. + current_samples[max_error_index] += 1; + + } + + // print info + if (verbose) { + std::cout << " > sample ratios:" << std::endl; + for (int64_t dataset_idx = 0; dataset_idx < num_datasets; ++dataset_idx) { + auto ratio = static_cast(current_samples[dataset_idx]) / + static_cast(size); + std::cout << " dataset " << dataset_idx << ", input: " << + weights_ptr[dataset_idx] << ", achieved: " << ratio << std::endl; + } + } + +} + + +py::array build_sample_idx(const py::array_t& sizes_, + const py::array_t& doc_idx_, + const int32_t seq_length, + const int32_t num_epochs, + const int64_t tokens_per_epoch) { + /* Sample index (sample_idx) is used for gpt2 like dataset for which + the documents are flattened and the samples are built based on this + 1-D flatten array. It is a 2D array with sizes [number-of-samples + 1, 2] + where [..., 0] contains the index into `doc_idx` and [..., 1] is the + starting offset in that document.*/ + + // Consistency checks. + assert(seq_length > 1); + assert(num_epochs > 0); + assert(tokens_per_epoch > 1); + + // Remove bound checks. + auto sizes = sizes_.unchecked<1>(); + auto doc_idx = doc_idx_.unchecked<1>(); + + // Mapping and it's length (1D). + int64_t num_samples = (num_epochs * tokens_per_epoch - 1) / seq_length; + int32_t* sample_idx = new int32_t[2*(num_samples+1)]; + + cout << " using:" << endl << std::flush; + cout << " number of documents: " << + doc_idx_.shape(0) / num_epochs << endl << std::flush; + cout << " number of epochs: " << num_epochs << + endl << std::flush; + cout << " sequence length: " << seq_length << + endl << std::flush; + cout << " total number of samples: " << num_samples << + endl << std::flush; + + // Index into sample_idx. + int64_t sample_index = 0; + // Index into doc_idx. + int64_t doc_idx_index = 0; + // Begining offset for each document. + int32_t doc_offset = 0; + // Start with first document and no offset. + sample_idx[2 * sample_index] = doc_idx_index; + sample_idx[2 * sample_index + 1] = doc_offset; + ++sample_index; + + while (sample_index <= num_samples) { + // Start with a fresh sequence. + int32_t remaining_seq_length = seq_length + 1; + while (remaining_seq_length != 0) { + // Get the document length. + auto doc_id = doc_idx[doc_idx_index]; + auto doc_length = sizes[doc_id] - doc_offset; + // And add it to the current sequence. + remaining_seq_length -= doc_length; + // If we have more than a full sequence, adjust offset and set + // remaining length to zero so we return from the while loop. + // Note that -1 here is for the same reason we have -1 in + // `_num_epochs` calculations. + if (remaining_seq_length <= 0) { + doc_offset += (remaining_seq_length + doc_length - 1); + remaining_seq_length = 0; + } else { + // Otherwise, start from the begining of the next document. + ++doc_idx_index; + doc_offset = 0; + } + } + // Record the sequence. + sample_idx[2 * sample_index] = doc_idx_index; + sample_idx[2 * sample_index + 1] = doc_offset; + ++sample_index; + } + + // Method to deallocate memory. + py::capsule free_when_done(sample_idx, [](void *mem_) { + int32_t *mem = reinterpret_cast(mem_); + delete[] mem; + }); + + // Return the numpy array. + const auto byte_size = sizeof(int32_t); + return py::array(std::vector{num_samples+1, 2}, // shape + {2*byte_size, byte_size}, // C-style contiguous strides + sample_idx, // the data pointer + free_when_done); // numpy array references + +} + + +inline int32_t get_target_sample_len(const int32_t short_seq_ratio, + const int32_t max_length, + std::mt19937& rand32_gen) { + /* Training sample length. */ + if (short_seq_ratio == 0) { + return max_length; + } + const auto random_number = rand32_gen(); + if ((random_number % short_seq_ratio) == 0) { + return 2 + random_number % (max_length - 1); + } + return max_length; +} + + +template +py::array build_mapping_impl(const py::array_t& docs_, + const py::array_t& sizes_, + const int32_t num_epochs, + const uint64_t max_num_samples, + const int32_t max_seq_length, + const double short_seq_prob, + const int32_t seed, + const bool verbose, + const int32_t min_num_sent) { + /* Build a mapping of (start-index, end-index, sequence-length) where + start and end index are the indices of the sentences in the sample + and sequence-length is the target sequence length. + */ + + // Consistency checks. + assert(num_epochs > 0); + assert(max_seq_length > 1); + assert(short_seq_prob >= 0.0); + assert(short_seq_prob <= 1.0); + assert(seed > 0); + + // Remove bound checks. + auto docs = docs_.unchecked<1>(); + auto sizes = sizes_.unchecked<1>(); + + // For efficiency, convert probability to ratio. Note: rand() generates int. + int32_t short_seq_ratio = 0; + if (short_seq_prob > 0) { + short_seq_ratio = static_cast(round(1.0 / short_seq_prob)); + } + + if (verbose) { + const auto sent_start_index = docs[0]; + const auto sent_end_index = docs[docs_.shape(0) - 1]; + const auto num_sentences = sent_end_index - sent_start_index; + cout << " using:" << endl << std::flush; + cout << " number of documents: " << docs_.shape(0) - 1 << + endl << std::flush; + cout << " sentences range: [" << sent_start_index << + ", " << sent_end_index << ")" << endl << std::flush; + cout << " total number of sentences: " << num_sentences << + endl << std::flush; + cout << " number of epochs: " << num_epochs << + endl << std::flush; + cout << " maximum number of samples: " << max_num_samples << + endl << std::flush; + cout << " maximum sequence length: " << max_seq_length << + endl << std::flush; + cout << " short sequence probability: " << short_seq_prob << + endl << std::flush; + cout << " short sequence ration (1/prob): " << short_seq_ratio << + endl << std::flush; + cout << " seed: " << seed << endl << + std::flush; + } + + // Mapping and it's length (1D). + int64_t num_samples = -1; + DocIdx* maps = NULL; + + // Perform two iterations, in the first iteration get the size + // and allocate memory and in the second iteration populate the map. + bool second = false; + for (int32_t iteration=0; iteration<2; ++iteration) { + + // Set the seed so both iterations produce the same results. + std::mt19937 rand32_gen(seed); + + // Set the flag on second iteration. + second = (iteration == 1); + + // Counters: + uint64_t empty_docs = 0; + uint64_t one_sent_docs = 0; + uint64_t long_sent_docs = 0; + + // Current map index. + uint64_t map_index = 0; + + // For each epoch: + for (int32_t epoch=0; epoch= max_num_samples) { + if (verbose && (!second)) { + cout << " reached " << max_num_samples << " samples after " + << epoch << " epochs ..." << endl << std::flush; + } + break; + } + // For each document: + for (int32_t doc=0; doc<(docs.shape(0) - 1); ++doc) { + + // Document sentences are in [sent_index_first, sent_index_last) + const auto sent_index_first = docs[doc]; + const auto sent_index_last = docs[doc + 1]; + + // At the begining of the document previous index is the + // start index. + auto prev_start_index = sent_index_first; + + // Remaining documents. + auto num_remain_sent = sent_index_last - sent_index_first; + + // Some bookkeeping + if ((epoch == 0) && (!second)) { + if (num_remain_sent == 0) { + ++empty_docs; + } + if (num_remain_sent == 1) { + ++one_sent_docs; + } + } + + // Detect documents with long sentences. + bool contains_long_sentence = false; + if (num_remain_sent > 1) { + for (auto sent_index=sent_index_first; + sent_index < sent_index_last; ++sent_index) { + if (sizes[sent_index] > LONG_SENTENCE_LEN){ + if ((epoch == 0) && (!second)) { + ++long_sent_docs; + } + contains_long_sentence = true; + break; + } + } + } + + // If we have more than two sentences. + if ((num_remain_sent >= min_num_sent) && (!contains_long_sentence)) { + + // Set values. + auto seq_len = int32_t{0}; + auto num_sent = int32_t{0}; + auto target_seq_len = get_target_sample_len(short_seq_ratio, + max_seq_length, + rand32_gen); + + // Loop through sentences. + for (auto sent_index=sent_index_first; + sent_index < sent_index_last; ++sent_index) { + + // Add the size and number of sentences. + seq_len += sizes[sent_index]; + ++num_sent; + --num_remain_sent; + + // If we have reached the target length. + // and if not only one sentence is left in the document. + // and if we have at least two sentneces. + // and if we have reached end of the document. + if (((seq_len >= target_seq_len) && + (num_remain_sent > 1) && + (num_sent >= min_num_sent) ) || (num_remain_sent == 0)) { + + // Check for overflow. + if ((3 * map_index + 2) > + std::numeric_limits::max()) { + cout << "number of samples exceeded maximum " + << "allowed by type int64: " + << std::numeric_limits::max() + << endl; + throw std::overflow_error("Number of samples"); + } + + // Populate the map. + if (second) { + const auto map_index_0 = 3 * map_index; + maps[map_index_0] = static_cast(prev_start_index); + maps[map_index_0 + 1] = static_cast(sent_index + 1); + maps[map_index_0 + 2] = static_cast(target_seq_len); + } + + // Update indices / counters. + ++map_index; + prev_start_index = sent_index + 1; + target_seq_len = get_target_sample_len(short_seq_ratio, + max_seq_length, + rand32_gen); + seq_len = 0; + num_sent = 0; + } + + } // for (auto sent_index=sent_index_first; ... + } // if (num_remain_sent > 1) { + } // for (int doc=0; doc < num_docs; ++doc) { + } // for (int epoch=0; epoch < num_epochs; ++epoch) { + + if (!second) { + if (verbose) { + cout << " number of empty documents: " << empty_docs << + endl << std::flush; + cout << " number of documents with one sentence: " << + one_sent_docs << endl << std::flush; + cout << " number of documents with long sentences: " << + long_sent_docs << endl << std::flush; + cout << " will create mapping for " << map_index << + " samples" << endl << std::flush; + } + assert(maps == NULL); + assert(num_samples < 0); + maps = new DocIdx[3*map_index]; + num_samples = static_cast(map_index); + } + + } // for (int iteration=0; iteration < 2; ++iteration) { + + // Shuffle. + // We need a 64 bit random number generator as we might have more + // than 2 billion samples. + std::mt19937_64 rand64_gen(seed + 1); + for (auto i=(num_samples - 1); i > 0; --i) { + const auto j = static_cast(rand64_gen() % (i + 1)); + const auto i0 = 3 * i; + const auto j0 = 3 * j; + // Swap values. + swap(maps[i0], maps[j0]); + swap(maps[i0 + 1], maps[j0 + 1]); + swap(maps[i0 + 2], maps[j0 + 2]); + } + + // Method to deallocate memory. + py::capsule free_when_done(maps, [](void *mem_) { + DocIdx *mem = reinterpret_cast(mem_); + delete[] mem; + }); + + // Return the numpy array. + const auto byte_size = sizeof(DocIdx); + return py::array(std::vector{num_samples, 3}, // shape + {3*byte_size, byte_size}, // C-style contiguous strides + maps, // the data pointer + free_when_done); // numpy array references + +} + + +py::array build_mapping(const py::array_t& docs_, + const py::array_t& sizes_, + const int num_epochs, + const uint64_t max_num_samples, + const int max_seq_length, + const double short_seq_prob, + const int seed, + const bool verbose, + const int32_t min_num_sent) { + + if (sizes_.size() > std::numeric_limits::max()) { + if (verbose) { + cout << " using uint64 for data mapping..." << endl << std::flush; + } + return build_mapping_impl(docs_, sizes_, num_epochs, + max_num_samples, max_seq_length, + short_seq_prob, seed, verbose, + min_num_sent); + } else { + if (verbose) { + cout << " using uint32 for data mapping..." << endl << std::flush; + } + return build_mapping_impl(docs_, sizes_, num_epochs, + max_num_samples, max_seq_length, + short_seq_prob, seed, verbose, + min_num_sent); + } +} + +template +py::array build_blocks_mapping_impl(const py::array_t& docs_, + const py::array_t& sizes_, + const py::array_t& titles_sizes_, + const int32_t num_epochs, + const uint64_t max_num_samples, + const int32_t max_seq_length, + const int32_t seed, + const bool verbose, + const bool use_one_sent_blocks) { + /* Build a mapping of (start-index, end-index, sequence-length) where + start and end index are the indices of the sentences in the sample + and sequence-length is the target sequence length. + */ + + // Consistency checks. + assert(num_epochs > 0); + assert(max_seq_length > 1); + assert(seed > 0); + + // Remove bound checks. + auto docs = docs_.unchecked<1>(); + auto sizes = sizes_.unchecked<1>(); + auto titles_sizes = titles_sizes_.unchecked<1>(); + + if (verbose) { + const auto sent_start_index = docs[0]; + const auto sent_end_index = docs[docs_.shape(0) - 1]; + const auto num_sentences = sent_end_index - sent_start_index; + cout << " using:" << endl << std::flush; + cout << " number of documents: " << docs_.shape(0) - 1 << + endl << std::flush; + cout << " sentences range: [" << sent_start_index << + ", " << sent_end_index << ")" << endl << std::flush; + cout << " total number of sentences: " << num_sentences << + endl << std::flush; + cout << " number of epochs: " << num_epochs << + endl << std::flush; + cout << " maximum number of samples: " << max_num_samples << + endl << std::flush; + cout << " maximum sequence length: " << max_seq_length << + endl << std::flush; + cout << " seed: " << seed << endl << + std::flush; + } + + // Mapping and its length (1D). + int64_t num_samples = -1; + DocIdx* maps = NULL; + + // Acceptable number of sentences per block. + int min_num_sent = 2; + if (use_one_sent_blocks) { + min_num_sent = 1; + } + + // Perform two iterations, in the first iteration get the size + // and allocate memory and in the second iteration populate the map. + bool second = false; + for (int32_t iteration=0; iteration<2; ++iteration) { + + // Set the flag on second iteration. + second = (iteration == 1); + + // Current map index. + uint64_t map_index = 0; + + uint64_t empty_docs = 0; + uint64_t one_sent_docs = 0; + uint64_t long_sent_docs = 0; + // For each epoch: + for (int32_t epoch=0; epoch= max_num_samples) { + if (verbose && (!second)) { + cout << " reached " << max_num_samples << " samples after " + << epoch << " epochs ..." << endl << std::flush; + } + break; + } + // For each document: + for (int32_t doc=0; doc<(docs.shape(0) - 1); ++doc) { + + // Document sentences are in [sent_index_first, sent_index_last) + const auto sent_index_first = docs[doc]; + const auto sent_index_last = docs[doc + 1]; + const auto target_seq_len = max_seq_length - titles_sizes[doc]; + + // At the begining of the document previous index is the + // start index. + auto prev_start_index = sent_index_first; + + // Remaining documents. + auto num_remain_sent = sent_index_last - sent_index_first; + + // Some bookkeeping + if ((epoch == 0) && (!second)) { + if (num_remain_sent == 0) { + ++empty_docs; + } + if (num_remain_sent == 1) { + ++one_sent_docs; + } + } + // Detect documents with long sentences. + bool contains_long_sentence = false; + if (num_remain_sent >= min_num_sent) { + for (auto sent_index=sent_index_first; + sent_index < sent_index_last; ++sent_index) { + if (sizes[sent_index] > LONG_SENTENCE_LEN){ + if ((epoch == 0) && (!second)) { + ++long_sent_docs; + } + contains_long_sentence = true; + break; + } + } + } + // If we have enough sentences and no long sentences. + if ((num_remain_sent >= min_num_sent) && (!contains_long_sentence)) { + + // Set values. + auto seq_len = int32_t{0}; + auto num_sent = int32_t{0}; + + // Loop through sentences. + for (auto sent_index=sent_index_first; + sent_index < sent_index_last; ++sent_index) { + + // Add the size and number of sentences. + seq_len += sizes[sent_index]; + ++num_sent; + --num_remain_sent; + + // If we have reached the target length. + // and there are an acceptable number of sentences left + // and if we have at least the minimum number of sentences. + // or if we have reached end of the document. + if (((seq_len >= target_seq_len) && + (num_remain_sent >= min_num_sent) && + (num_sent >= min_num_sent) ) || (num_remain_sent == 0)) { + + // Populate the map. + if (second) { + const auto map_index_0 = 4 * map_index; + // Each sample has 4 items: the starting sentence index, ending sentence index, + // the index of the document from which the block comes (used for fetching titles) + // and the unique id of the block (used for creating block indexes) + + maps[map_index_0] = static_cast(prev_start_index); + maps[map_index_0 + 1] = static_cast(sent_index + 1); + maps[map_index_0 + 2] = static_cast(doc); + maps[map_index_0 + 3] = static_cast(block_id); + } + + // Update indices / counters. + ++map_index; + ++block_id; + prev_start_index = sent_index + 1; + seq_len = 0; + num_sent = 0; + } + } // for (auto sent_index=sent_index_first; ... + } // if (num_remain_sent > 1) { + } // for (int doc=0; doc < num_docs; ++doc) { + } // for (int epoch=0; epoch < num_epochs; ++epoch) { + + if (!second) { + if (verbose) { + cout << " number of empty documents: " << empty_docs << + endl << std::flush; + cout << " number of documents with one sentence: " << + one_sent_docs << endl << std::flush; + cout << " number of documents with long sentences: " << + long_sent_docs << endl << std::flush; + cout << " will create mapping for " << map_index << + " samples" << endl << std::flush; + } + assert(maps == NULL); + assert(num_samples < 0); + maps = new DocIdx[4*map_index]; + num_samples = static_cast(map_index); + } + + } // for (int iteration=0; iteration < 2; ++iteration) { + + // Shuffle. + // We need a 64 bit random number generator as we might have more + // than 2 billion samples. + std::mt19937_64 rand64_gen(seed + 1); + for (auto i=(num_samples - 1); i > 0; --i) { + const auto j = static_cast(rand64_gen() % (i + 1)); + const auto i0 = 4 * i; + const auto j0 = 4 * j; + // Swap values. + swap(maps[i0], maps[j0]); + swap(maps[i0 + 1], maps[j0 + 1]); + swap(maps[i0 + 2], maps[j0 + 2]); + swap(maps[i0 + 3], maps[j0 + 3]); + } + + // Method to deallocate memory. + py::capsule free_when_done(maps, [](void *mem_) { + DocIdx *mem = reinterpret_cast(mem_); + delete[] mem; + }); + + // Return the numpy array. + const auto byte_size = sizeof(DocIdx); + return py::array(std::vector{num_samples, 4}, // shape + {4*byte_size, byte_size}, // C-style contiguous strides + maps, // the data pointer + free_when_done); // numpy array references + +} + +py::array build_blocks_mapping(const py::array_t& docs_, + const py::array_t& sizes_, + const py::array_t& titles_sizes_, + const int num_epochs, + const uint64_t max_num_samples, + const int max_seq_length, + const int seed, + const bool verbose, + const bool use_one_sent_blocks) { + + if (sizes_.size() > std::numeric_limits::max()) { + if (verbose) { + cout << " using uint64 for data mapping..." << endl << std::flush; + } + return build_blocks_mapping_impl(docs_, sizes_, titles_sizes_, + num_epochs, max_num_samples, max_seq_length, seed, verbose, use_one_sent_blocks); + } else { + if (verbose) { + cout << " using uint32 for data mapping..." << endl << std::flush; + } + return build_blocks_mapping_impl(docs_, sizes_, titles_sizes_, + num_epochs, max_num_samples, max_seq_length, seed, verbose, use_one_sent_blocks); + } +} + +PYBIND11_MODULE(helpers, m) { + m.def("build_mapping", &build_mapping); + m.def("build_blocks_mapping", &build_blocks_mapping); + m.def("build_sample_idx", &build_sample_idx); + m.def("build_blending_indices", &build_blending_indices); +} diff --git a/nemo/collections/nlp/data/language_modeling/megatron/indexed_dataset.py b/nemo/collections/nlp/data/language_modeling/megatron/indexed_dataset.py new file mode 100644 index 000000000..d24e1e424 --- /dev/null +++ b/nemo/collections/nlp/data/language_modeling/megatron/indexed_dataset.py @@ -0,0 +1,564 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# Most of the code here has been copied from: +# fairseq/fairseq/data/indexed_dataset.py + +# with some modifications: + +# Removed IndexedRawTextDataset since it relied on Fairseq dictionary +# other slight modifications to remove fairseq dependencies +# Added document index to index file and made it accessible. +# An empty sentence no longer separates documents. + +import os +import shutil +import struct +from functools import lru_cache +from itertools import accumulate + +import numpy as np +import torch + +from nemo.utils import logging + + +def __best_fitting_dtype(vocab_size=None): + if vocab_size is not None and vocab_size < 65500: + return np.uint16 + else: + return np.int32 + + +def get_available_dataset_impl(): + return ['lazy', 'cached', 'mmap'] + + +def infer_dataset_impl(path): + if IndexedDataset.exists(path): + with open(index_file_path(path), 'rb') as f: + magic = f.read(8) + if magic == IndexedDataset._HDR_MAGIC: + return 'cached' + elif magic == MMapIndexedDataset.Index._HDR_MAGIC[:8]: + return 'mmap' + else: + return None + else: + print(f"Dataset does not exist: {path}") + print("Path should be a basename that both .idx and .bin can be appended to get full filenames.") + return None + + +def make_builder(out_file, impl, vocab_size=None): + if impl == 'mmap': + return MMapIndexedDatasetBuilder(out_file, dtype=__best_fitting_dtype(vocab_size)) + else: + return IndexedDatasetBuilder(out_file) + + +def make_dataset(path, impl, skip_warmup=False): + if not IndexedDataset.exists(path): + print(f"Dataset does not exist: {path}") + print("Path should be a basename that both .idx and .bin can be appended to get full filenames.") + return None + if impl == 'infer': + impl = infer_dataset_impl(path) + if impl == 'lazy' and IndexedDataset.exists(path): + return IndexedDataset(path) + elif impl == 'cached' and IndexedDataset.exists(path): + return IndexedCachedDataset(path) + elif impl == 'mmap' and MMapIndexedDataset.exists(path): + return MMapIndexedDataset(path, skip_warmup) + print(f"Unknown dataset implementation: {impl}") + return None + + +def dataset_exists(path, impl): + if impl == 'mmap': + return MMapIndexedDataset.exists(path) + else: + return IndexedDataset.exists(path) + + +def read_longs(f, n): + a = np.empty(n, dtype=np.int64) + f.readinto(a) + return a + + +def write_longs(f, a): + f.write(np.array(a, dtype=np.int64)) + + +dtypes = {1: np.uint8, 2: np.int8, 3: np.int16, 4: np.int32, 5: np.int64, 6: np.float, 7: np.double, 8: np.uint16} + + +def code(dtype): + for k in dtypes.keys(): + if dtypes[k] == dtype: + return k + raise ValueError(dtype) + + +def index_file_path(prefix_path): + return prefix_path + '.idx' + + +def data_file_path(prefix_path): + return prefix_path + '.bin' + + +def create_doc_idx(sizes): + doc_idx = [0] + for i, s in enumerate(sizes): + if s == 0: + doc_idx.append(i + 1) + return doc_idx + + +class IndexedDataset(torch.utils.data.Dataset): + """Loader for IndexedDataset""" + + _HDR_MAGIC = b'TNTIDX\x00\x00' + + def __init__(self, path): + super().__init__() + self.path = path + self.data_file = None + self.read_index(path) + + def read_index(self, path): + with open(index_file_path(path), 'rb') as f: + magic = f.read(8) + assert magic == self._HDR_MAGIC, ( + 'Index file doesn\'t match expected format. ' 'Make sure that --dataset-impl is configured properly.' + ) + version = f.read(8) + assert struct.unpack('= self._len: + raise IndexError('index out of range') + + def __del__(self): + if self.data_file: + self.data_file.close() + + # @lru_cache(maxsize=8) + def __getitem__(self, idx): + if not self.data_file: + self.read_data(self.path) + if isinstance(idx, int): + i = idx + self.check_index(i) + tensor_size = self.sizes[self.dim_offsets[i] : self.dim_offsets[i + 1]] + a = np.empty(tensor_size, dtype=self.dtype) + self.data_file.seek(self.data_offsets[i] * self.element_size) + self.data_file.readinto(a) + return a + elif isinstance(idx, slice): + start, stop, step = idx.indices(len(self)) + if step != 1: + raise ValueError("Slices into indexed_dataset must be contiguous") + sizes = self.sizes[self.dim_offsets[start] : self.dim_offsets[stop]] + size = sum(sizes) + a = np.empty(size, dtype=self.dtype) + self.data_file.seek(self.data_offsets[start] * self.element_size) + self.data_file.readinto(a) + offsets = list(accumulate(sizes)) + sents = np.split(a, offsets[:-1]) + return sents + + def __len__(self): + return self._len + + def num_tokens(self, index): + return self.sizes[index] + + def size(self, index): + return self.sizes[index] + + @staticmethod + def exists(path): + return os.path.exists(index_file_path(path)) and os.path.exists(data_file_path(path)) + + @property + def supports_prefetch(self): + return False # avoid prefetching to save memory + + +class IndexedCachedDataset(IndexedDataset): + def __init__(self, path): + super().__init__(path) + self.cache = None + self.cache_index = {} + + @property + def supports_prefetch(self): + return True + + def prefetch(self, indices): + if all(i in self.cache_index for i in indices): + return + if not self.data_file: + self.read_data(self.path) + indices = sorted(set(indices)) + total_size = 0 + for i in indices: + total_size += self.data_offsets[i + 1] - self.data_offsets[i] + self.cache = np.empty(total_size, dtype=self.dtype) + ptx = 0 + self.cache_index.clear() + for i in indices: + self.cache_index[i] = ptx + size = self.data_offsets[i + 1] - self.data_offsets[i] + a = self.cache[ptx : ptx + size] + self.data_file.seek(self.data_offsets[i] * self.element_size) + self.data_file.readinto(a) + ptx += size + if self.data_file: + # close and delete data file after prefetch so we can pickle + self.data_file.close() + self.data_file = None + + # @lru_cache(maxsize=8) + def __getitem__(self, idx): + if isinstance(idx, int): + i = idx + self.check_index(i) + tensor_size = self.sizes[self.dim_offsets[i] : self.dim_offsets[i + 1]] + a = np.empty(tensor_size, dtype=self.dtype) + ptx = self.cache_index[i] + np.copyto(a, self.cache[ptx : ptx + a.size]) + return a + elif isinstance(idx, slice): + # Hack just to make this work, can optimizer later if necessary + sents = [] + for i in range(*idx.indices(len(self))): + sents.append(self[i]) + return sents + + +class IndexedDatasetBuilder(object): + element_sizes = {np.uint8: 1, np.int8: 1, np.int16: 2, np.int32: 4, np.int64: 8, np.float: 4, np.double: 8} + + def __init__(self, out_file, dtype=np.int32): + self.out_file = open(out_file, 'wb') + self.dtype = dtype + self.data_offsets = [0] + self.dim_offsets = [0] + self.sizes = [] + self.element_size = self.element_sizes[self.dtype] + self.doc_idx = [0] + + def add_item(self, tensor): + bytes = self.out_file.write(np.array(tensor.numpy(), dtype=self.dtype)) + self.data_offsets.append(self.data_offsets[-1] + bytes / self.element_size) + for s in tensor.size(): + self.sizes.append(s) + self.dim_offsets.append(self.dim_offsets[-1] + len(tensor.size())) + + def end_document(self): + self.doc_idx.append(len(self.sizes)) + + def merge_file_(self, another_file): + index = IndexedDataset(another_file) + assert index.dtype == self.dtype + + begin = self.data_offsets[-1] + for offset in index.data_offsets[1:]: + self.data_offsets.append(begin + offset) + self.sizes.extend(index.sizes) + begin = self.dim_offsets[-1] + for dim_offset in index.dim_offsets[1:]: + self.dim_offsets.append(begin + dim_offset) + + with open(data_file_path(another_file), 'rb') as f: + while True: + data = f.read(1024) + if data: + self.out_file.write(data) + else: + break + + def finalize(self, index_file): + self.out_file.close() + index = open(index_file, 'wb') + index.write(b'TNTIDX\x00\x00') + index.write(struct.pack(' 0, "Provide the argument --vocab-extra-ids 100 to the script" + + def __len__(self): + return self.samples_mapping.shape[0] + + def __getitem__(self, idx): + + start_index, end_index, seq_length = self.samples_mapping[idx] + sample = [] + for index in range(start_index, end_index): + sample.append(self.indexed_dataset[index]) + # Note that this rng state should be numpy and not python since + # python randint is inclusive whereas the numpy one is exclusive. + np_rng = np.random.RandomState(seed=(self.seed + idx)) + return build_training_sample( + sample, + seq_length, + self.max_seq_length, # needed for padding + self.max_seq_length_dec, + self.vocab_id_list, + self.vocab_id_to_token_dict, + self.cls_id, + self.sep_id, + self.mask_id, + self.pad_id, + self.masked_lm_prob, + np_rng, + self.bos_id, + self.eos_id, + self.sentinel_tokens, + ) + + +def build_training_sample( + sample, + target_seq_length, + max_seq_length, + max_seq_length_dec, + vocab_id_list, + vocab_id_to_token_dict, + cls_id, + sep_id, + mask_id, + pad_id, + masked_lm_prob, + np_rng, + bos_id=None, + eos_id=None, + sentinel_tokens=None, +): + """Build training sample. + + Arguments: + sample: A list of sentences in which each sentence is a list token ids. + target_seq_length: Desired sequence length. + max_seq_length: Maximum length of the sequence. All values are padded to + this length. + vocab_id_list: List of vocabulary ids. Used to pick a random id. + vocab_id_to_token_dict: A dictionary from vocab ids to text tokens. + cls_id: Start of example id. + sep_id: Separator id. + mask_id: Mask token id. + pad_id: Padding token id. + masked_lm_prob: Probability to mask tokens. + np_rng: Random number genenrator. Note that this rng state should be + numpy and not python since python randint is inclusive for + the opper bound whereas the numpy one is exclusive. + bos_id: start of decoder example id + eos_id: end of generation id + sentinel_tokens: unique value to be substituted for every replaced span + """ + + assert target_seq_length <= max_seq_length + + # flatten sentences into one list + tokens = [token for sentence in sample for token in sentence] + + # Truncate to `target_sequence_length`. + max_num_tokens = target_seq_length + truncated = len(tokens) > max_num_tokens + tokens = tokens[:max_num_tokens] + + # Masking. + max_predictions_per_seq = masked_lm_prob * max_num_tokens + (tokens, masked_positions, masked_labels, _, masked_spans) = create_masked_lm_predictions( + tokens, + vocab_id_list, + vocab_id_to_token_dict, + masked_lm_prob, + cls_id, + sep_id, + mask_id, + max_predictions_per_seq, + np_rng, + max_ngrams=10, + geometric_dist=True, + masking_style="t5", + ) + + # Padding. + tokens_enc, tokens_dec_in, labels, enc_mask, dec_mask, enc_dec_mask, loss_mask = pad_and_convert_to_numpy( + tokens, + masked_positions, + masked_labels, + pad_id, + max_seq_length, + max_seq_length_dec, + masked_spans, + bos_id, + eos_id, + sentinel_tokens, + ) + + train_sample = { + 'text_enc': tokens_enc, + 'text_dec': tokens_dec_in, + 'labels': labels, + 'loss_mask': loss_mask, + 'truncated': int(truncated), + 'enc_mask': enc_mask, + 'dec_mask': dec_mask, + 'enc_dec_mask': enc_dec_mask, + } + return train_sample + + +def pad_and_convert_to_numpy( + tokens, + masked_positions, + masked_labels, + pad_id, + max_seq_length, + max_seq_length_dec, + masked_spans=None, + bos_id=None, + eos_id=None, + sentinel_tokens=None, +): + """Pad sequences and convert them to numpy.""" + + sentinel_tokens = collections.deque(sentinel_tokens) + t5_input = [] + (t5_decoder_in, t5_decoder_out) = ([bos_id], []) + (start_index, end_index) = (0, None) + for span in masked_spans: + flag = sentinel_tokens.popleft() + + # Append the same tokens in decoder input and output + t5_decoder_in.append(flag) + t5_decoder_in.extend(span.label) + t5_decoder_out.append(flag) + t5_decoder_out.extend(span.label) + + end_index = span.index[0] + t5_input.extend(tokens[start_index:end_index]) + t5_input.append(flag) + + # the next start index is the token after the last span token + start_index = span.index[-1] + 1 + + # Add token to the t5_decoder_out + t5_decoder_out.append(eos_id) + + # Add the remaining tokens to the t5 input + t5_input.extend(tokens[start_index:]) + + # assert (len(t5_input) - len(masked_spans)) + \ + # (len(t5_decoder_in) - (len(masked_spans) + 1)) == len(tokens) + + # Some checks. + + # Encoder-side padding mask. + num_tokens = len(t5_input) + padding_length = max_seq_length - num_tokens + assert padding_length >= 0 + assert len(masked_positions) == len(masked_labels) + + # Tokens.. + filler = [pad_id] * padding_length + tokens_enc = np.array(t5_input + filler, dtype=np.int64) + + # Decoder-side padding mask. + num_tokens_dec = len(t5_decoder_in) + padding_length_dec = max_seq_length_dec - num_tokens_dec + assert padding_length_dec >= 0 + filler_dec = [pad_id] * padding_length_dec + tokens_dec_in = np.array(t5_decoder_in + filler_dec, dtype=np.int64) + + # Create attention masks + enc_mask = make_attention_mask(tokens_enc, tokens_enc) + enc_dec_mask = make_attention_mask(tokens_dec_in, tokens_enc) + dec_mask = make_attention_mask(tokens_dec_in, tokens_dec_in) + dec_mask = dec_mask * make_history_mask(tokens_dec_in) + + # Labels mask. + labels = t5_decoder_out + ([-1] * padding_length_dec) + labels = np.array(labels, dtype=np.int64) + + # Loss mask + loss_mask = ([1] * num_tokens_dec) + ([0] * padding_length_dec) + loss_mask = np.array(loss_mask, dtype=np.int64) + + return tokens_enc, tokens_dec_in, labels, enc_mask, dec_mask, enc_dec_mask, loss_mask + + +def make_attention_mask(source_block, target_block): + """ + Returns a 2-dimensional (2-D) attention mask + :param source_block: 1-D array + :param target_block: 1-D array + """ + mask = (target_block[None, :] >= 1) * (source_block[:, None] >= 1) + mask = mask.astype(np.int64) + # (source_length, target_length) + return mask + + +def make_attention_mask_3d(source_block, target_block): + """ + Returns a 3-dimensional (3-D) attention mask + :param source_block: 1-D array + :param target_block: 1-D array + """ + mask = (target_block[:, None, :] >= 1) * (source_block[:, :, None] >= 1) + # (batch, source_length, target_length) + # mask = mask.astype(np.int64) + return mask + + +def make_history_mask(block): + length = block.shape[0] + arange = np.arange(length) + history_mask = arange[None,] <= arange[:, None] + history_mask = history_mask.astype(np.int64) + return history_mask + + +def make_history_mask_3d(block): + batch, length = block.shape + arange = torch.arange(length, device=block.device) + history_mask = (arange[None,] <= arange[:, None])[ + None, + ] + history_mask = history_mask.expand(batch, length, length) + return history_mask diff --git a/nemo/collections/nlp/models/language_modeling/megatron/__init__.py b/nemo/collections/nlp/models/language_modeling/megatron/__init__.py new file mode 100644 index 000000000..36494a48b --- /dev/null +++ b/nemo/collections/nlp/models/language_modeling/megatron/__init__.py @@ -0,0 +1,18 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# from nemo.collections.nlp.models.language_modeling.megatron.bert_model import BertModel +from nemo.collections.nlp.models.language_modeling.megatron.gpt_model import GPTModel + +# from nemo.collections.nlp.models.language_modeling.megatron.t5_model import T5Model diff --git a/nemo/collections/nlp/models/language_modeling/megatron/bert_model.py b/nemo/collections/nlp/models/language_modeling/megatron/bert_model.py new file mode 100644 index 000000000..7abba3fa1 --- /dev/null +++ b/nemo/collections/nlp/models/language_modeling/megatron/bert_model.py @@ -0,0 +1,221 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""BERT model.""" + +import torch +from apex.transformer import tensor_parallel + +from nemo.collections.nlp.modules.common.megatron.enums import AttnMaskType +from nemo.collections.nlp.modules.common.megatron.fused_layer_norm import MixedFusedLayerNorm as LayerNorm +from nemo.collections.nlp.modules.common.megatron.language_model import get_language_model, parallel_lm_logits +from nemo.collections.nlp.modules.common.megatron.module import MegatronModule +from nemo.collections.nlp.modules.common.megatron.utils import ( + erf_gelu, + get_linear_layer, + init_method_normal, + openai_gelu, + scaled_init_method_normal, +) + + +def bert_extended_attention_mask(attention_mask): + # We create a 3D attention mask from a 2D tensor mask. + # [b, 1, s] + attention_mask_b1s = attention_mask.unsqueeze(1) + # [b, s, 1] + attention_mask_bs1 = attention_mask.unsqueeze(2) + # [b, s, s] + attention_mask_bss = attention_mask_b1s * attention_mask_bs1 + # [b, 1, s, s] + extended_attention_mask = attention_mask_bss.unsqueeze(1) + + # Convert attention mask to binary: + extended_attention_mask = extended_attention_mask < 0.5 + + return extended_attention_mask + + +def bert_position_ids(token_ids): + # Create position ids + seq_length = token_ids.size(1) + position_ids = torch.arange(seq_length, dtype=torch.long, device=token_ids.device) + position_ids = position_ids.unsqueeze(0).expand_as(token_ids) + + return position_ids + + +class BertLMHead(MegatronModule): + """Masked LM head for Bert + + Arguments: + mpu_vocab_size: model parallel size of vocabulary. + hidden_size: hidden size + init_method: init method for weight initialization + layernorm_epsilon: tolerance for layer norm divisions + parallel_output: whether output logits being distributed or not. + """ + + def __init__(self, mpu_vocab_size, hidden_size, init_method, layernorm_epsilon, parallel_output): + + super(BertLMHead, self).__init__() + + args = get_args() + + self.bias = torch.nn.Parameter(torch.zeros(mpu_vocab_size)) + parallel_state.set_tensor_model_parallel_attributes(self.bias, True, 0, 1) + self.parallel_output = parallel_output + + self.dense = get_linear_layer(hidden_size, hidden_size, init_method) + self.layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon) + self.gelu = torch.nn.functional.gelu + if args.openai_gelu: + self.gelu = openai_gelu + elif args.onnx_safe: + self.gelu = erf_gelu + + def forward(self, hidden_states, word_embeddings_weight): + hidden_states = self.dense(hidden_states) + hidden_states = self.gelu(hidden_states) + hidden_states = self.layernorm(hidden_states) + output = parallel_lm_logits(hidden_states, word_embeddings_weight, self.parallel_output, bias=self.bias) + return output + + +def post_language_model_processing( + lm_output, pooled_output, lm_head, binary_head, lm_labels, logit_weights, fp16_lm_cross_entropy +): + # Output. + lm_logits = lm_head(lm_output, logit_weights) + + binary_logits = None + if binary_head is not None: + binary_logits = binary_head(pooled_output) + + if lm_labels is None: + return lm_logits, binary_logits + else: + if fp16_lm_cross_entropy: + assert lm_logits.dtype == torch.half + lm_loss = tensor_parallel.vocab_parallel_cross_entropy(lm_logits, lm_labels) + else: + lm_loss = tensor_parallel.vocab_parallel_cross_entropy(lm_logits.float(), lm_labels) + return lm_loss, binary_logits + + +class BertModel(MegatronModule): + """Bert Language model.""" + + def __init__( + self, num_tokentypes=2, add_binary_head=True, parallel_output=True, pre_process=True, post_process=True + ): + super(BertModel, self).__init__() + args = get_args() + + self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy + self.add_binary_head = add_binary_head + self.parallel_output = parallel_output + self.pre_process = pre_process + self.post_process = post_process + + init_method = init_method_normal(args.init_method_std) + scaled_init_method = scaled_init_method_normal(args.init_method_std, args.num_layers) + + self.language_model, self._language_model_key = get_language_model( + num_tokentypes=num_tokentypes, + add_pooler=self.add_binary_head, + encoder_attn_mask_type=AttnMaskType.padding, + init_method=init_method, + scaled_init_method=scaled_init_method, + pre_process=self.pre_process, + post_process=self.post_process, + ) + + self.initialize_word_embeddings(init_method_normal) + if self.post_process: + self.lm_head = BertLMHead( + self.word_embeddings_weight().size(0), + args.hidden_size, + init_method, + args.layernorm_epsilon, + parallel_output, + ) + self._lm_head_key = 'lm_head' + self.binary_head = None + if self.add_binary_head: + self.binary_head = get_linear_layer(args.hidden_size, 2, init_method) + self._binary_head_key = 'binary_head' + + def set_input_tensor(self, input_tensor): + """See megatron.model.transformer.set_input_tensor()""" + self.language_model.set_input_tensor(input_tensor) + + def forward(self, bert_model_input, attention_mask, tokentype_ids=None, lm_labels=None): + + extended_attention_mask = bert_extended_attention_mask(attention_mask) + input_ids = bert_model_input + position_ids = bert_position_ids(input_ids) + + lm_output = self.language_model(input_ids, position_ids, extended_attention_mask, tokentype_ids=tokentype_ids) + + if self.post_process and self.add_binary_head: + lm_output, pooled_output = lm_output + else: + pooled_output = None + + if self.post_process: + return post_language_model_processing( + lm_output, + pooled_output, + self.lm_head, + self.binary_head, + lm_labels, + self.word_embeddings_weight(), + self.fp16_lm_cross_entropy, + ) + else: + return lm_output + + def state_dict_for_save_checkpoint(self, destination=None, prefix='', keep_vars=False): + """For easy load when model is combined with other heads, + add an extra key.""" + + state_dict_ = {} + state_dict_[self._language_model_key] = self.language_model.state_dict_for_save_checkpoint( + destination, prefix, keep_vars + ) + if self.post_process: + state_dict_[self._lm_head_key] = self.lm_head.state_dict_for_save_checkpoint( + destination, prefix, keep_vars + ) + if self.post_process and self.add_binary_head: + state_dict_[self._binary_head_key] = self.binary_head.state_dict(destination, prefix, keep_vars) + # Save word_embeddings. + if self.post_process and not self.pre_process: + state_dict_[self._word_embeddings_for_head_key] = self.word_embeddings.state_dict( + destination, prefix, keep_vars + ) + return state_dict_ + + def load_state_dict(self, state_dict, strict=True): + """Customized load.""" + + self.language_model.load_state_dict(state_dict[self._language_model_key], strict=strict) + if self.post_process: + self.lm_head.load_state_dict(state_dict[self._lm_head_key], strict=strict) + if self.post_process and self.add_binary_head: + self.binary_head.load_state_dict(state_dict[self._binary_head_key], strict=strict) + # Load word_embeddings. + if self.post_process and not self.pre_process: + self.word_embeddings.load_state_dict(state_dict[self._word_embeddings_for_head_key], strict=strict) diff --git a/nemo/collections/nlp/models/language_modeling/megatron/gpt_model.py b/nemo/collections/nlp/models/language_modeling/megatron/gpt_model.py new file mode 100644 index 000000000..a4680ec56 --- /dev/null +++ b/nemo/collections/nlp/models/language_modeling/megatron/gpt_model.py @@ -0,0 +1,191 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""GPT-2 model.""" + +import torch +from apex.transformer import tensor_parallel +from apex.transformer.enums import AttnMaskType + +from nemo.collections.nlp.modules.common.megatron.language_model import get_language_model, parallel_lm_logits +from nemo.collections.nlp.modules.common.megatron.module import MegatronModule +from nemo.collections.nlp.modules.common.megatron.utils import init_method_normal, scaled_init_method_normal + + +def post_language_model_processing( + lm_output, + labels, + logit_weights, + get_key_value, + parallel_output, + forward_method_parallel_output, + fp16_lm_cross_entropy, +): + if get_key_value: + lm_output, presents = lm_output + + # Output. + if forward_method_parallel_output is not None: + parallel_output = forward_method_parallel_output + output = parallel_lm_logits(lm_output, logit_weights, parallel_output) + + if get_key_value: + output = [output, presents] + + if labels is None: + return output + else: + if fp16_lm_cross_entropy: + assert output.dtype == torch.half + loss = tensor_parallel.vocab_parallel_cross_entropy(output, labels) + else: + loss = tensor_parallel.vocab_parallel_cross_entropy(output.float(), labels) + return loss + + +class GPTModel(MegatronModule): + """GPT-2 Language model.""" + + def __init__( + self, + vocab_size, + hidden_size, + max_position_embeddings, + num_layers, + num_attention_heads, + ffn_hidden_size, + apply_query_key_layer_scaling=True, + kv_channels=None, + num_tokentypes=0, + parallel_output=True, + pre_process=True, + post_process=True, + init_method_std=0.02, + fp16_lm_cross_entropy=False, + use_cpu_initialization=False, + hidden_dropout=0.1, + fused_fp16=False, + fused_bf16=False, + fp32_residual_connection=False, + activations_checkpoint_method=None, + activations_checkpoint_num_layers=1, + layernorm_epsilon=1e-5, + bias_gelu_fusion=True, + openai_gelu=False, + onnx_safe=False, + ): + super(GPTModel, self).__init__() + + self.parallel_output = parallel_output + self.pre_process = pre_process + self.post_process = post_process + self.fp16_lm_cross_entropy = fp16_lm_cross_entropy + + assert not (fused_fp16 and fused_bf16), "both fused_fp16 and fused_bf16 flags cannot be True at the same time." + + if kv_channels is None: + assert ( + hidden_size % num_attention_heads == 0 + ), 'hidden_size must be divisible by num_attention_heads if kv_channels is None' + kv_channels = hidden_size // num_attention_heads + + self.language_model, self._language_model_key = get_language_model( + vocab_size=vocab_size, + hidden_size=hidden_size, + hidden_dropout=hidden_dropout, + num_tokentypes=num_tokentypes, + max_position_embeddings=max_position_embeddings, + num_layers=num_layers, + num_attention_heads=num_attention_heads, + apply_query_key_layer_scaling=apply_query_key_layer_scaling, + kv_channels=kv_channels, + ffn_hidden_size=ffn_hidden_size, + add_pooler=False, + encoder_attn_mask_type=AttnMaskType.causal, + init_method=init_method_normal(init_method_std), + scaled_init_method=scaled_init_method_normal(init_method_std, num_layers), + pre_process=self.pre_process, + post_process=self.post_process, + init_method_std=init_method_std, + use_cpu_initialization=use_cpu_initialization, + fused_fp16=fused_fp16, + fused_bf16=fused_bf16, + fp32_residual_connection=fp32_residual_connection, + activations_checkpoint_method=activations_checkpoint_method, + activations_checkpoint_num_layers=activations_checkpoint_num_layers, + layernorm_epsilon=layernorm_epsilon, + bias_gelu_fusion=bias_gelu_fusion, + openai_gelu=openai_gelu, + onnx_safe=onnx_safe, + ) + + self.initialize_word_embeddings( + init_method=init_method_normal(init_method_std), vocab_size=vocab_size, hidden_size=hidden_size + ) + + def set_input_tensor(self, input_tensor): + """See megatron.model.transformer.set_input_tensor()""" + self.language_model.set_input_tensor(input_tensor) + + def forward( + self, + input_ids, + position_ids, + attention_mask, + labels=None, + tokentype_ids=None, + layer_past=None, + get_key_value=False, + forward_method_parallel_output=None, + ): + + lm_output = self.language_model( + input_ids, position_ids, attention_mask, layer_past=layer_past, get_key_value=get_key_value + ) + + if self.post_process: + return post_language_model_processing( + lm_output, + labels, + self.word_embeddings_weight(), + get_key_value, + self.parallel_output, + forward_method_parallel_output, + self.fp16_lm_cross_entropy, + ) + else: + return lm_output + + def state_dict_for_save_checkpoint(self, destination=None, prefix='', keep_vars=False): + + state_dict_ = {} + state_dict_[self._language_model_key] = self.language_model.state_dict_for_save_checkpoint( + destination, prefix, keep_vars + ) + # Save word_embeddings. + if self.post_process and not self.pre_process: + state_dict_[self._word_embeddings_for_head_key] = self.word_embeddings.state_dict( + destination, prefix, keep_vars + ) + return state_dict_ + + def load_state_dict(self, state_dict, strict=True): + """Customized load.""" + + # Load word_embeddings. + if self.post_process and not self.pre_process: + self.word_embeddings.load_state_dict(state_dict[self._word_embeddings_for_head_key], strict=strict) + if self._language_model_key in state_dict: + state_dict = state_dict[self._language_model_key] + self.language_model.load_state_dict(state_dict, strict=strict) diff --git a/nemo/collections/nlp/models/language_modeling/megatron/t5_model.py b/nemo/collections/nlp/models/language_modeling/megatron/t5_model.py new file mode 100644 index 000000000..e06c2649d --- /dev/null +++ b/nemo/collections/nlp/models/language_modeling/megatron/t5_model.py @@ -0,0 +1,163 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""T5 model.""" + +import torch +from apex.transformer import tensor_parallel +from apex.transformer.enums import AttnMaskType +from megatron import get_args + +from nemo.collections.nlp.modules.common.megatron.language_model import get_language_model, parallel_lm_logits +from nemo.collections.nlp.modules.common.megatron.module import MegatronModule +from nemo.collections.nlp.modules.common.megatron.transformer import LayerNorm +from nemo.collections.nlp.modules.common.megatron.utils import init_method_normal, scaled_init_method_normal + + +def t5_extended_attention_mask(attention_mask_list): + def attn_mask_postprocess(attn_mask): + # [b, 1, s, s] + extended_attention_mask = attn_mask.unsqueeze(1) + return extended_attention_mask + + return [attn_mask_postprocess(attn_mask) for attn_mask in attention_mask_list] + + +def t5_position_ids(token_ids): + # Create position ids + seq_length = token_ids.size(1) + position_ids = torch.arange(seq_length, dtype=torch.long, device=token_ids.device) + position_ids = position_ids.unsqueeze(0).expand_as(token_ids) + + return position_ids + + +class T5LMHead(MegatronModule): + """Masked LM head for T5 + + Arguments: + mpu_vocab_size: model parallel size of vocabulary. + hidden_size: hidden size + init_method: init method for weight initialization + layernorm_epsilon: tolerance for layer norm divisions + parallel_output: wether output logits being distributed or not. + """ + + def __init__(self, mpu_vocab_size, parallel_output): + super(T5LMHead, self).__init__() + + args = get_args() + + self.bias = torch.nn.Parameter(torch.zeros(mpu_vocab_size)) + self.bias.model_parallel = True + self.bias.partition_dim = 0 + self.bias.stride = 1 + self.parallel_output = parallel_output + + def forward(self, hidden_states, word_embeddings_weight): + output = parallel_lm_logits(hidden_states, word_embeddings_weight, self.parallel_output, bias=self.bias) + return output + + +class T5Model(MegatronModule): + """T5 Language model.""" + + def __init__(self, num_tokentypes=0, parallel_output=True): + super(T5Model, self).__init__() + args = get_args() + + self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy + self.parallel_output = parallel_output + init_method = init_method_normal(args.init_method_std) + scaled_init_method = scaled_init_method_normal(args.init_method_std, args.num_layers) + + self.language_model, self._language_model_key = get_language_model( + num_tokentypes=num_tokentypes, + add_pooler=False, + add_decoder=True, + encoder_attn_mask_type=AttnMaskType.padding, + init_method=init_method, + scaled_init_method=scaled_init_method, + ) + + self.lm_head = T5LMHead(self.language_model.embedding.word_embeddings.weight.size(0), parallel_output) + self._lm_head_key = 'lm_head' + + def set_input_tensor(self, input_tensor): + """See megatron.model.transformer.set_input_tensor()""" + self.language_model.set_input_tensor(input_tensor) + + def forward( + self, + encoder_input_ids, + decoder_input_ids, + encoder_attn_mask, + decoder_attn_mask, + encoder_decoder_attn_mask, + tokentype_ids=None, + lm_labels=None, + enc_hidden_states=None, + ): + + # Converting the attention masks to proper parameter settings + encoder_attn_mask, decoder_attn_mask, encoder_decoder_attn_mask = t5_extended_attention_mask( + [encoder_attn_mask, decoder_attn_mask, encoder_decoder_attn_mask] + ) + + encoder_position_ids = t5_position_ids(encoder_input_ids) + decoder_position_ids = t5_position_ids(decoder_input_ids) + + lm_output = self.language_model( + encoder_input_ids, + encoder_position_ids, + encoder_attn_mask, + decoder_input_ids, + decoder_position_ids, + decoder_attn_mask, + encoder_decoder_attn_mask, + tokentype_ids=tokentype_ids, + enc_hidden_states=enc_hidden_states, + ) + + decoder_output, encoder_output = lm_output + + # Output. + lm_logits = self.lm_head(decoder_output, self.language_model.embedding.word_embeddings.weight) + + if lm_labels is None: + return lm_logits, encoder_output + else: + if self.fp16_lm_cross_entropy: + assert lm_logits.dtype == torch.half + lm_loss = tensor_parallel.vocab_parallel_cross_entropy(lm_logits, lm_labels) + else: + lm_loss = tensor_parallel.vocab_parallel_cross_entropy(lm_logits.float(), lm_labels) + return lm_loss, encoder_output + + def state_dict_for_save_checkpoint(self, destination=None, prefix='', keep_vars=False): + """For easy load when model is combined with other heads, + add an extra key.""" + + state_dict_ = {} + state_dict_[self._language_model_key] = self.language_model.state_dict_for_save_checkpoint( + destination, prefix, keep_vars + ) + state_dict_[self._lm_head_key] = self.lm_head.state_dict_for_save_checkpoint(destination, prefix, keep_vars) + return state_dict_ + + def load_state_dict(self, state_dict, strict=True): + """Customized load.""" + + self.language_model.load_state_dict(state_dict[self._language_model_key], strict=strict) + self.lm_head.load_state_dict(state_dict[self._lm_head_key], strict=strict) diff --git a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py new file mode 100644 index 000000000..59aafe0da --- /dev/null +++ b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py @@ -0,0 +1,406 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re +from typing import Any, Dict, Optional + +import torch +from apex.transformer import parallel_state, tensor_parallel +from omegaconf.dictconfig import DictConfig +from pytorch_lightning.trainer.trainer import Trainer + +from nemo.collections.nlp.data.language_modeling.megatron.data_samplers import ( + MegatronPretrainingRandomSampler, + MegatronPretrainingSampler, +) +from nemo.collections.nlp.data.language_modeling.megatron.gpt_dataset import build_train_valid_test_datasets +from nemo.collections.nlp.models.language_modeling.megatron.gpt_model import GPTModel +from nemo.collections.nlp.models.nlp_model import NLPModel +from nemo.collections.nlp.modules.common.megatron.clip_grads import clip_grad_norm_fp32 +from nemo.collections.nlp.modules.common.megatron.megatron_init import ( + initialize_model_parallel_for_nemo, + set_jit_fusion_options, +) +from nemo.collections.nlp.modules.common.megatron.utils import ( + average_losses_across_data_parallel_group, + get_ltor_masks_and_position_ids, +) +from nemo.collections.nlp.modules.common.tokenizer_utils import get_nmt_tokenizer +from nemo.collections.nlp.parts.nlp_overrides import NLPNativeMixedPrecisionPlugin +from nemo.utils import AppState, logging + + +class MegatronGPTModel(NLPModel): + """ + Megatron GPT pretraining + """ + + def __init__(self, cfg: DictConfig, trainer: Trainer): + super().__init__(cfg, trainer=trainer) + self.cfg = cfg + + if self.cfg.get('use_cpu_initialization', False) is False: + torch.cuda.set_device(trainer.local_rank) + + # buffer used during train_step for logging average loss over gradient accumulation steps + self._reduced_loss_buffer = [] + + initialize_model_parallel_for_nemo( + world_size=trainer.world_size, + global_rank=trainer.global_rank, + local_rank=trainer.local_rank, + tensor_model_parallel_size=cfg.get('tensor_model_parallel_size', 1), + seed=self.cfg.get('seed', 1234), + ) + + if not self.cfg.get('fused_bf16'): + set_jit_fusion_options() + + self.tokenizer = get_nmt_tokenizer( + library=self.cfg.tokenizer.library, + model_name=self.cfg.tokenizer.type, + tokenizer_model=self.register_artifact("tokenizer_model", self.cfg.tokenizer.model), + vocab_file=self.register_artifact("vocab_file", self.cfg.tokenizer.vocab_file), + merges_file=self.register_artifact("merges_file", self.cfg.tokenizer.merge_file), + ) + + vocab_size = self.tokenizer.vocab_size + + padded_vocab_size = self._vocab_size_with_padding( + orig_vocab_size=vocab_size, + make_vocab_size_divisible_by=cfg.get('make_vocab_size_divisible_by', 128), + tensor_model_parallel_size=cfg.get('tensor_model_parallel_size', 1), + ) + + self.model = GPTModel( + vocab_size=padded_vocab_size, + hidden_size=cfg.hidden_size, + max_position_embeddings=cfg.max_position_embeddings, + num_layers=cfg.num_layers, + num_attention_heads=cfg.num_attention_heads, + apply_query_key_layer_scaling=cfg.get('apply_query_key_layer_scaling', True), + kv_channels=cfg.get('kv_channels', None), + ffn_hidden_size=cfg.ffn_hidden_size, + num_tokentypes=0, + parallel_output=True, + pre_process=cfg.get('pre_process', True), + post_process=cfg.get('post_process', True), + init_method_std=cfg.get('init_method_std', 0.02), + fp16_lm_cross_entropy=cfg.get('fp16_lm_cross_entropy', False), + use_cpu_initialization=cfg.get('use_cpu_initialization', False), + hidden_dropout=cfg.get('hidden_dropout', 0.1), + fused_fp16=cfg.get('fused_fp16', False), + fused_bf16=cfg.get('fused_bf16', False), + fp32_residual_connection=cfg.get('fp32_residual_connection', False), + activations_checkpoint_method=cfg.get('activations_checkpoint_method', None), + activations_checkpoint_num_layers=cfg.get('activations_checkpoint_num_layers', 1), + layernorm_epsilon=cfg.get('layernorm_epsilon', 1e-5), + onnx_safe=cfg.get('onnx_safe', False), + ) + + def forward(self, tokens, position_ids, attention_mask, labels): + output_tensor = self.model(tokens, position_ids, attention_mask, labels=labels) + return output_tensor + + def training_step(self, batch, batch_idx): + tokens, labels, loss_mask, attention_mask, position_ids = self.process_batch(batch) + output_tensor = self(tokens, position_ids, attention_mask, labels) + loss = self.loss_func(loss_mask, output_tensor) + reduced_loss = average_losses_across_data_parallel_group([loss]) + + # cache reduced loss while accumulating gradients + self._reduced_loss_buffer.append(reduced_loss[0]) + + if (batch_idx + 1) % self.trainer.accumulate_grad_batches == 0: + # Reduced loss for logging. + average_reduced_loss = sum(self._reduced_loss_buffer) / len(self._reduced_loss_buffer) + self.log('reduced_train_loss', average_reduced_loss, prog_bar=True) + lr = self._optimizer.param_groups[0]['lr'] + self.log('lr', lr) + self.log('global_step', self.trainer.global_step, prog_bar=True) + self.log('consumed_samples', self.compute_consumed_samples(self.trainer.global_step), prog_bar=True) + self._reduced_loss_buffer = [] + return loss + + def validation_step(self, batch, batch_idx): + tokens, labels, loss_mask, attention_mask, position_ids = self.process_batch(batch) + output_tensor = self(tokens, position_ids, attention_mask, labels) + loss = self.loss_func(loss_mask, output_tensor) + # TODO: add text generation + # take the first k tokens and then generate text - compare with ground truth (won't look similar in general) + """ + k = num_context + n = max_generate_length + context_tokens = tokens[0:k] + while k < n: + output_tensor = self(context_tokens) + next_token = sample(output_tensor) + context_tokens.append(next_token) + k += 1 + """ + return loss + + def validation_epoch_end(self, outputs): + averaged_loss = average_losses_across_data_parallel_group(outputs) + self.log('val_loss', averaged_loss[0], prog_bar=True) + self.log('consumed_samples', self.compute_consumed_samples(self.trainer.global_step)) + + def test_step(self, batch, batch_idx): + return self.validation_step(batch, batch_idx) + + def test_epoch_end(self, outputs): + averaged_loss = average_losses_across_data_parallel_group(outputs) + logging.info(f'test_loss: {averaged_loss[0]}') + + def loss_func(self, loss_mask, output_tensor): + losses = output_tensor.float() + loss_mask = loss_mask.view(-1).float() + # TODO: add nemo version here + loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum() # sequence level nll + return loss + + def process_batch(self, batch): + + # Items and their type. + keys = ['text'] + datatype = torch.int64 + + data = batch + data_b = tensor_parallel.broadcast_data(keys, data, datatype) + + # Unpack. + tokens_ = data_b['text'].long() + labels = tokens_[:, 1:].contiguous() + tokens = tokens_[:, :-1].contiguous() + + # Get the masks and postition ids. + attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids( + tokens, + self.tokenizer.eos_id, + self.cfg.data.get('reset_position_ids', False), + self.cfg.data.get('reset_attention_mask', False), + self.cfg.data.get('eod_mask_loss', False), + ) + + return tokens, labels, loss_mask, attention_mask, position_ids + + def build_train_valid_test_datasets(self): + logging.info('Building GPT datasets.') + global_batch_size = self.trainer.world_size * self.cfg.micro_batch_size / self.cfg.tensor_model_parallel_size + eval_iters = (self.trainer.max_steps // self.trainer.val_check_interval + 1) * self.trainer.limit_val_batches + test_iters = self.trainer.limit_test_batches + train_valid_test_num_samples = [ + self.trainer.max_steps * global_batch_size, + eval_iters * global_batch_size, + test_iters * global_batch_size, + ] + self._train_ds, self._validation_ds, self._test_ds = build_train_valid_test_datasets( + cfg=self.cfg, + trainer=self.trainer, + data_prefix=self.cfg.data.data_prefix, + data_impl=self.cfg.data.data_impl, + splits_string=self.cfg.data.splits_string, + train_valid_test_num_samples=train_valid_test_num_samples, + seq_length=self.cfg.data.seq_length, + seed=self.cfg.seed, + skip_warmup=self.cfg.data.get('skip_warmup', True), + ) + if self._train_ds is not None: + logging.info(f'Length of train dataset: {len(self._train_ds)}') + if self._validation_ds is not None: + logging.info(f'Length of val dataset: {len(self._validation_ds)}') + if self._test_ds is not None: + logging.info(f'Length of test dataset: {len(self._test_ds)}') + logging.info(f'Finished building GPT datasets.') + return self._train_ds, self._validation_ds, self._test_ds + + def build_pretraining_data_loader(self, dataset, consumed_samples): + """Buld dataloader given an input dataset.""" + + if dataset is None: + return None + + # Megatron sampler + if hasattr(self.cfg.data, 'dataloader_type') and self.cfg.data.dataloader_type is not None: + if self.cfg.data.dataloader_type == 'single': + batch_sampler = MegatronPretrainingSampler( + total_samples=len(dataset), + consumed_samples=consumed_samples, + micro_batch_size=self.cfg.micro_batch_size, + data_parallel_rank=parallel_state.get_data_parallel_rank(), + data_parallel_size=parallel_state.get_data_parallel_world_size(), + ) + elif self.cfg.data.dataloader_type == 'cyclic': + batch_sampler = MegatronPretrainingRandomSampler( + total_samples=len(dataset), + consumed_samples=consumed_samples, + micro_batch_size=self.cfg.micro_batch_size, + data_parallel_rank=parallel_state.get_data_parallel_rank(), + data_parallel_size=parallel_state.get_data_parallel_world_size(), + ) + else: + raise ValueError('cfg.data.dataloader_type must be "single" or "cyclic"') + else: + raise ValueError('cfg.data.dataloader_type not found. Must be "single" or "cyclic"') + + # Torch dataloader. + return torch.utils.data.DataLoader( + dataset, batch_sampler=batch_sampler, num_workers=self.cfg.data.num_workers, pin_memory=True, + ) + + def setup(self, stage=None): + if stage == 'predict': + return + # TODO: consider adding a ModelPT guard to check if model is being restored. + # allowing restored models to optionally setup datasets + self.build_train_valid_test_datasets() + self.setup_training_data(self.cfg.data) + self.setup_validation_data(self.cfg.data) + self.setup_test_data(self.cfg.data) + + def setup_training_data(self, cfg): + if hasattr(self, '_train_ds'): + resume_checkpoint_path = self.trainer.checkpoint_connector.resume_checkpoint_path + if resume_checkpoint_path: + consumed_samples = int( + float(re.findall(r"consumed_samples\=([0-9]+.[0-9]+)", resume_checkpoint_path)[0]) + ) + else: + consumed_samples = 0 + logging.info( + f'Setting up train dataloader with len(len(self._train_ds)): {len(self._train_ds)} and consumed samples: {consumed_samples}' + ) + self._train_dl = self.build_pretraining_data_loader(self._train_ds, consumed_samples) + + def setup_validation_data(self, cfg): + if hasattr(self, '_validation_ds'): + consumed_samples = 0 + logging.info( + f'Setting up validation dataloader with len(len(self._validation_ds)): {len(self._validation_ds)} and consumed samples: {consumed_samples}' + ) + self._validation_dl = self.build_pretraining_data_loader(self._validation_ds, consumed_samples) + + def setup_test_data(self, cfg): + if hasattr(self, '_test_ds'): + consumed_samples = 0 + logging.info( + f'Setting up test dataloader with len(len(self._test_ds)): {len(self._test_ds)} and consumed samples: {consumed_samples}' + ) + self._test_dl = self.build_pretraining_data_loader(self._test_ds, consumed_samples) + + def compute_consumed_samples(self, global_step): + app_state = AppState() + consumed_samples = ( + global_step + * app_state.data_parallel_size + * self.cfg.micro_batch_size + * self.trainer.accumulate_grad_batches + ) + return int(consumed_samples) + + def on_before_optimizer_step(self, optimizer, optimizer_idx): + """PTL hook that is called after unscaling gradients when using native amp. + We use gradient clipping implementation from megatron-lm. + """ + clip_val = self.trainer.gradient_clip_val + if clip_val is None: + return + + clip_val = float(clip_val) + if clip_val <= 0: + return + + if isinstance(self.trainer.accelerator_connector.precision_plugin, NLPNativeMixedPrecisionPlugin): + parameters = self.model.parameters() + clip_grad_norm_fp32(parameters=parameters, max_norm=clip_val) + else: + return + + def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: Optional[int] = None) -> Any: + request = batch + response = self.complete(request) + return response + + def complete(self, request: Dict): + """ + Autoregressively invokes language model in the inference mode + Args: + request: Dictionary with the following fields + * prompt: a string which text the model should complete. + * tokens_to_generate: how many tokens to generate while doing prompt completion. + * stop_after_sentence: (default True) whether to stop generation once sentence end is reached. + Returns: + response: A python dictionary with the following fields + * prompt: original text of the prompt + * tokenized_prompt: list of (str) tokens from prompt + * completion: a python dictionary with the following subfields: + * tokens: a list of triples (token, token_id, log_prob) comprising completion + * stop reason: either 'eos', 'sentence_end' or 'limit' indicating why generation stopped + * text: completion text (as a single string) + + """ + response = {} + self.freeze() + logsoftmaxlayer = torch.nn.LogSoftmax(dim=-1) + response['tokenized_prompt'] = request['tokenized_prompt'] + tokens = request['tokens'] + # naive greedy slow loop + # TODO: add option for BeamSearchDecoder + response['prompt'] = request['prompt'] + response['completion'] = {} + response['completion']['stop reason'] = 'limit' + for i in range(request.get("tokens_to_generate", 64)): + attention_mask, _, position_ids = get_ltor_masks_and_position_ids( + data=tokens, + eod_token=self.tokenizer.eos_id, + reset_position_ids=self.cfg.get('reset_position_ids', False), + reset_attention_mask=self.cfg.get('reset_attention_mask', False), + eod_mask_loss=self.cfg.get('eod_mask_loss', False), + ) + # No labels during inference. Still need masks to not attend to the right + output_tensor = self(tokens, position_ids, attention_mask, labels=None) + output_tensor = tensor_parallel.gather_from_tensor_model_parallel_region(output_tensor) + log_probs, token_ids = torch.max(logsoftmaxlayer(output_tensor), dim=-1) + reached_eos = token_ids[0, -1].item() == self.tokenizer.eos_id + tokens = torch.cat([torch.squeeze(tokens), token_ids[:, -1]]) + response['completion']["tokens"] = list( + zip(self.tokenizer.ids_to_tokens(tokens), tokens.tolist(), log_probs.tolist()[0]) + ) + completion_text = self.tokenizer.ids_to_text(x[1] for x in response['completion']["tokens"]) + if reached_eos: # Will it actually ever reach that? + response['completion']['stop reason'] = 'eos' + break + elif request.get("stop_after_sentence", True) and completion_text.endswith(('.', '!', '?')): + response['completion']['stop reason'] = 'sentence_end' + break + tokens = torch.unsqueeze(tokens, 0) + response['completion']["text"] = self.tokenizer.ids_to_text(x[1] for x in response['completion']["tokens"]) + self.unfreeze() + return response + + def list_available_models(self): + return None + + def _vocab_size_with_padding(self, orig_vocab_size, make_vocab_size_divisible_by, tensor_model_parallel_size): + """Pad vocab size so it is divisible by model parallel size and + still having GPU friendly size.""" + + after = orig_vocab_size + multiple = make_vocab_size_divisible_by * tensor_model_parallel_size + while (after % multiple) != 0: + after += 1 + logging.info( + f'Padded vocab_size: {after}, original vocab_size: {orig_vocab_size}, dummy tokens: {after - orig_vocab_size}.' + ) + return after diff --git a/nemo/collections/nlp/models/nlp_model.py b/nemo/collections/nlp/models/nlp_model.py index e6d920c12..9b567a214 100644 --- a/nemo/collections/nlp/models/nlp_model.py +++ b/nemo/collections/nlp/models/nlp_model.py @@ -15,32 +15,29 @@ import hashlib import json import os -import shutil -import tarfile -import tempfile -from typing import Any, Dict, Optional, Union +from typing import Any, Dict, Optional -import torch -from megatron import mpu -from megatron.checkpointing import get_checkpoint_version, set_checkpoint_version from omegaconf import DictConfig, OmegaConf from pytorch_lightning import Trainer -from pytorch_lightning.accelerators.accelerator import Accelerator +from pytorch_lightning.core.saving import load_hparams_from_tags_csv, load_hparams_from_yaml from pytorch_lightning.utilities import rank_zero_only +from pytorch_lightning.utilities.cloud_io import load as pl_load +from pytorch_lightning.utilities.migration import pl_legacy_patch from transformers import TRANSFORMERS_CACHE from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import AutoTokenizer from nemo.collections.nlp.modules import BertModule, MegatronBertEncoder from nemo.collections.nlp.modules.common.huggingface.huggingface_utils import VOCAB_FILE_NAME +from nemo.collections.nlp.modules.common.megatron.megatron_bert import ( + get_megatron_checkpoint_version, + set_megatron_checkpoint_version, +) from nemo.collections.nlp.modules.common.megatron.megatron_encoder import MegatronEncoderModule -from nemo.collections.nlp.modules.common.megatron.megatron_utils import compute_model_parallel_rank from nemo.collections.nlp.modules.common.tokenizer_utils import get_tokenizer -from nemo.collections.nlp.parts.nlp_overrides import NLPCheckpointConnector +from nemo.collections.nlp.parts.nlp_overrides import NLPCheckpointConnector, NLPSaveRestoreConnector from nemo.core.classes import ModelPT from nemo.core.classes.exportable import Exportable -from nemo.core.connectors.save_restore_connector import SaveRestoreConnector from nemo.utils import AppState, logging -from nemo.utils.get_rank import is_global_rank_zero __all__ = ['NLPModel'] @@ -55,6 +52,8 @@ class NLPModel(ModelPT, Exportable): def __init__(self, cfg: DictConfig, trainer: Trainer = None): super().__init__(cfg, trainer) + # handles model parallel save and restore logic + self._save_restore_connector = NLPSaveRestoreConnector() self.set_world_size(trainer) def register_artifact( @@ -180,274 +179,31 @@ class NLPModel(ModelPT, Exportable): f'Registering tokenizer vocab for {self.tokenizer} is not yet supported. Please override this method if needed.' ) - def _clip_gradients(self, optimizer, clip_val=None): - """ Override of PTL Gradient Clipping. - Enables model parallel gradient clipping from Megatron-LM. - - Args: - optimizer ([type]): [description] - clip_val ([type], optional): [description]. Defaults to None. - """ - app_state = AppState() - - # get clip_val from trainer if None is provided - if clip_val is None: - clip_val = float(self._trainer.gradient_clip_val) - - if app_state.model_parallel_size is not None: - model = self._trainer.get_model() - parameters = model.parameters() - if mpu.model_parallel_is_initialized(): - mpu.grads.clip_grad_norm(parameters=parameters, max_norm=clip_val) - else: - raise ValueError('Model parallel groups must be intialized to use model parallel gradient clipping.') - - else: - return Accelerator._clip_gradients(self, optimizer, clip_val) - - def setup(self, stage: str) -> None: - """ PTL hook that is called on all DDP processes. """ - - if stage == 'fit': - - # adds self.bert_model config to .nemo file - if hasattr(self, 'bert_model') and self.bert_model is not None: - self.register_bert_model() - - app_state = AppState() - - if app_state.model_parallel_size is not None: - - self._trainer.checkpoint_connector = NLPCheckpointConnector(self._trainer) - - # # Configure checkpointing for model parallel - # if app_state.create_checkpoint_callback: - # # global rank 0 is configured by exp_manager - # if not is_global_rank_zero() and app_state.data_parallel_rank == 0: - # configure_checkpointing( - # self._trainer, - # app_state.log_dir, - # app_state.checkpoint_name, - # app_state.checkpoint_callback_params, - # ) - def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None: """ LightningModule hook that's used to save things in addition to model weights. """ if hasattr(self, "bert_model") and isinstance(self.bert_model, MegatronBertEncoder): - checkpoint['checkpoint_version'] = get_checkpoint_version() + checkpoint['checkpoint_version'] = get_megatron_checkpoint_version() return None def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: """ LightningModule hook that's used to restore things saved with on_save_checkpoint.""" if hasattr(self, "bert_model") and isinstance(self.bert_model, MegatronBertEncoder): - if get_checkpoint_version(): + if get_megatron_checkpoint_version(): assert ( - checkpoint['checkpoint_version'] == get_checkpoint_version() - ), 'checkpoint version found on_load_checkpoint different than get_checkpoint_version' + checkpoint['checkpoint_version'] == get_megatron_checkpoint_version() + ), 'checkpoint version found on_load_checkpoint different than get_megatron_checkpoint_version' else: - set_checkpoint_version(checkpoint['checkpoint_version']) + set_megatron_checkpoint_version(checkpoint['checkpoint_version']) logging.info(f"Setting Megatron checkpoint version: {checkpoint['checkpoint_version']}") return None - # no rank check as model parallel models need to be saved on data parallel rank 0 - def save_to(self, save_path: str): - """ - Saves model instance (weights and configuration) into .nemo file - You can use "restore_from" method to fully restore instance from .nemo file. - - .nemo file is an archive (tar.gz) with the following: - model_config.yaml - model configuration in .yaml format. You can deserialize this into cfg argument for model's constructor - model_weights.ckpt - model checkpoint - - Args: - save_path: Path to .nemo file where model instance should be saved - """ - save_path = os.path.abspath(save_path) - app_state = AppState() - if app_state.model_parallel_size is not None: - self._default_save_to(save_path) - else: - # super.save_to only runs on global rank 0 - return super().save_to(save_path) - - def _default_save_to(self, save_path: str): - app_state = AppState() - if app_state.model_parallel_size is not None: - # each model parallel rank creates a .nemo file - # after all .nemo files are created, each rank - # will add their checkpoint to global rank 0 - - base_dir = os.path.dirname(save_path) # use the directory to merge mp_rank .nemo files into one - - # update save_path based on model parallel_rank - base_path = os.path.splitext(save_path)[0] # everything except the extension - - mp_save_path = f'{base_path}_mp_rank_{app_state.model_parallel_rank:02d}.nemo' - - if app_state.data_parallel_rank == 0: - super()._default_save_to(mp_save_path) - - # barrier so that all processes have finished writing their weights before creating .nemo file - torch.distributed.barrier() - - if is_global_rank_zero(): - # extract all tar files - for mp_rank in range(app_state.model_parallel_size): - mp_tar_path = f'{base_path}_mp_rank_{mp_rank:02d}.nemo' - mp_tar = tarfile.open(mp_tar_path, 'r:gz') - mp_tar.extractall(path=os.path.join(base_dir, f'mp_rank_{mp_rank:02d}')) - mp_tar.close() - os.remove(mp_tar_path) - - # move rank 0 .nemo extract to base_path - shutil.move(os.path.join(base_dir, 'mp_rank_00'), base_path) - - # move mp_rank_00 checkpoint to mp_rank_00 directory inside base_path - os.mkdir(os.path.join(base_path, 'mp_rank_00')) - shutil.move(os.path.join(base_path, 'model_weights.ckpt'), os.path.join(base_path, 'mp_rank_00')) - - # move other mp_rank checkpoints from base_dir to base_path - for mp_rank in range(1, app_state.model_parallel_size): - os.mkdir(os.path.join(base_path, f'mp_rank_{mp_rank:02d}')) - shutil.move( - os.path.join(base_dir, f'mp_rank_{mp_rank:02d}', 'model_weights.ckpt'), - os.path.join(base_path, f'mp_rank_{mp_rank:02d}'), - ) - # clean up leftover directory - shutil.rmtree(os.path.join(base_dir, f'mp_rank_{mp_rank:02d}')) - - # create tar file from base_path - self._make_nemo_file_from_folder(save_path, base_path) - - # clean up base_path - shutil.rmtree(base_path) - - elif is_global_rank_zero(): - return super()._default_save_to(save_path) - else: - return - - @classmethod - def restore_from( - cls, - restore_path: str, - override_config_path: Optional[Union[OmegaConf, str]] = None, - map_location: Optional[torch.device] = None, - strict: bool = True, - return_config: bool = False, - trainer: Trainer = None, - save_restore_connector: SaveRestoreConnector = None, - ): - """ - Restores model instance (weights and configuration) from .nemo file. - - Args: - restore_path: path to .nemo file from which model should be instantiated - override_config_path: path to a yaml config that will override the internal - config file or an OmegaConf / DictConfig object representing the model config. - map_location: Optional torch.device() to map the instantiated model to a device. - By default (None), it will select a GPU if available, falling back to CPU otherwise. - strict: Passed to load_state_dict. Set to True by default. - return_config: If set to true, will return just the underlying config of the restored - model as an OmegaConf DictConfig object without instantiating the model. - trainer: PyTorch Lightning trainer. Must be passed in order to use model parallel .nemo - - Example: - ``` - model = nemo.collections.nlp.models.TokenClassificationModel.restore_from('token_classification.nemo') - assert isinstance(model, nemo.collections.nlp.models.TokenClassificationModel) - ``` - - Returns: - An instance of type cls or its underlying config (if return_config is set). - """ - if save_restore_connector is None: - save_restore_connector = SaveRestoreConnector() - - if not os.path.exists(restore_path): - raise FileNotFoundError(f"Can't find {restore_path}") - - app_state = AppState() - app_state.model_restore_path = os.path.abspath(os.path.expanduser(restore_path)) - - # detect if we have a model parallel .nemo file - with tempfile.TemporaryDirectory() as tmpdir: - cwd = os.getcwd() - os.chdir(tmpdir) - # detect if model parallel from tarfile - tar = tarfile.open(app_state.model_restore_path, "r:gz") - names = tar.getnames() - mp_ranks = [] - for name in names: - if 'mp_rank' in name: - mp_ranks.append(name) - if mp_ranks: - app_state.model_parallel_size = len(mp_ranks) // 2 # directory and file are included in getnames() - - # get checkpoint version - checkpoint_version_member = None - for member in tar.getmembers(): - if 'megatron_checkpoint_version.json' in member.name: - checkpoint_version_member = member - tar.extract(checkpoint_version_member, tmpdir) - with open(checkpoint_version_member.name, 'r') as f: - checkpoint_version = json.load(f).get('checkpoint_version', None) - logging.info( - ( - f'Detected model parallel .nemo file: {restore_path}. ' - f'Assuming megatron model parallelism with ' - f'model_parallel_size: {app_state.model_parallel_size} ' - f'and checkpoint version: {checkpoint_version}' - ) - ) - tar.close() - os.chdir(cwd) - - if app_state.model_parallel_size is not None: - if not isinstance(trainer, Trainer): - raise ValueError("trainer must be a PyTorch Lightning Trainer to restore model parallel .nemo files.") - - if checkpoint_version is None: - raise ValueError( - "Restoring from megatron model parallel .nemo but could not find megatron checkpoint version." - ) - else: - logging.info(f"Setting megatron checkpoint version: {checkpoint_version}") - set_checkpoint_version(checkpoint_version) - - app_state.world_size = trainer.num_gpus * trainer.num_nodes - - if trainer.local_rank is not None: - app_state.local_rank = trainer.local_rank - else: - raise ValueError("trainer.local_rank is None. local_rank needed to restore model parallel models.") - - model_parallel_rank = compute_model_parallel_rank(trainer.local_rank, app_state.model_parallel_size) - app_state.model_parallel_rank = model_parallel_rank - - cls.update_save_restore_connector(save_restore_connector) - restored_model = cls._save_restore_connector.restore_from( - cls, app_state.model_restore_path, override_config_path, map_location, strict, return_config - ) - restored_model.set_trainer(trainer) - return restored_model - else: - return super().restore_from( - app_state.model_restore_path, - override_config_path, - map_location, - strict, - return_config, - save_restore_connector=save_restore_connector, - ) - @rank_zero_only def register_megatron_checkpoint_version(self): """ Adds checkpoint version to .nemo archive """ if self.has_megatron_encoder: - checkpoint_version = get_checkpoint_version() + checkpoint_version = get_megatron_checkpoint_version() if checkpoint_version is None: raise ValueError('Unable to get megatron checkpoint version.') else: @@ -511,3 +267,72 @@ class NLPModel(ModelPT, Exportable): if isinstance(self.encoder, MegatronEncoderModule): logging.info(f"Restoring from pretrained model parallel checkpoint: {self.encoder.checkpoint_file}") self.encoder._encoder.restore_weights(self.encoder.checkpoint_file) + + @classmethod + def load_from_checkpoint( + cls, + checkpoint_path: str, + map_location: Any = None, + hparams_file: Optional[str] = None, + strict: bool = True, + **kwargs, + ): + """ + Loads ModelPT from checkpoint, with some maintenance of restoration. + For documentation, please refer to LightningModule.load_from_checkpoin() documentation. + """ + checkpoint = None + try: + cls._set_model_restore_state(is_being_restored=True) + # TODO: replace with proper PTL API + with pl_legacy_patch(): + if map_location is not None: + checkpoint = pl_load(checkpoint_path, map_location=map_location) + else: + checkpoint = pl_load(checkpoint_path, map_location=lambda storage, loc: storage) + + if hparams_file is not None: + extension = hparams_file.split(".")[-1] + if extension.lower() == "csv": + hparams = load_hparams_from_tags_csv(hparams_file) + elif extension.lower() in ("yml", "yaml"): + hparams = load_hparams_from_yaml(hparams_file) + else: + raise ValueError(".csv, .yml or .yaml is required for `hparams_file`") + + hparams["on_gpu"] = False + + # overwrite hparams by the given file + checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY] = hparams + + # for past checkpoint need to add the new key + if cls.CHECKPOINT_HYPER_PARAMS_KEY not in checkpoint: + checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY] = {} + # override the hparams with values that were passed in + # TODO: can we do this without overriding? + config_kwargs = kwargs.copy() + if 'trainer' in config_kwargs: + config_kwargs.pop('trainer') + checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY].update(config_kwargs) + + model = cls._load_model_state(checkpoint, strict=strict, **kwargs) + checkpoint = model + + finally: + cls._set_model_restore_state(is_being_restored=False) + return checkpoint + + def save_to(self, save_path: str): + app_state = AppState() + # Add NeMo rank check as well + if app_state.model_parallel_size is not None: + if app_state.model_parallel_size > 1: + if not isinstance(self._save_restore_connector, NLPSaveRestoreConnector): + logging.warning( + f"Using {self._save_restore_connector.__class__} to save a model parallel model. Overriding with NLPSaveRestoreConnector. Make sure to subclass NLPSaveRestoreConnector." + ) + self._save_restore_connector = NLPSaveRestoreConnector() + save_path = os.path.abspath(os.path.expanduser(save_path)) + self._save_restore_connector.save_to(self, save_path) + else: + super(NLPModel, self).save_to(save_path=save_path) diff --git a/nemo/collections/nlp/modules/common/bert_module.py b/nemo/collections/nlp/modules/common/bert_module.py index fbac6a858..735c68f7b 100644 --- a/nemo/collections/nlp/modules/common/bert_module.py +++ b/nemo/collections/nlp/modules/common/bert_module.py @@ -39,16 +39,6 @@ class BertModule(NeuralModule, Exportable): def output_types(self) -> Optional[Dict[str, NeuralType]]: return {"last_hidden_states": NeuralType(('B', 'T', 'D'), ChannelType())} - @classmethod - def restore_from(cls, restore_path: str): - """Restores module/model with weights""" - pass - - @classmethod - def save_to(self, save_path: str): - """Saves module/model with weights""" - pass - def restore_weights(self, restore_path: str): """Restores module/model's weights""" logging.info(f"Restoring weights from {restore_path}") diff --git a/nemo/collections/nlp/modules/common/lm_utils.py b/nemo/collections/nlp/modules/common/lm_utils.py index 4feb5698e..ab93e8362 100644 --- a/nemo/collections/nlp/modules/common/lm_utils.py +++ b/nemo/collections/nlp/modules/common/lm_utils.py @@ -89,13 +89,15 @@ def get_lm_model( ) if "megatron" in pretrained_model_name: - model, checkpoint_file = get_megatron_lm_model( - config_dict=config_dict, - config_file=config_file, - pretrained_model_name=pretrained_model_name, - checkpoint_file=checkpoint_file, - vocab_file=vocab_file, - ) + raise ValueError('megatron-lm BERT models have been deprecated in NeMo 1.5+. Please use NeMo 1.4 for support.') + # TODO: enable megatron bert in nemo + # model, checkpoint_file = get_megatron_lm_model( + # config_dict=config_dict, + # config_file=config_file, + # pretrained_model_name=pretrained_model_name, + # checkpoint_file=checkpoint_file, + # vocab_file=vocab_file, + # ) else: model = get_huggingface_lm_model( config_dict=config_dict, config_file=config_file, pretrained_model_name=pretrained_model_name, @@ -180,13 +182,17 @@ def get_transformer( ) elif library == 'megatron': - model = get_megatron_transformer( - model_name=model_name, - pretrained=pretrained, - config_dict=config_dict, - encoder=encoder, - checkpoint_file=checkpoint_file, + raise ValueError( + f'megatron-lm bert support has been deprecated in NeMo 1.5+. Please use NeMo 1.4 for support.' ) + # TODO: enable megatron bert in nemo + # model = get_megatron_transformer( + # model_name=model_name, + # pretrained=pretrained, + # config_dict=config_dict, + # encoder=encoder, + # checkpoint_file=checkpoint_file, + # ) else: raise ValueError("Libary must be 'nemo', 'huggingface' or 'megatron'") diff --git a/nemo/collections/nlp/modules/common/megatron/clip_grads.py b/nemo/collections/nlp/modules/common/megatron/clip_grads.py new file mode 100644 index 000000000..09f8e354f --- /dev/null +++ b/nemo/collections/nlp/modules/common/megatron/clip_grads.py @@ -0,0 +1,137 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Gradient clipping.""" + +import amp_C +import torch +from apex.multi_tensor_apply import multi_tensor_applier +from apex.transformer import parallel_state +from apex.transformer.tensor_parallel.layers import param_is_not_tensor_parallel_duplicate +from torch._six import inf + +from nemo.collections.nlp.modules.common.megatron.module import param_is_not_shared + + +def clip_grad_norm_fp32(parameters, max_norm, norm_type=2): + """Clips gradient norm of an iterable of parameters whose gradients + are in fp32. + This is adapted from torch.nn.utils.clip_grad.clip_grad_norm_ and + added functionality to handle model parallel parameters. Note that + the gradients are modified in place. + Arguments: + parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a + single Tensor that will have gradients normalized + max_norm (float or int): max norm of the gradients + norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for + infinity norm. + Returns: + Total norm of the parameters (viewed as a single vector). + """ + + if isinstance(parameters, torch.Tensor): + parameters = [parameters] + + # Filter parameters based on: + # - grad should not be none + # - parameter should not be shared + # - should not be a replica due to tensor model parallelism + grads = [] + grads_for_norm = [] + for param in parameters: + grad_not_none = param.grad is not None + is_not_shared = param_is_not_shared(param) + is_not_tp_duplicate = param_is_not_tensor_parallel_duplicate(param) + grad = param.grad.detach() + if grad_not_none: + # Make sure the grads are in fp32 + assert isinstance(param.grad, torch.cuda.FloatTensor) + grads.append(grad) + if grad_not_none and is_not_shared and is_not_tp_duplicate: + grads_for_norm.append(grad) + + # Norm parameters. + max_norm = float(max_norm) + norm_type = float(norm_type) + total_norm = 0.0 + + # Calculate norm. + if norm_type == inf: + total_norm = max(grad.abs().max() for grad in grads_for_norm) + total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)]) + # Take max across all model-parallel GPUs. + torch.distributed.all_reduce( + total_norm_cuda, op=torch.distributed.ReduceOp.MAX, group=parallel_state.get_model_parallel_group() + ) + total_norm = total_norm_cuda[0].item() + + else: + if norm_type == 2.0: + dummy_overflow_buf = torch.cuda.IntTensor([0]) + # Use apex's multi-tensor applier for efficiency reasons. + # Multi-tensor applier takes a function and a list of list + # and performs the operation on that list all in one kernel. + grad_norm, _ = multi_tensor_applier( + amp_C.multi_tensor_l2norm, dummy_overflow_buf, [grads_for_norm], False # no per-parameter norm + ) + # Since we will be summing across data parallel groups, + # we need the pow(norm-type). + total_norm = grad_norm ** norm_type + + else: + for grad in grads_for_norm: + grad_norm = torch.norm(grad, norm_type) + total_norm += grad_norm ** norm_type + + # Sum across all model-parallel GPUs. + torch.distributed.all_reduce( + total_norm, op=torch.distributed.ReduceOp.SUM, group=parallel_state.get_model_parallel_group() + ) + total_norm = total_norm.item() ** (1.0 / norm_type) + + # Scale. + clip_coeff = max_norm / (total_norm + 1.0e-6) + if clip_coeff < 1.0: + dummy_overflow_buf = torch.cuda.IntTensor([0]) + multi_tensor_applier(amp_C.multi_tensor_scale, dummy_overflow_buf, [grads, grads], clip_coeff) + + return total_norm + + +def count_zeros_fp32(parameters): + + if isinstance(parameters, torch.Tensor): + parameters = [parameters] + + # Filter parameters based on: + # - grad should not be none + # - parameter should not be shared + # - should not be a replica due to tensor model parallelism + total_num_zeros = 0.0 + for param in parameters: + grad_not_none = param.grad is not None + is_not_shared = param_is_not_shared(param) + is_not_tp_duplicate = param_is_not_tensor_parallel_duplicate(param) + if grad_not_none and is_not_shared and is_not_tp_duplicate: + grad = param.grad.detach() + num_zeros = grad.numel() - torch.count_nonzero(grad) + total_num_zeros = num_zeros + total_num_zeros + + # Sum across all model-parallel GPUs. + torch.distributed.all_reduce( + total_num_zeros, op=torch.distributed.ReduceOp.SUM, group=parallel_state.get_model_parallel_group() + ) + total_num_zeros = total_num_zeros.item() + + return total_num_zeros diff --git a/nemo/collections/nlp/modules/common/megatron/enums.py b/nemo/collections/nlp/modules/common/megatron/enums.py new file mode 100644 index 000000000..cd3605a83 --- /dev/null +++ b/nemo/collections/nlp/modules/common/megatron/enums.py @@ -0,0 +1,30 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import enum + + +class LayerType(enum.Enum): + encoder = 1 + decoder = 2 + + +class AttnType(enum.Enum): + self_attn = 1 + cross_attn = 2 + + +class AttnMaskType(enum.Enum): + padding = 1 + causal = 2 diff --git a/nemo/collections/nlp/modules/common/megatron/fused_bias_dropout_add.py b/nemo/collections/nlp/modules/common/megatron/fused_bias_dropout_add.py new file mode 100644 index 000000000..cc21ac987 --- /dev/null +++ b/nemo/collections/nlp/modules/common/megatron/fused_bias_dropout_add.py @@ -0,0 +1,62 @@ +# coding=utf-8 +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import torch + +from nemo.collections.nlp.modules.common.megatron.utils import AutocastModuleWrapper + + +def bias_dropout_add(x, bias, residual, prob, training): + # type: (Tensor, Tensor, Tensor, float, bool) -> Tensor + out = torch.nn.functional.dropout(x + bias, p=prob, training=training) + out = residual + out + return out + + +@torch.jit.script +def bias_dropout_add_fused_train_( + x: torch.Tensor, bias: torch.Tensor, residual: torch.Tensor, prob: float +) -> torch.Tensor: + # type: (Tensor, Tensor, Tensor, float) -> Tensor + return bias_dropout_add(x, bias, residual, prob, True) + + +class BiasDropoutAddFusedTrain(AutocastModuleWrapper): + def __init__(self, fp16=False, bf16=False): + super(BiasDropoutAddFusedTrain, self).__init__(fp16, bf16) + + self.func = bias_dropout_add_fused_train_ + + def forward(self, x, bias, residual, prob): + return self.autocast_forward(x, bias, residual, prob) + + +@torch.jit.script +def bias_dropout_add_fused_inference_( + x: torch.Tensor, bias: torch.Tensor, residual: torch.Tensor, prob: float +) -> torch.Tensor: + # type: (Tensor, Tensor, Tensor, float) -> Tensor + return bias_dropout_add(x, bias, residual, prob, False) + + +class BiasDropoutAddFusedInference(AutocastModuleWrapper): + def __init__(self, fp16=False, bf16=False): + super(BiasDropoutAddFusedInference, self).__init__(fp16, bf16) + + self.func = bias_dropout_add_fused_inference_ + + def forward(self, x, bias, residual, prob): + return self.autocast_forward(x, bias, residual, prob) diff --git a/nemo/collections/nlp/modules/common/megatron/fused_bias_gelu.py b/nemo/collections/nlp/modules/common/megatron/fused_bias_gelu.py new file mode 100644 index 000000000..25b84cefa --- /dev/null +++ b/nemo/collections/nlp/modules/common/megatron/fused_bias_gelu.py @@ -0,0 +1,68 @@ +# coding=utf-8 +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + +from nemo.collections.nlp.modules.common.megatron.utils import AutocastModuleWrapper + +###### BIAS GELU FUSION/ NO AUTOGRAD ################ +# 1/sqrt(2*pi)-> 0.3989423 +# 1/sqrt(2) -> 0.70710678 +# sqrt(2/pi) -> 0.79788456 +# this function is tanh approximation of gelu +# actual gelu is: +# x * 0.5 * (1.0 + torch.erf(x * 0.70710678)) + + +@torch.jit.script +def bias_gelu(bias, y): + x = bias + y + return x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))) + + +# gradient of tanh approximation of gelu +# gradient of actual gelu is: +# 0.5 * (1. + torch.erf(x * 0.70710678)) + 0.3989423 * x * torch.exp(-0.5 * x * x) +@torch.jit.script +def bias_gelu_back(g, bias, y): + x = bias + y + tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)) + # sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243 + ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out) + return ff * g + + +class GeLUFunction(torch.autograd.Function): + @staticmethod + # bias is an optional argument + def forward(ctx, input, bias): + ctx.save_for_backward(input, bias) + return bias_gelu(bias, input) + + @staticmethod + def backward(ctx, grad_output): + input, bias = ctx.saved_tensors + tmp = bias_gelu_back(grad_output, bias, input) + return tmp, tmp + + +class FusedBiasGeLU(AutocastModuleWrapper): + def __init__(self, fp16=False, bf16=False): + super(FusedBiasGeLU, self).__init__(fp16, bf16) + + self.func = GeLUFunction + + def forward(self, input, bias): + return self.autocast_forward(input, bias) diff --git a/nemo/collections/nlp/modules/common/megatron/fused_kernels/__init__.py b/nemo/collections/nlp/modules/common/megatron/fused_kernels/__init__.py new file mode 100644 index 000000000..ccce2d0b8 --- /dev/null +++ b/nemo/collections/nlp/modules/common/megatron/fused_kernels/__init__.py @@ -0,0 +1,107 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import pathlib +import subprocess + +from torch.utils import cpp_extension + +# Setting this param to a list has a problem of generating different +# compilation commands (with diferent order of architectures) and +# leading to recompilation of fused kernels. Set it to empty string +# to avoid recompilation and assign arch flags explicity in +# extra_cuda_cflags below +os.environ["TORCH_CUDA_ARCH_LIST"] = "" + + +def load(): + + # Check if cuda 11 is installed for compute capability 8.0 + cc_flag = [] + _, bare_metal_major, _ = _get_cuda_bare_metal_version(cpp_extension.CUDA_HOME) + if int(bare_metal_major) >= 11: + cc_flag.append('-gencode') + cc_flag.append('arch=compute_80,code=sm_80') + + # Build path + srcpath = pathlib.Path(__file__).parent.absolute() + buildpath = srcpath / 'build' + _create_build_dir(buildpath) + + # Helper function to build the kernels. + def _cpp_extention_load_helper(name, sources, extra_cuda_flags): + return cpp_extension.load( + name=name, + sources=sources, + build_directory=buildpath, + extra_cflags=['-O3',], + extra_cuda_cflags=['-O3', '-gencode', 'arch=compute_70,code=sm_70', '--use_fast_math'] + + extra_cuda_flags + + cc_flag, + verbose=True, + ) + + # ============== + # Fused softmax. + # ============== + + extra_cuda_flags = [ + '-U__CUDA_NO_HALF_OPERATORS__', + '-U__CUDA_NO_HALF_CONVERSIONS__', + '--expt-relaxed-constexpr', + '--expt-extended-lambda', + ] + + # Upper triangular softmax. + sources = [ + srcpath / 'scaled_upper_triang_masked_softmax.cpp', + srcpath / 'scaled_upper_triang_masked_softmax_cuda.cu', + ] + scaled_upper_triang_masked_softmax_cuda = _cpp_extention_load_helper( + "scaled_upper_triang_masked_softmax_cuda", sources, extra_cuda_flags + ) + + # Masked softmax. + sources = [srcpath / 'scaled_masked_softmax.cpp', srcpath / 'scaled_masked_softmax_cuda.cu'] + scaled_masked_softmax_cuda = _cpp_extention_load_helper("scaled_masked_softmax_cuda", sources, extra_cuda_flags) + + # ================================= + # Mixed precision fused layer norm. + # ================================= + + extra_cuda_flags = ['-maxrregcount=50'] + sources = [srcpath / 'layer_norm_cuda.cpp', srcpath / 'layer_norm_cuda_kernel.cu'] + fused_mix_prec_layer_norm_cuda = _cpp_extention_load_helper( + "fused_mix_prec_layer_norm_cuda", sources, extra_cuda_flags + ) + + +def _get_cuda_bare_metal_version(cuda_dir): + raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True) + output = raw_output.split() + release_idx = output.index("release") + 1 + release = output[release_idx].split(".") + bare_metal_major = release[0] + bare_metal_minor = release[1][0] + + return raw_output, bare_metal_major, bare_metal_minor + + +def _create_build_dir(buildpath): + try: + os.mkdir(buildpath) + except OSError: + if not os.path.isdir(buildpath): + print(f"Creation of the build directory {buildpath} failed") diff --git a/nemo/collections/nlp/modules/common/megatron/fused_kernels/compat.h b/nemo/collections/nlp/modules/common/megatron/fused_kernels/compat.h new file mode 100644 index 000000000..92e7eb772 --- /dev/null +++ b/nemo/collections/nlp/modules/common/megatron/fused_kernels/compat.h @@ -0,0 +1,31 @@ +/* coding=utf-8 + * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*This code is copied fron NVIDIA apex: + * https://github.com/NVIDIA/apex + * with minor changes. */ + + + +#ifndef TORCH_CHECK +#define TORCH_CHECK AT_CHECK +#endif + +#ifdef VERSION_GE_1_3 +#define DATA_PTR data_ptr +#else +#define DATA_PTR data +#endif diff --git a/nemo/collections/nlp/modules/common/megatron/fused_kernels/layer_norm_cuda.cpp b/nemo/collections/nlp/modules/common/megatron/fused_kernels/layer_norm_cuda.cpp new file mode 100644 index 000000000..8f28e7b4a --- /dev/null +++ b/nemo/collections/nlp/modules/common/megatron/fused_kernels/layer_norm_cuda.cpp @@ -0,0 +1,201 @@ +/* coding=utf-8 + * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*This code is copied fron NVIDIA apex: + * https://github.com/NVIDIA/apex + * with minor changes. */ + +#include +#include +#include +#include "compat.h" + +namespace { + +void compute_n1_n2( + at::Tensor input, + at::IntArrayRef normalized_shape, + int& n1, + int& n2) { + int idiff = input.ndimension() - normalized_shape.size(); + n2 = 1; + for (int i = 0; i < (int)normalized_shape.size(); ++i) { + assert( input.sizes()[i+idiff] == normalized_shape[i] ); + n2 *= normalized_shape[i]; + } + n1 = 1; + for (int i = 0; i < idiff; ++i) { + n1 *= input.sizes()[i]; + } +} + +void check_args( + at::IntArrayRef normalized_shape, + at::Tensor gamma, + at::Tensor beta + ) +{ + TORCH_CHECK(!gamma.defined() || gamma.sizes().equals(normalized_shape)); + TORCH_CHECK(!beta.defined() || beta.sizes().equals(normalized_shape)); +} + +void check_args( + at::Tensor input, + at::IntArrayRef normalized_shape, + int& n1, + int& n2 + ) +{ + int64_t normalized_ndim = normalized_shape.size(); + + if (normalized_ndim < 1) { + std::stringstream ss; + ss << "Expected normalized_shape to be at least 1-dimensional, i.e., " + << "containing at least one element, but got normalized_shape=" + << normalized_shape; + throw std::runtime_error(ss.str()); + } + + auto input_shape = input.sizes(); + auto input_ndim = input.dim(); + + if (input_ndim < normalized_ndim || + !input_shape.slice(input_ndim - normalized_ndim).equals(normalized_shape)) { + std::stringstream ss; + ss << "Given normalized_shape=" << normalized_shape + << ", expected input with shape [*"; + for (auto size : normalized_shape) { + ss << ", " << size; + } + ss << "], but got input of size" << input_shape; + throw std::runtime_error(ss.str()); + } + + compute_n1_n2(input,normalized_shape,n1,n2); +} + + +void check_args( + at::Tensor input, + at::IntArrayRef normalized_shape, + at::Tensor gamma, + at::Tensor beta, + int& n1, + int& n2 + ) +{ + check_args(input,normalized_shape,n1,n2); + check_args(normalized_shape,gamma,beta); +} +} + +void cuda_layer_norm( + at::Tensor* output, + at::Tensor* mean, + at::Tensor* invvar, + at::Tensor* input, + int n1, + int n2, + at::IntArrayRef normalized_shape, + at::Tensor* gamma, + at::Tensor* beta, + double epsilon); + +#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") +#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) + +std::vector layer_norm_affine( + at::Tensor input, + at::IntArrayRef normalized_shape, + at::Tensor gamma, + at::Tensor beta, + double epsilon) { + + CHECK_INPUT(input); + CHECK_INPUT(gamma); + CHECK_INPUT(beta); + int n1, n2; + check_args(input, normalized_shape, gamma, beta, n1, n2); + + at::Tensor output = at::empty_like( + input, gamma.options().dtype(gamma.scalar_type())); + at::Tensor mean = at::empty( + {n1}, input.options().dtype(at::ScalarType::Float)); + at::Tensor invvar = at::empty_like(mean); + + cuda_layer_norm(&output, &mean, &invvar, &input, n1, n2, + normalized_shape, &gamma, &beta, epsilon); + + return {output, mean, invvar}; + +} + + +void cuda_layer_norm_gradient( + at::Tensor* dout, + at::Tensor* mean, + at::Tensor* invvar, + at::Tensor* input, + int n1, + int n2, + at::IntArrayRef normalized_shape, + at::Tensor* gamma, + at::Tensor* beta, + double epsilon, + at::Tensor* grad_input, + at::Tensor* grad_gamma, + at::Tensor* grad_beta + ); + +std::vector layer_norm_gradient_affine( + at::Tensor dout, + at::Tensor mean, + at::Tensor invvar, + at::Tensor input, + at::IntArrayRef normalized_shape, + at::Tensor gamma, + at::Tensor beta, + double epsilon) { + + CHECK_INPUT(dout); + CHECK_INPUT(mean); + CHECK_INPUT(invvar); + CHECK_INPUT(input); + CHECK_INPUT(gamma); + CHECK_INPUT(beta); + int n1, n2; + check_args(input, normalized_shape, gamma, beta, n1, n2); + + at::Tensor grad_input = at::empty_like(input); + at::Tensor grad_gamma = at::empty_like(gamma); + at::Tensor grad_beta = at::empty_like(beta); + + cuda_layer_norm_gradient(&dout, &mean, &invvar, &input, n1, n2, + normalized_shape, &gamma, &beta, epsilon, + &grad_input, &grad_gamma, &grad_beta); + + return {grad_input, grad_gamma, grad_beta}; + +} + + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("forward_affine", &layer_norm_affine, + "LayerNorm forward (CUDA)"); + m.def("backward_affine", &layer_norm_gradient_affine, + "LayerNorm backward (CUDA)"); +} diff --git a/nemo/collections/nlp/modules/common/megatron/fused_kernels/layer_norm_cuda_kernel.cu b/nemo/collections/nlp/modules/common/megatron/fused_kernels/layer_norm_cuda_kernel.cu new file mode 100644 index 000000000..a892c069f --- /dev/null +++ b/nemo/collections/nlp/modules/common/megatron/fused_kernels/layer_norm_cuda_kernel.cu @@ -0,0 +1,829 @@ +/* coding=utf-8 + * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*This code is copied fron NVIDIA apex: + * https://github.com/NVIDIA/apex + * with minor changes. */ + +#include "ATen/ATen.h" +#include "ATen/AccumulateType.h" +#include "ATen/cuda/CUDAContext.h" +#include "ATen/cuda/DeviceUtils.cuh" + +#include +#include + +#include "type_shim.h" + +template __device__ +void cuWelfordOnlineSum( + const U curr, + U& mu, + U& sigma2, + U& count) +{ + count = count + U(1); + U delta = curr - mu; + U lmean = mu + delta / count; + mu = lmean; + U delta2 = curr - lmean; + sigma2 = sigma2 + delta * delta2; +} + +template __device__ +void cuChanOnlineSum( + const U muB, + const U sigma2B, + const U countB, + U& mu, + U& sigma2, + U& count) +{ + U delta = muB - mu; + U nA = count; + U nB = countB; + count = count + countB; + U nX = count; + if (nX > U(0)) { + nA = nA / nX; + nB = nB / nX; + mu = nA*mu + nB*muB; + sigma2 = sigma2 + sigma2B + delta * delta * nA * nB * nX; + } else { + mu = U(0); + sigma2 = U(0); + } +} + +template __device__ +void cuWelfordMuSigma2( + const T* __restrict__ vals, + const int n1, + const int n2, + const int i1, + U& mu, + U& sigma2, + U* buf) +{ + // Assumptions: + // 1) blockDim.x == warpSize + // 2) Tensor is contiguous + // 3) 2*blockDim.y*sizeof(U)+blockDim.y*sizeof(int) shared memory available. + // + // compute variance and mean over n2 + U count = U(0); + mu= U(0); + sigma2 = U(0); + if (i1 < n1) { + // one warp normalizes one n1 index, + // synchronization is implicit + // initialize with standard Welford algorithm + const int numx = blockDim.x * blockDim.y; + const int thrx = threadIdx.x + threadIdx.y * blockDim.x; + const T* lvals = vals + i1*n2; + int l = 4*thrx; + for (; l+3 < n2; l+=4*numx) { + for (int k = 0; k < 4; ++k) { + U curr = static_cast(lvals[l+k]); + cuWelfordOnlineSum(curr,mu,sigma2,count); + } + } + for (; l < n2; ++l) { + U curr = static_cast(lvals[l]); + cuWelfordOnlineSum(curr,mu,sigma2,count); + } + // intra-warp reductions + for (int l = 0; l <= 4; ++l) { + int srcLaneB = (threadIdx.x+(1<(muB,sigma2B,countB,mu,sigma2,count); + } + // threadIdx.x == 0 has correct values for each warp + // inter-warp reductions + if (blockDim.y > 1) { + U* ubuf = (U*)buf; + U* ibuf = (U*)(ubuf + blockDim.y); + for (int offset = blockDim.y/2; offset > 0; offset /= 2) { + // upper half of warps write to shared + if (threadIdx.x == 0 && threadIdx.y >= offset && threadIdx.y < 2*offset) { + const int wrt_y = threadIdx.y - offset; + ubuf[2*wrt_y] = mu; + ubuf[2*wrt_y+1] = sigma2; + ibuf[wrt_y] = count; + } + __syncthreads(); + // lower half merges + if (threadIdx.x == 0 && threadIdx.y < offset) { + U muB = ubuf[2*threadIdx.y]; + U sigma2B = ubuf[2*threadIdx.y+1]; + U countB = ibuf[threadIdx.y]; + cuChanOnlineSum(muB,sigma2B,countB,mu,sigma2,count); + } + __syncthreads(); + } + // threadIdx.x = 0 && threadIdx.y == 0 only thread that has correct values + if (threadIdx.x == 0 && threadIdx.y == 0) { + ubuf[0] = mu; + ubuf[1] = sigma2; + } + __syncthreads(); + mu = ubuf[0]; + sigma2 = ubuf[1]/U(n2); + // don't care about final value of count, we know count == n2 + } else { + mu = WARP_SHFL(mu, 0); + sigma2 = WARP_SHFL(sigma2/U(n2), 0); + } + } +} + +template<> __device__ +void cuWelfordMuSigma2( + const at::Half* __restrict__ vals, + const int n1, + const int n2, + const int i1, + float& mu, + float& sigma2, + float* buf) +{ + // Assumptions: + // 1) blockDim.x == warpSize + // 2) Tensor is contiguous + // 3) 2*blockDim.y*sizeof(U)+blockDim.y*sizeof(int) shared memory available. + // + // compute variance and mean over n2 + float count = 0.0f; + mu= float(0); + sigma2 = float(0); + if (i1 < n1) { + // one warp normalizes one n1 index, + // synchronization is implicit + // initialize with standard Welford algorithm + const int numx = blockDim.x * blockDim.y; + const int thrx = threadIdx.x + threadIdx.y * blockDim.x; + const at::Half* lvals = vals + i1*n2; + int l = 8*thrx; + if ((((size_t)lvals)&3) != 0) { + // 16 bit alignment + // first thread consumes first point + if (thrx == 0) { + float curr = static_cast(lvals[0]); + cuWelfordOnlineSum(curr,mu,sigma2,count); + } + ++l; + } + // at this point, lvals[l] are 32 bit aligned for all threads. + for (; l+7 < n2; l+=8*numx) { + for (int k = 0; k < 8; k+=2) { + float2 curr = __half22float2(*((__half2*)(lvals+l+k))); + cuWelfordOnlineSum(curr.x,mu,sigma2,count); + cuWelfordOnlineSum(curr.y,mu,sigma2,count); + } + } + for (; l < n2; ++l) { + float curr = static_cast(lvals[l]); + cuWelfordOnlineSum(curr,mu,sigma2,count); + } + // intra-warp reductions + for (int l = 0; l <= 4; ++l) { + int srcLaneB = (threadIdx.x+(1< 1) { + float* ubuf = (float*)buf; + float* ibuf = (float*)(ubuf + blockDim.y); + for (int offset = blockDim.y/2; offset > 0; offset /= 2) { + // upper half of warps write to shared + if (threadIdx.x == 0 && threadIdx.y >= offset && threadIdx.y < 2*offset) { + const int wrt_y = threadIdx.y - offset; + ubuf[2*wrt_y] = mu; + ubuf[2*wrt_y+1] = sigma2; + ibuf[wrt_y] = count; + } + __syncthreads(); + // lower half merges + if (threadIdx.x == 0 && threadIdx.y < offset) { + float muB = ubuf[2*threadIdx.y]; + float sigma2B = ubuf[2*threadIdx.y+1]; + float countB = ibuf[threadIdx.y]; + cuChanOnlineSum(muB,sigma2B,countB,mu,sigma2,count); + } + __syncthreads(); + } + // threadIdx.x = 0 && threadIdx.y == 0 only thread that has correct values + if (threadIdx.x == 0 && threadIdx.y == 0) { + ubuf[0] = mu; + ubuf[1] = sigma2; + } + __syncthreads(); + mu = ubuf[0]; + sigma2 = ubuf[1]/float(n2); + // don't care about final value of count, we know count == n2 + } else { + mu = WARP_SHFL(mu, 0); + sigma2 = WARP_SHFL(sigma2/float(n2), 0); + } + } +} + +template U rsqrt(U v) { + return U(1) / sqrt(v); +} +template<> float rsqrt(float v) { + return rsqrtf(v); +} +template<> double rsqrt(double v) { + return rsqrt(v); +} + +namespace { +// This is the un-specialized struct. Note that we prevent instantiation of this +// struct by putting an undefined symbol in the function body so it won't compile. +// template +// struct SharedMemory +// { +// // Ensure that we won't compile any un-specialized types +// __device__ T *getPointer() +// { +// extern __device__ void error(void); +// error(); +// return NULL; +// } +// }; +// https://github.com/NVIDIA/apex/issues/246 +template +struct SharedMemory; + +template <> +struct SharedMemory +{ + __device__ float *getPointer() + { + extern __shared__ float s_float[]; + return s_float; + } +}; + +} + +template __global__ +void cuApplyLayerNorm( + V* __restrict__ output_vals, + U* __restrict__ mean, + U* __restrict__ invvar, + const T* __restrict__ vals, + const int n1, + const int n2, + const U epsilon, + const V* __restrict__ gamma, + const V* __restrict__ beta + ) +{ + // Assumptions: + // 1) blockDim.x == warpSize + // 2) Tensors are contiguous + // + for (auto i1=blockIdx.y; i1 < n1; i1 += gridDim.y) { + SharedMemory shared; + U* buf = shared.getPointer(); + U mu,sigma2; + cuWelfordMuSigma2(vals,n1,n2,i1,mu,sigma2,buf); + const T* lvals = vals + i1*n2; + V* ovals = output_vals + i1*n2; + U c_invvar = rsqrt(sigma2 + epsilon); + const int numx = blockDim.x * blockDim.y; + const int thrx = threadIdx.x + threadIdx.y * blockDim.x; + if (gamma != NULL && beta != NULL) { + for (int i = thrx; i < n2; i+=numx) { + U curr = static_cast(lvals[i]); + ovals[i] = gamma[i] * static_cast(c_invvar * (curr - mu)) + beta[i]; + } + } else { + for (int i = thrx; i < n2; i+=numx) { + U curr = static_cast(lvals[i]); + ovals[i] = static_cast(c_invvar * (curr - mu)); + } + } + if (threadIdx.x == 0 && threadIdx.y == 0) { + mean[i1] = mu; + invvar[i1] = c_invvar; + } + } +} + +template __device__ +void cuLoadWriteStridedInputs( + const int i1_block, + const int thr_load_row_off, + const int thr_load_col_off, + const int i2_off, + const int row_stride, + U* warp_buf1, + U* warp_buf2, + const T* input, + const V* dout, + const int i1_end, + const int n2, + const U* __restrict__ mean, + const U* __restrict__ invvar + ) +{ + int i1 = i1_block+thr_load_row_off; + if (i1 < i1_end) { + U curr_mean = mean[i1]; + U curr_invvar = invvar[i1]; + for (int k = 0; k < blockDim.y; ++k) { + int i2 = i2_off + k; + int load_idx = i1*n2+i2; + int write_idx = thr_load_row_off*row_stride+thr_load_col_off+k; + if (i2(input[load_idx]); + U curr_dout = static_cast(dout[load_idx]); + warp_buf1[write_idx] = curr_dout; + warp_buf2[write_idx] = curr_dout * (curr_input - curr_mean) * curr_invvar; + } else { + warp_buf1[write_idx] = U(0); + warp_buf2[write_idx] = U(0); + } + } + } else { + for (int k = 0; k < blockDim.y; ++k) { + int write_idx = thr_load_row_off*row_stride+thr_load_col_off+k; + warp_buf1[write_idx] = U(0); + warp_buf2[write_idx] = U(0); + } + } +} + +template __device__ +void cuLoadAddStridedInputs( + const int i1_block, + const int thr_load_row_off, + const int thr_load_col_off, + const int i2_off, + const int row_stride, + U* warp_buf1, + U* warp_buf2, + const T* input, + const V* dout, + const int i1_end, + const int n2, + const U* __restrict__ mean, + const U* __restrict__ invvar + ) +{ + int i1 = i1_block+thr_load_row_off; + if (i1 < i1_end) { + U curr_mean = mean[i1]; + U curr_invvar = invvar[i1]; + for (int k = 0; k < blockDim.y; ++k) { + int i2 = i2_off + k; + int load_idx = i1*n2+i2; + int write_idx = thr_load_row_off*row_stride+thr_load_col_off+k; + if (i2(input[load_idx]); + U curr_dout = static_cast(dout[load_idx]); + warp_buf1[write_idx] += curr_dout; + warp_buf2[write_idx] += curr_dout * (curr_input - curr_mean) * curr_invvar; + } + } + } +} + +template __global__ +void cuComputePartGradGammaBeta( + const V* __restrict__ dout, + const T* __restrict__ input, + const int n1, + const int n2, + const U* __restrict__ mean, + const U* __restrict__ invvar, + U epsilon, + U* part_grad_gamma, + U* part_grad_beta) +{ + const int numsegs_n1 = (n1+blockDim.y*blockDim.y-1) / (blockDim.y*blockDim.y); + const int segs_per_block = (numsegs_n1 + gridDim.y - 1) / gridDim.y; + const int i1_beg = blockIdx.y * segs_per_block * blockDim.y*blockDim.y; + const int i1_beg_plus_one = (blockIdx.y+1) * segs_per_block * blockDim.y*blockDim.y; + const int i1_end = i1_beg_plus_one < n1 ? i1_beg_plus_one : n1; + const int row_stride = blockDim.x+1; + const int thr_load_col_off = (threadIdx.x*blockDim.y)&(blockDim.x-1); + const int thr_load_row_off = (threadIdx.x*blockDim.y)/blockDim.x + threadIdx.y*blockDim.y; + const int i2_off = blockIdx.x * blockDim.x + thr_load_col_off; + SharedMemory shared; + U* buf = shared.getPointer(); // buf has at least blockDim.x * blockDim.y * blockDim.y + (blockDim.y - 1)*(blockDim.x/blockDim.y) elements + U* warp_buf1 = (U*)buf; + U* warp_buf2 = warp_buf1 + blockDim.y * blockDim.y * row_stride; + // compute partial sums from strided inputs + // do this to increase number of loads in flight + cuLoadWriteStridedInputs(i1_beg,thr_load_row_off,thr_load_col_off,i2_off,row_stride,warp_buf1,warp_buf2,input,dout,i1_end,n2,mean,invvar); + for (int i1_block = i1_beg+blockDim.y*blockDim.y; i1_block < i1_end; i1_block+=blockDim.y*blockDim.y) { + cuLoadAddStridedInputs(i1_block,thr_load_row_off,thr_load_col_off,i2_off,row_stride,warp_buf1,warp_buf2,input,dout,i1_end,n2,mean,invvar); + } + __syncthreads(); + // inter-warp reductions + // sum within each warp + U acc1 = U(0); + U acc2 = U(0); + for (int k = 0; k < blockDim.y; ++k) { + int row1 = threadIdx.y + k*blockDim.y; + int idx1 = row1*row_stride + threadIdx.x; + acc1 += warp_buf1[idx1]; + acc2 += warp_buf2[idx1]; + } + warp_buf1[threadIdx.y*row_stride+threadIdx.x] = acc1; + warp_buf2[threadIdx.y*row_stride+threadIdx.x] = acc2; + __syncthreads(); + // sum all warps + for (int offset = blockDim.y/2; offset > 1; offset /= 2) { + if (threadIdx.y < offset) { + int row1 = threadIdx.y; + int row2 = threadIdx.y + offset; + int idx1 = row1*row_stride + threadIdx.x; + int idx2 = row2*row_stride + threadIdx.x; + warp_buf1[idx1] += warp_buf1[idx2]; + warp_buf2[idx1] += warp_buf2[idx2]; + } + __syncthreads(); + } + int i2 = blockIdx.x * blockDim.x + threadIdx.x; + if (threadIdx.y == 0 && i2 < n2) { + int row1 = threadIdx.y; + int row2 = threadIdx.y + 1; + int idx1 = row1*row_stride + threadIdx.x; + int idx2 = row2*row_stride + threadIdx.x; + part_grad_beta[blockIdx.y*n2+i2] = warp_buf1[idx1] + warp_buf1[idx2]; + part_grad_gamma[blockIdx.y*n2+i2] = warp_buf2[idx1] + warp_buf2[idx2]; + } +} + +template __global__ +void cuComputeGradGammaBeta( + const U* part_grad_gamma, + const U* part_grad_beta, + const int part_size, + const int n1, + const int n2, + V* grad_gamma, + V* grad_beta) +{ + // sum partial gradients for gamma and beta + SharedMemory shared; + U* buf = shared.getPointer(); + int i2 = blockIdx.x * blockDim.x + threadIdx.x; + if (i2 < n2) { + // each warp does sequential reductions until reduced part_size is num_warps + int num_warp_reductions = part_size / blockDim.y; + U sum_gamma = U(0); + U sum_beta = U(0); + const U* part_grad_gamma_ptr = part_grad_gamma + threadIdx.y * num_warp_reductions * n2 + i2; + const U* part_grad_beta_ptr = part_grad_beta + threadIdx.y * num_warp_reductions * n2 + i2; + for (int warp_offset = 0; warp_offset < num_warp_reductions; ++warp_offset) { + sum_gamma += part_grad_gamma_ptr[warp_offset*n2]; + sum_beta += part_grad_beta_ptr[warp_offset*n2]; + } + // inter-warp reductions + const int nbsize3 = blockDim.x * blockDim.y / 2; + for (int offset = blockDim.y/2; offset >= 1; offset /= 2) { + // top half write to shared memory + if (threadIdx.y >= offset && threadIdx.y < 2*offset) { + const int write_idx = (threadIdx.y - offset) * blockDim.x + threadIdx.x; + buf[write_idx] = sum_gamma; + buf[write_idx+nbsize3] = sum_beta; + } + __syncthreads(); + // bottom half sums + if (threadIdx.y < offset) { + const int read_idx = threadIdx.y * blockDim.x + threadIdx.x; + sum_gamma += buf[read_idx]; + sum_beta += buf[read_idx+nbsize3]; + } + __syncthreads(); + } + // write out fully summed gradients + if (threadIdx.y == 0) { + grad_gamma[i2] = sum_gamma; + grad_beta[i2] = sum_beta; + } + } +} + +template __global__ +void cuComputeGradInput( + const V* __restrict__ dout, + const T* __restrict__ input, + const int n1, + const int n2, + const U* __restrict__ mean, + const U* __restrict__ invvar, + U epsilon, + const V* gamma, + T* grad_input) +{ + for (auto i1=blockIdx.y; i1 < n1; i1 += gridDim.y) { + U sum_loss1 = U(0); + U sum_loss2 = U(0); + const U c_mean = mean[i1]; + const U c_invvar = invvar[i1]; + const T* k_input = input + i1*n2; + const V* k_dout = dout + i1*n2; + const int numx = blockDim.x * blockDim.y; + const int thrx = threadIdx.x + threadIdx.y * blockDim.x; + if (gamma != NULL) { + int l = 4*thrx; + for (; l+3 < n2; l+=4*numx) { + for (int k = 0; k < 4; ++k) { + const U c_h = static_cast(k_input[l+k]); + const U c_loss = static_cast(k_dout[l+k]); + sum_loss1 += c_loss * gamma[l+k]; + sum_loss2 += c_loss * gamma[l+k] * (c_h - c_mean) * c_invvar; + } + } + for (; l < n2; ++l) { + const U c_h = static_cast(k_input[l]); + const U c_loss = static_cast(k_dout[l]); + sum_loss1 += c_loss * gamma[l]; + sum_loss2 += c_loss * gamma[l] * (c_h - c_mean) * c_invvar; + } + } else { + int l = 4*thrx; + for (; l+3 < n2; l+=4*numx) { + for (int k = 0; k < 4; ++k) { + const U c_h = static_cast(k_input[l+k]); + const U c_loss = static_cast(k_dout[l+k]); + sum_loss1 += c_loss; + sum_loss2 += c_loss * (c_h - c_mean) * c_invvar; + } + } + for (; l < n2; ++l) { + const U c_h = static_cast(k_input[l]); + const U c_loss = static_cast(k_dout[l]); + sum_loss1 += c_loss; + sum_loss2 += c_loss * (c_h - c_mean) * c_invvar; + } + } + // intra-warp reductions + for (int mask = blockDim.x/2; mask > 0; mask /= 2) { + sum_loss1 += WARP_SHFL_XOR(sum_loss1, mask); + sum_loss2 += WARP_SHFL_XOR(sum_loss2, mask); + } + // inter-warp reductions + if (blockDim.y > 1) { + SharedMemory shared; + U* buf = shared.getPointer(); + for (int offset = blockDim.y/2; offset > 0; offset /= 2) { + // upper half of warps write to shared + if (threadIdx.y >= offset && threadIdx.y < 2*offset) { + const int wrt_i = (threadIdx.y - offset) * blockDim.x + threadIdx.x; + buf[2*wrt_i] = sum_loss1; + buf[2*wrt_i+1] = sum_loss2; + } + __syncthreads(); + // lower half merges + if (threadIdx.y < offset) { + const int read_i = threadIdx.y * blockDim.x + threadIdx.x; + sum_loss1 += buf[2*read_i]; + sum_loss2 += buf[2*read_i+1]; + } + __syncthreads(); + } + if (threadIdx.y == 0) { + buf[2*threadIdx.x] = sum_loss1; + buf[2*threadIdx.x+1] = sum_loss2; + } + __syncthreads(); + if (threadIdx.y !=0) { + sum_loss1 = buf[2*threadIdx.x]; + sum_loss2 = buf[2*threadIdx.x+1]; + } + } + // all threads now have the two sums over l + U fH = (U)n2; + U term1 = (U(1) / fH) * c_invvar; + T* k_grad_input = grad_input + i1*n2; + if (gamma != NULL) { + for (int l = thrx; l < n2; l+=numx) { + const U c_h = static_cast(k_input[l]); + const U c_loss = static_cast(k_dout[l]); + U f_grad_input = fH * c_loss * gamma[l]; + f_grad_input -= sum_loss1; + f_grad_input -= (c_h - c_mean) * c_invvar * sum_loss2; + f_grad_input *= term1; + k_grad_input[l] = static_cast(f_grad_input); + } + } else { + for (int l = thrx; l < n2; l+=numx) { + const U c_h = static_cast(k_input[l]); + const U c_loss = static_cast(k_dout[l]); + U f_grad_input = fH * c_loss; + f_grad_input -= sum_loss1; + f_grad_input -= (c_h - c_mean) * c_invvar * sum_loss2; + f_grad_input *= term1; + k_grad_input[l] = static_cast(f_grad_input); + } + } + } +} + + + + +template +void HostApplyLayerNorm( + V* output, + U* mean, + U* invvar, + const T* input, + int n1, + int n2, + double epsilon, + const V* gamma, + const V* beta + ) +{ + auto stream = at::cuda::getCurrentCUDAStream().stream(); + const dim3 threads(32,4,1); + const uint64_t maxGridY = + at::cuda::getCurrentDeviceProperties()->maxGridSize[1]; + const dim3 blocks(1, std::min((uint64_t)n1, maxGridY), 1); + int nshared = + threads.y > 1 ? + threads.y*sizeof(U)+(threads.y/2)*sizeof(U) : + 0; + cuApplyLayerNorm<<>>( + output, + mean, + invvar, + input, + n1,n2, + U(epsilon), + gamma,beta); +} + + +void cuda_layer_norm( + at::Tensor* output, + at::Tensor* mean, + at::Tensor* invvar, + at::Tensor* input, + int n1, + int n2, + #ifdef VERSION_GE_1_1 + at::IntArrayRef normalized_shape, + #else + at::IntList normalized_shape, + #endif + at::Tensor* gamma, + at::Tensor* beta, + double epsilon) +{ + using namespace at; + DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES( + input->scalar_type(), output->scalar_type(), "cuda_layer_norm_kernel", + HostApplyLayerNorm( + output->DATA_PTR(), + mean->DATA_PTR(), + invvar->DATA_PTR(), + input->DATA_PTR(), + n1,n2, + epsilon, + gamma != NULL ? gamma->DATA_PTR() : NULL, + beta != NULL ? beta->DATA_PTR() : NULL); + ) +} + + +template +void HostLayerNormGradient( + const V* dout, + const U* mean, + const U* invvar, + at::Tensor* input, + int n1, + int n2, + const V* gamma, + const V* beta, + double epsilon, + T* grad_input, + V* grad_gamma, + V* grad_beta + ) +{ + auto stream = at::cuda::getCurrentCUDAStream().stream(); + + if (gamma != NULL && beta != NULL) { + // compute grad_gamma(j) and grad_beta(j) + const int part_size = 16; + const dim3 threads2(32,4,1); + const dim3 blocks2((n2+threads2.x-1)/threads2.x,part_size,1); + const int nshared2_a = 2 * sizeof(U) * threads2.y * threads2.y * + (threads2.x + 1); + const int nshared2_b = threads2.x * threads2.y * sizeof(U); + const int nshared2 = nshared2_a > nshared2_b ? nshared2_a : nshared2_b; + at::Tensor part_grad_gamma = at::empty( + {part_size,n2}, input->options().dtype(at::ScalarType::Float)); + at::Tensor part_grad_beta = at::empty_like(part_grad_gamma); + cuComputePartGradGammaBeta<<>>( + dout, + input->DATA_PTR(), + n1,n2, + mean, + invvar, + U(epsilon), + part_grad_gamma.DATA_PTR(), + part_grad_beta.DATA_PTR()); + + const dim3 threads3(32,8,1); + const dim3 blocks3((n2+threads2.x-1)/threads2.x,1,1); + const int nshared3 = threads3.x * threads3.y * sizeof(U); + cuComputeGradGammaBeta<<>>( + part_grad_gamma.DATA_PTR(), + part_grad_beta.DATA_PTR(), + part_size, + n1,n2, + grad_gamma, + grad_beta); + } + + // compute grad_input + const uint64_t maxGridY = + at::cuda::getCurrentDeviceProperties()->maxGridSize[1]; + const dim3 blocks1(1, std::min((uint64_t)n1, maxGridY), 1); + const dim3 threads1(32,4,1); + int nshared = + threads1.y > 1 ? + threads1.y*threads1.x*sizeof(U) : + 0; + cuComputeGradInput<<>>( + dout, + input->DATA_PTR(), + n1,n2, + mean, + invvar, + U(epsilon), + gamma, + grad_input); +} + + +void cuda_layer_norm_gradient( + at::Tensor* dout, + at::Tensor* mean, + at::Tensor* invvar, + at::Tensor* input, + int n1, + int n2, + #ifdef VERSION_GE_1_1 + at::IntArrayRef normalized_shape, + #else + at::IntList normalized_shape, + #endif + at::Tensor* gamma, + at::Tensor* beta, + double epsilon, + at::Tensor* grad_input, + at::Tensor* grad_gamma, + at::Tensor* grad_beta) +{ + using namespace at; + DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES( + input->scalar_type(), gamma->scalar_type(), + "cuda_layer_norm_gradient_kernel", + HostLayerNormGradient( + dout->DATA_PTR(), + mean->DATA_PTR(), + invvar->DATA_PTR(), + input, + n1,n2, + // TMJ pass NULL argument for gamma, beta, grad_gamma and grad_beta + // if gamma Tensor is NULL on input. + gamma != NULL ? gamma->DATA_PTR() : NULL, + gamma != NULL ? beta->DATA_PTR() : NULL, + epsilon, + grad_input->DATA_PTR(), + gamma != NULL ? grad_gamma->DATA_PTR() : NULL, + gamma != NULL ? grad_beta->DATA_PTR() : NULL); + ) +} diff --git a/nemo/collections/nlp/modules/common/megatron/fused_kernels/scaled_masked_softmax.cpp b/nemo/collections/nlp/modules/common/megatron/fused_kernels/scaled_masked_softmax.cpp new file mode 100644 index 000000000..1852aee6f --- /dev/null +++ b/nemo/collections/nlp/modules/common/megatron/fused_kernels/scaled_masked_softmax.cpp @@ -0,0 +1,97 @@ +/* coding=utf-8 + * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include + +namespace multihead_attn { +namespace fused_softmax { +namespace scaled_masked_softmax { + +torch::Tensor fwd_cuda( + torch::Tensor const& input, + torch::Tensor const& mask, + float scale_factor); + +torch::Tensor bwd_cuda( + torch::Tensor const& output_grads, + torch::Tensor const& softmax_results, + float scale_factor); + +int get_batch_per_block_cuda( + int query_seq_len, + int key_seq_len, + int batches, + int attn_heads); + +torch::Tensor fwd( + torch::Tensor const& input, + torch::Tensor const& mask, + float scale_factor) { + AT_ASSERTM(input.dim() == 4, "expected 4D tensor"); + AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) || + (input.scalar_type() == at::ScalarType::BFloat16), + "Only fp16 and bf16 are supported"); + AT_ASSERTM(mask.dim() == 4, "expected 4D tensor"); + + return fwd_cuda(input, mask, scale_factor); +} + +torch::Tensor bwd( + torch::Tensor const& output_grads, + torch::Tensor const& softmax_results, + float scale_factor) { + + AT_ASSERTM(output_grads.dim() == 4, "expected 3D tensor"); + AT_ASSERTM(softmax_results.dim() == 4, "expected 3D tensor"); + + AT_ASSERTM((output_grads.scalar_type() == at::ScalarType::Half) || + (output_grads.scalar_type() == at::ScalarType::BFloat16), + "Only fp16 and bf16 are supported"); + AT_ASSERTM((softmax_results.scalar_type() == at::ScalarType::Half) || + (softmax_results.scalar_type() == at::ScalarType::BFloat16), + "Only fp16 and bf16 are supported"); + + return bwd_cuda(output_grads, softmax_results, scale_factor); +} + +int get_batch_per_block( + int query_seq_len, + int key_seq_len, + int batches, + int attn_heads) { + return get_batch_per_block_cuda(query_seq_len, key_seq_len, batches, attn_heads); +} + +} // end namespace scaled_masked_softmax +} // end namespace fused_softmax +} // end namespace multihead_attn + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("forward", + &multihead_attn::fused_softmax::scaled_masked_softmax::fwd, + "Self Multihead Attention scaled, time masked softmax -- Forward."); + + m.def("backward", + &multihead_attn::fused_softmax::scaled_masked_softmax::bwd, + "Self Multihead Attention scaled, time masked softmax -- Backward."); + + m.def("get_batch_per_block", + &multihead_attn::fused_softmax::scaled_masked_softmax::get_batch_per_block, + "Return Batch per block size." + ); +} diff --git a/nemo/collections/nlp/modules/common/megatron/fused_kernels/scaled_masked_softmax.h b/nemo/collections/nlp/modules/common/megatron/fused_kernels/scaled_masked_softmax.h new file mode 100644 index 000000000..45e8dcea2 --- /dev/null +++ b/nemo/collections/nlp/modules/common/megatron/fused_kernels/scaled_masked_softmax.h @@ -0,0 +1,505 @@ +/* coding=utf-8 + * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +namespace { + +template +__device__ __inline__ void copy_vector(Datatype *dst, const Datatype *src); + +template <> +__device__ __inline__ void copy_vector(c10::BFloat16 *dst, const c10::BFloat16 *src) { *dst = *src; } + +template <> +__device__ __inline__ void copy_vector(c10::BFloat16 *dst, const c10::BFloat16 *src) { *((float2*) dst) = *((float2*) src); } + +template <> +__device__ __inline__ void copy_vector(c10::Half *dst, const c10::Half *src) { *dst = *src; } + +template <> +__device__ __inline__ void copy_vector(c10::Half *dst, const c10::Half *src) { *((float2*) dst) = *((float2*) src); } + +template <> +__device__ __inline__ void copy_vector(uint8_t *dst, const uint8_t *src) { *dst = *src; } + +template <> +__device__ __inline__ void copy_vector(uint8_t *dst, const uint8_t *src) {*((half2*) dst) = *((half2*) src); } + +int log2_ceil(int value) { + int log2_value = 0; + while ((1 << log2_value) < value) ++log2_value; + return log2_value; +} + +template +struct Add { + __device__ __forceinline__ T operator()(T a, T b) const { + return a + b; + } +}; + +template +struct Max { + __device__ __forceinline__ T operator()(T a, T b) const { + return a < b ? b : a; + } +}; + +template +__device__ __forceinline__ T WARP_SHFL_XOR_NATIVE(T value, int laneMask, int width = warpSize, unsigned int mask = 0xffffffff) +{ +#if CUDA_VERSION >= 9000 + return __shfl_xor_sync(mask, value, laneMask, width); +#else + return __shfl_xor(value, laneMask, width); +#endif +} + +template class ReduceOp> +__device__ __forceinline__ void warp_reduce(acc_t* sum) { + ReduceOp r; + #pragma unroll + for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) { + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + acc_t b = WARP_SHFL_XOR_NATIVE(sum[i], offset, WARP_SIZE); + sum[i] = r(sum[i], b); + } + } +} + +/* + * Extended softmax (from native aten pytorch) with following additional features + * 1) input scaling + * 2) Explicit masking + */ +template +__global__ void scaled_masked_softmax_warp_forward( + output_t *dst, + const input_t *src, + const uint8_t *mask, + const acc_t scale, + int micro_batch_size, + int element_count, + int pad_batches) +{ + // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and + // warp_size of method warp_softmax_forward_kernel. + constexpr int next_power_of_two = 1 << log2_elements; + constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; + constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; + constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; + constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4; + + // blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, ) + // gridDim/blockIdx = (seq_len, attn_heads, batches) + int first_batch = (blockDim.y * (blockIdx.x + gridDim.x * (blockIdx.y + gridDim.y * blockIdx.z))+ threadIdx.y) * WARP_BATCH; + int pad_first_batch = 0; + if (pad_batches != 1) { // bert style + pad_first_batch = (blockDim.y * (blockIdx.x + gridDim.x * blockIdx.z) + threadIdx.y) * WARP_BATCH; + } else { // gpt2 style + pad_first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH; + } + + // micro_batch_size might not be a multiple of WARP_BATCH. Check how + // many batches have to computed within this WARP. + int local_batches = micro_batch_size - first_batch; + if (local_batches > WARP_BATCH) + local_batches = WARP_BATCH; + + // there might be multiple batches per warp. compute the index within the batch + int local_idx = threadIdx.x; + + src += first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx; + dst += first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx; + mask += pad_first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx; + + // load data from global memory + acc_t elements[WARP_BATCH][WARP_ITERATIONS]; + input_t temp_data[ELEMENTS_PER_LDG_STG]; + uint8_t temp_mask[ELEMENTS_PER_LDG_STG]; + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + int batch_element_count = (i >= local_batches) ? 0 : element_count; + + #pragma unroll + for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { + int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; + + if (element_index < batch_element_count) { + int itr_idx = i*element_count+it*WARP_SIZE; + copy_vector(temp_data, src + itr_idx); + copy_vector(temp_mask, mask + itr_idx); + + #pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + if (temp_mask[element] != 1) { + elements[i][it + element] = (acc_t)temp_data[element] * scale; + } else { + elements[i][it + element] = -10000.0; + } + } + } else { + #pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + elements[i][it + element] = -std::numeric_limits::infinity(); + } + } + } + } + + // compute max_value + acc_t max_value[WARP_BATCH]; + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + max_value[i] = elements[i][0]; + #pragma unroll + for (int it = 1; it < WARP_ITERATIONS; ++it) { + max_value[i] = (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it]; + } + } + warp_reduce(max_value); + + acc_t sum[WARP_BATCH] { 0.0f }; + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + #pragma unroll + for (int it = 0; it < WARP_ITERATIONS; ++it) { + elements[i][it] = std::exp((elements[i][it] - max_value[i])); + sum[i] += elements[i][it]; + } + } + warp_reduce(sum); + + // store result + output_t out[ELEMENTS_PER_LDG_STG]; + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + if (i >= local_batches) + break; + #pragma unroll + for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { + int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; + if (element_index < element_count) { + #pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + out[element] = elements[i][it + element] / sum[i]; + } + copy_vector(dst + i * element_count + it * WARP_SIZE, out); + } else { + break; + } + } + } +} + +template +__global__ void scaled_masked_softmax_warp_backward( + output_t *gradInput, + input_t *grad, + const input_t *output, + acc_t scale, + int micro_batch_size, + int element_count) +{ + // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and + // warp_size of method warp_softmax_backward_kernel. + constexpr int next_power_of_two = 1 << log2_elements; + constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; + constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; + constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; + constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4; + + // blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, ) + // gridDim/blockIdx = (seq_len, attn_heads, batches) + int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH; + + // micro_batch_size might not be a multiple of WARP_BATCH. Check how + // many batches have to computed within this WARP. + int local_batches = micro_batch_size - first_batch; + if (local_batches > WARP_BATCH) + local_batches = WARP_BATCH; + + // there might be multiple batches per warp. compute the index within the batch + int local_idx = threadIdx.x; + + // the first element to process by the current thread + int thread_offset = first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx; + grad += thread_offset; + output += thread_offset; + gradInput += thread_offset; + + // load data from global memory + acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS] { 0.0f }; + acc_t output_reg[WARP_BATCH][WARP_ITERATIONS] { 0.0f }; + input_t temp_grad[ELEMENTS_PER_LDG_STG]; + input_t temp_output[ELEMENTS_PER_LDG_STG]; + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + int batch_element_count = (i >= local_batches) ? 0 : element_count; + + #pragma unroll + for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { + int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; + if (element_index < batch_element_count) { + copy_vector(temp_grad, grad + i * element_count + it * WARP_SIZE); + copy_vector(temp_output, output + i * element_count + it * WARP_SIZE); + + #pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + output_reg[i][it + element] = (acc_t)temp_output[element]; + } + #pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + grad_reg[i][it + element] = (acc_t)temp_grad[element] * output_reg[i][it + element]; + } + } + } + } + + acc_t sum[WARP_BATCH]; + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + sum[i] = grad_reg[i][0]; + #pragma unroll + for (int it = 1; it < WARP_ITERATIONS; ++it) { + sum[i] += grad_reg[i][it]; + } + } + warp_reduce(sum); + + // store result + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + if (i >= local_batches) + break; + #pragma unroll + for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { + int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; + if (element_index < element_count) { + // compute gradients + output_t out[ELEMENTS_PER_LDG_STG]; + #pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + out[element] = (output_t)(scale * (grad_reg[i][it + element] - output_reg[i][it + element] * sum[i])); + } + copy_vector(gradInput + i * element_count + it * WARP_SIZE, out); + } + } + } +} +} // end of anonymous namespace + +int get_batch_per_block(int query_seq_len, int key_seq_len, int batches, int attn_heads){ + int log2_elements = log2_ceil(key_seq_len); + const int next_power_of_two = 1 << log2_elements; + + int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; + int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; + + constexpr int threads_per_block = 128; + int warps_per_block = (threads_per_block / warp_size); + int batches_per_block = warps_per_block * batches_per_warp; + + return batches_per_block; +} + +template +void dispatch_scaled_masked_softmax_forward( + output_t *dst, + const input_t *src, + const uint8_t *mask, + const input_t scale, + int query_seq_len, + int key_seq_len, + int batches, + int attn_heads, + int pad_batches) +{ + TORCH_INTERNAL_ASSERT(key_seq_len >= 0 && key_seq_len <= 2048 ); + if (key_seq_len == 0) { + return; + } else { + int log2_elements = log2_ceil(key_seq_len); + const int next_power_of_two = 1 << log2_elements; + int batch_count = batches * attn_heads * query_seq_len; + + // This value must match the WARP_SIZE constexpr value computed inside softmax_warp_forward. + int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; + + // This value must match the WARP_BATCH constexpr value computed inside softmax_warp_forward. + int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; + + // use 128 threads per block to maximimize gpu utilization + constexpr int threads_per_block = 128; + + int warps_per_block = (threads_per_block / warp_size); + int batches_per_block = warps_per_block * batches_per_warp; + TORCH_INTERNAL_ASSERT(query_seq_len%batches_per_block == 0); + dim3 blocks(query_seq_len/batches_per_block, attn_heads, batches); + dim3 threads(warp_size, warps_per_block, 1); + // Launch code would be more elegant if C++ supported FOR CONSTEXPR + switch (log2_elements) { + case 0: // 1 + scaled_masked_softmax_warp_forward + <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + case 1: // 2 + scaled_masked_softmax_warp_forward + <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + case 2: // 4 + scaled_masked_softmax_warp_forward + <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + case 3: // 8 + scaled_masked_softmax_warp_forward + <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + case 4: // 16 + scaled_masked_softmax_warp_forward + <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + case 5: // 32 + scaled_masked_softmax_warp_forward + <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + case 6: // 64 + scaled_masked_softmax_warp_forward + <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + case 7: // 128 + scaled_masked_softmax_warp_forward + <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + case 8: // 256 + scaled_masked_softmax_warp_forward + <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + case 9: // 512 + scaled_masked_softmax_warp_forward + <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + case 10: // 1024 + scaled_masked_softmax_warp_forward + <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + case 11: // 2048 + scaled_masked_softmax_warp_forward + <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + default: + break; + } + } +} + +template +void dispatch_scaled_masked_softmax_backward( + output_t *grad_input, + input_t *grad, + const input_t *output, + const acc_t scale, + int query_seq_len, + int key_seq_len, + int batches, + int attn_heads) +{ + TORCH_INTERNAL_ASSERT( key_seq_len >= 0 && key_seq_len <= 2048 ); + if (key_seq_len == 0) { + return; + } else { + int log2_elements = log2_ceil(key_seq_len); + const int next_power_of_two = 1 << log2_elements; + int batch_count = batches * attn_heads * query_seq_len; + + // This value must match the WARP_SIZE constexpr value computed inside softmax_warp_backward. + int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; + + // This value must match the WARP_BATCH constexpr value computed inside softmax_warp_backward. + int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; + + // use 128 threads per block to maximimize gpu utilization + constexpr int threads_per_block = 128; + + int warps_per_block = (threads_per_block / warp_size); + int batches_per_block = warps_per_block * batches_per_warp; + int blocks = batch_count/batches_per_block; + dim3 threads(warp_size, warps_per_block, 1); + // Launch code would be more elegant if C++ supported FOR CONSTEXPR + switch (log2_elements) { + case 0: // 1 + scaled_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); + break; + case 1: // 2 + scaled_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); + break; + case 2: // 4 + scaled_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); + break; + case 3: // 8 + scaled_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); + break; + case 4: // 16 + scaled_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); + break; + case 5: // 32 + scaled_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); + break; + case 6: // 64 + scaled_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); + break; + case 7: // 128 + scaled_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); + break; + case 8: // 256 + scaled_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); + break; + case 9: // 512 + scaled_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); + break; + case 10: // 1024 + scaled_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); + break; + case 11: // 2048 + scaled_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); + break; + default: + break; + } + } +} diff --git a/nemo/collections/nlp/modules/common/megatron/fused_kernels/scaled_masked_softmax_cuda.cu b/nemo/collections/nlp/modules/common/megatron/fused_kernels/scaled_masked_softmax_cuda.cu new file mode 100644 index 000000000..902d36dd0 --- /dev/null +++ b/nemo/collections/nlp/modules/common/megatron/fused_kernels/scaled_masked_softmax_cuda.cu @@ -0,0 +1,117 @@ +/* coding=utf-8 + * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include +#include +#include +#include "scaled_masked_softmax.h" +#include "type_shim.h" + +namespace multihead_attn { +namespace fused_softmax { +namespace scaled_masked_softmax { + +int get_batch_per_block_cuda(int query_seq_len, int key_seq_len, int batches, int attn_heads){ + return get_batch_per_block(query_seq_len, key_seq_len, batches, attn_heads); +} + + +torch::Tensor fwd_cuda( + torch::Tensor const& input, + torch::Tensor const& mask, + float scale_factor) +{ + // input is a 4d tensor with dimensions [batches, attn_heads, seq_len, seq_len] + const int batches = input.size(0); + const int pad_batches = mask.size(0); + const int attn_heads = input.size(1); + const int query_seq_len = input.size(2); + const int key_seq_len = input.size(3); + TORCH_INTERNAL_ASSERT(key_seq_len <= 2048); + TORCH_INTERNAL_ASSERT(query_seq_len > 1); + TORCH_INTERNAL_ASSERT(pad_batches == 1 || pad_batches == batches); + TORCH_INTERNAL_ASSERT(mask.size(1) == 1); + TORCH_INTERNAL_ASSERT(mask.size(2) == query_seq_len); + TORCH_INTERNAL_ASSERT(mask.size(3) == key_seq_len); + + // Output + auto act_options = input.options().requires_grad(false); + torch::Tensor softmax_results = + torch::empty({batches, attn_heads, query_seq_len, key_seq_len}, act_options); + + // Softmax Intermediate Result Ptr + void* input_ptr = static_cast(input.data_ptr()); + void* mask_ptr = static_cast(mask.data_ptr()); + void* softmax_results_ptr = static_cast(softmax_results.data_ptr()); + + DISPATCH_HALF_AND_BFLOAT( + input.scalar_type(), + "dispatch_scaled_masked_softmax_forward", + dispatch_scaled_masked_softmax_forward( + reinterpret_cast(softmax_results_ptr), + reinterpret_cast(input_ptr), + reinterpret_cast(mask_ptr), + scale_factor, + query_seq_len, + key_seq_len, + batches, + attn_heads, + pad_batches); + ); + return softmax_results; +} + +torch::Tensor bwd_cuda( + torch::Tensor const& output_grads_, + torch::Tensor const& softmax_results_, + float scale_factor) { + + auto output_grads = output_grads_.contiguous(); + auto softmax_results = softmax_results_.contiguous(); + + //output grads is a 4d tensor with dimensions [batches, attn_heads, seq_len, seq_len] + const int batches = output_grads.size(0); + const int attn_heads = output_grads.size(1); + const int query_seq_len = output_grads.size(2); + const int key_seq_len = output_grads.size(3); + + void* output_grads_ptr = static_cast(output_grads.data_ptr()); + + //Softmax Grad + DISPATCH_HALF_AND_BFLOAT( + output_grads_.scalar_type(), + "dispatch_scaled_masked_softmax_backward", + dispatch_scaled_masked_softmax_backward( + reinterpret_cast(output_grads_ptr), + reinterpret_cast(output_grads_ptr), + reinterpret_cast(softmax_results.data_ptr()), + scale_factor, + query_seq_len, + key_seq_len, + batches, + attn_heads); + ); + + //backward pass is completely in-place + return output_grads; +} +} +} +} diff --git a/nemo/collections/nlp/modules/common/megatron/fused_kernels/scaled_upper_triang_masked_softmax.cpp b/nemo/collections/nlp/modules/common/megatron/fused_kernels/scaled_upper_triang_masked_softmax.cpp new file mode 100644 index 000000000..ea283588d --- /dev/null +++ b/nemo/collections/nlp/modules/common/megatron/fused_kernels/scaled_upper_triang_masked_softmax.cpp @@ -0,0 +1,72 @@ +/* coding=utf-8 + * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include + +namespace multihead_attn { +namespace fused_softmax { +namespace scaled_upper_triang_masked_softmax { + +torch::Tensor fwd_cuda( + torch::Tensor const& input, + float scale_factor); + +torch::Tensor bwd_cuda( + torch::Tensor const& output_grads, + torch::Tensor const& softmax_results, + float scale_factor); + +torch::Tensor fwd(torch::Tensor const& input, float scale_factor) { + AT_ASSERTM(input.dim() == 3, "expected 3D tensor"); + AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) || + (input.scalar_type() == at::ScalarType::BFloat16), + "Only fp16 and bf16 are supported"); + + return fwd_cuda(input, scale_factor); +} + +torch::Tensor bwd( + torch::Tensor const& output_grads, + torch::Tensor const& softmax_results, + float scale_factor) { + + AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor"); + AT_ASSERTM(softmax_results.dim() == 3, "expected 3D tensor"); + + AT_ASSERTM((output_grads.scalar_type() == at::ScalarType::Half) || + (output_grads.scalar_type() == at::ScalarType::BFloat16), + "Only fp16 and bf16 are supported"); + AT_ASSERTM((softmax_results.scalar_type() == at::ScalarType::Half) || + (softmax_results.scalar_type() == at::ScalarType::BFloat16), + "Only fp16 and bf16 are supported"); + + return bwd_cuda(output_grads, softmax_results, scale_factor); +} + +} // end namespace scaled_upper_triang_masked_softmax +} // end namespace fused_softmax +} // end namespace multihead_attn + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("forward", + &multihead_attn::fused_softmax::scaled_upper_triang_masked_softmax::fwd, + "Self Multihead Attention scaled, time masked softmax -- Forward."); + m.def("backward", + &multihead_attn::fused_softmax::scaled_upper_triang_masked_softmax::bwd, + "Self Multihead Attention scaled, time masked softmax -- Backward."); +} diff --git a/nemo/collections/nlp/modules/common/megatron/fused_kernels/scaled_upper_triang_masked_softmax.h b/nemo/collections/nlp/modules/common/megatron/fused_kernels/scaled_upper_triang_masked_softmax.h new file mode 100644 index 000000000..6df83fc10 --- /dev/null +++ b/nemo/collections/nlp/modules/common/megatron/fused_kernels/scaled_upper_triang_masked_softmax.h @@ -0,0 +1,513 @@ +/* coding=utf-8 + * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include +#include +#include + +namespace { + +template +__device__ __inline__ void copy_vector(Datatype *dst, const Datatype *src); + +template <> +__device__ __inline__ void copy_vector(c10::BFloat16 *dst, const c10::BFloat16 *src) { *dst = *src; } + +template <> +__device__ __inline__ void copy_vector(c10::BFloat16 *dst, const c10::BFloat16 *src) { *((float2*) dst) = *((float2*) src); } + +template <> +__device__ __inline__ void copy_vector(c10::Half *dst, const c10::Half *src) { *dst = *src; } + +template <> +__device__ __inline__ void copy_vector(c10::Half *dst, const c10::Half *src) { *((float2*) dst) = *((float2*) src); } + +template <> +__device__ __inline__ void copy_vector(uint8_t *dst, const uint8_t *src) { *dst = *src; } + +template <> +__device__ __inline__ void copy_vector(uint8_t *dst, const uint8_t *src) {*((half2*) dst) = *((half2*) src); } + +template +__device__ __inline__ void copy_zero_vector(Datatype *dst); + +template <> +__device__ __inline__ void copy_zero_vector(c10::BFloat16 *dst) { *dst = 0.0; } + +template <> +__device__ __inline__ void copy_zero_vector(c10::BFloat16 *dst) { *((float2*) dst) = make_float2(0.0f, 0.0f); } + +template <> +__device__ __inline__ void copy_zero_vector(c10::Half *dst) { *dst = 0.0; } + +template <> +__device__ __inline__ void copy_zero_vector(c10::Half *dst) { *((float2*) dst) = make_float2(0.0f, 0.0f); } + + +int log2_ceil(int value) { + int log2_value = 0; + while ((1 << log2_value) < value) ++log2_value; + return log2_value; +} + +template +struct Add { + __device__ __forceinline__ T operator()(T a, T b) const { + return a + b; + } +}; + +template +struct Max { + __device__ __forceinline__ T operator()(T a, T b) const { + return a < b ? b : a; + } +}; + +template +__device__ __forceinline__ T WARP_SHFL_XOR_NATIVE(T value, int laneMask, int width = warpSize, unsigned int mask = 0xffffffff) +{ +#if CUDA_VERSION >= 9000 + return __shfl_xor_sync(mask, value, laneMask, width); +#else + return __shfl_xor(value, laneMask, width); +#endif +} + +template class ReduceOp> +__device__ __forceinline__ void warp_reduce(acc_t* sum) { + ReduceOp r; + #pragma unroll + for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) { + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + acc_t b = WARP_SHFL_XOR_NATIVE(sum[i], offset, WARP_SIZE); + sum[i] = r(sum[i], b); + } + } +} + +/* + * Extended softmax (from native aten pytorch) with following additional features + * 1) input scaling + * 2) Implicit time (diagonal masking) + */ +template +__global__ void scaled_upper_triang_masked_softmax_warp_forward( + output_t *dst, + const input_t *src, + const acc_t scale, + int micro_batch_size, + int stride, + int element_count) +{ + // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and + // warp_size of method warp_softmax_forward_kernel. + constexpr int next_power_of_two = 1 << log2_elements; + constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; + constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; + constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; + constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4; + + int first_batch = (blockDim.y * blockIdx.y + threadIdx.y) * gridDim.x * WARP_BATCH + blockIdx.x; + int local_seq = blockIdx.x + 1; + int warp_iteration_limit = (local_seq + ELEMENTS_PER_LDG_STG * WARP_SIZE - 1)/ WARP_SIZE; + + // micro_batch_size might not be a multiple of WARP_BATCH. Check how + // many batches have to computed within this WARP. + int local_batches = micro_batch_size - first_batch; + if (local_batches > WARP_BATCH) + local_batches = WARP_BATCH; + + // there might be multiple batches per warp. compute the index within the batch + int local_idx = threadIdx.x; + + src += first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx; + dst += first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx; + + // load data from global memory + acc_t elements[WARP_BATCH][WARP_ITERATIONS]; + input_t temp_data[ELEMENTS_PER_LDG_STG]; + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + int batch_element_count = (i >= local_batches) ? 0 : local_seq; + + #pragma unroll + for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { + int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; + + if (element_index < batch_element_count) { + copy_vector(temp_data, src + i*element_count*stride + it*WARP_SIZE); + + #pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + if ((element_index + element) < batch_element_count) { + elements[i][it+element] = (acc_t)temp_data[element] * scale; + } else { + elements[i][it + element] = -std::numeric_limits::infinity(); + } + } + } else { + #pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + elements[i][it + element] = -std::numeric_limits::infinity(); + } + } + } + } + + // compute max_value + acc_t max_value[WARP_BATCH]; + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + max_value[i] = elements[i][0]; + #pragma unroll + for (int it = 1; it < WARP_ITERATIONS; ++it) { + max_value[i] = (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it]; + } + } + warp_reduce(max_value); + + acc_t sum[WARP_BATCH] { 0.0f }; + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + #pragma unroll + for (int it = 0; it < WARP_ITERATIONS; ++it) { + if (it < warp_iteration_limit) { + elements[i][it] = std::exp((elements[i][it] - max_value[i])); + sum[i] += elements[i][it]; + } + } + } + warp_reduce(sum); + + // store result + output_t out[ELEMENTS_PER_LDG_STG]; + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + if (i >= local_batches) + break; + #pragma unroll + for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { + int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; + + if (element_index < local_seq) { + + #pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + if (element_index + element < local_seq) { + out[element] = elements[i][it + element] / sum[i]; + } else { + out[element] = 0; + } + } + copy_vector(dst + i * element_count * stride + it * WARP_SIZE, out); + } else if (element_index < element_count) { + copy_zero_vector(dst + i * element_count * stride + it * WARP_SIZE); + } else { + break; + } + } + } +} + +template +__global__ void scaled_upper_triang_masked_softmax_warp_backward( + output_t *gradInput, + input_t *grad, + const input_t *output, + acc_t scale, + int micro_batch_size, + int stride, + int element_count) +{ + // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and + // warp_size of method warp_softmax_backward_kernel. + constexpr int next_power_of_two = 1 << log2_elements; + constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; + constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; + constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; + constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4; + + int first_batch = (blockDim.y * blockIdx.y + threadIdx.y) * gridDim.x * WARP_BATCH + blockIdx.x; + int local_seq = blockIdx.x + 1; + + // micro_batch_size might not be a multiple of WARP_BATCH. Check how + // many batches have to computed within this WARP. + int local_batches = micro_batch_size - first_batch; + if (local_batches > WARP_BATCH) + local_batches = WARP_BATCH; + + // there might be multiple batches per warp. compute the index within the batch + int local_idx = threadIdx.x; + + // the first element to process by the current thread + int thread_offset = first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx; + grad += thread_offset; + output += thread_offset; + gradInput += thread_offset; + + // load data from global memory + acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS] { 0.0f }; + acc_t output_reg[WARP_BATCH][WARP_ITERATIONS] { 0.0f }; + input_t temp_grad[ELEMENTS_PER_LDG_STG]; + input_t temp_output[ELEMENTS_PER_LDG_STG]; + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + int batch_element_count = (i >= local_batches) ? 0 : local_seq; + + #pragma unroll + for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { + int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; + if (element_index < batch_element_count) { + copy_vector(temp_grad, grad + i * element_count * stride + it * WARP_SIZE); + copy_vector(temp_output, output + i * element_count * stride + it * WARP_SIZE); + + #pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + if (element_index + element < batch_element_count) { + output_reg[i][it + element] = (acc_t)temp_output[element]; + } + } + #pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + if (element_index + element < batch_element_count) { + grad_reg[i][it + element] = (acc_t)temp_grad[element] * output_reg[i][it + element]; + } + } + } + } + } + + acc_t sum[WARP_BATCH]; + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + sum[i] = grad_reg[i][0]; + #pragma unroll + for (int it = 1; it < WARP_ITERATIONS; ++it) { + sum[i] += grad_reg[i][it]; + } + } + warp_reduce(sum); + + // store result + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + if (i >= local_batches) + break; + #pragma unroll + for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { + int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; + if (element_index < element_count) { + // compute gradients + output_t out[ELEMENTS_PER_LDG_STG]; + #pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + out[element] = (output_t)(scale * (grad_reg[i][it + element] - output_reg[i][it + element] * sum[i])); + } + copy_vector(gradInput + i * element_count * stride + it * WARP_SIZE, out); + } + } + } +} + +} // end of anonymous namespace + +template +void dispatch_scaled_upper_triang_masked_softmax_forward( + output_t *dst, + const input_t *src, + const input_t scale, + int softmax_elements, + int softmax_elements_stride, + int attn_batches) +{ + TORCH_INTERNAL_ASSERT(softmax_elements >= 0 && softmax_elements <= 2048 ); + if (softmax_elements == 0) { + return; + } else { + int log2_elements = log2_ceil(softmax_elements); + const int next_power_of_two = 1 << log2_elements; + int seq_len = softmax_elements; + int batch_count = attn_batches * seq_len; + + // This value must match the WARP_SIZE constexpr value computed inside softmax_warp_forward. + int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; + + // This value must match the WARP_BATCH constexpr value computed inside softmax_warp_forward. + int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; + + // use 128 threads per block to maximimize gpu utilization + constexpr int threads_per_block = 128; + + int warps_per_block = (threads_per_block / warp_size); + int batches_per_block = warps_per_block * batches_per_warp; + TORCH_INTERNAL_ASSERT(attn_batches % batches_per_block == 0); + + int blocks_per_seq = attn_batches / batches_per_block; + dim3 blocks(seq_len, blocks_per_seq, 1); + dim3 threads(warp_size, warps_per_block, 1); + // Launch code would be more elegant if C++ supported FOR CONSTEXPR + switch (log2_elements) { + case 0: // 1 + scaled_upper_triang_masked_softmax_warp_forward + <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 1: // 2 + scaled_upper_triang_masked_softmax_warp_forward + <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 2: // 4 + scaled_upper_triang_masked_softmax_warp_forward + <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 3: // 8 + scaled_upper_triang_masked_softmax_warp_forward + <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 4: // 16 + scaled_upper_triang_masked_softmax_warp_forward + <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 5: // 32 + scaled_upper_triang_masked_softmax_warp_forward + <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 6: // 64 + scaled_upper_triang_masked_softmax_warp_forward + <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 7: // 128 + scaled_upper_triang_masked_softmax_warp_forward + <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 8: // 256 + scaled_upper_triang_masked_softmax_warp_forward + <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 9: // 512 + scaled_upper_triang_masked_softmax_warp_forward + <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 10: // 1024 + scaled_upper_triang_masked_softmax_warp_forward + <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 11: // 2048 + scaled_upper_triang_masked_softmax_warp_forward + <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + default: + break; + } + } +} + +template +void dispatch_scaled_upper_triang_masked_softmax_backward( + output_t *grad_input, + input_t *grad, + const input_t *output, + const acc_t scale, + int softmax_elements, + int softmax_elements_stride, + int attn_batches) +{ + TORCH_INTERNAL_ASSERT( softmax_elements >= 0 && softmax_elements <= 2048 ); + if (softmax_elements == 0) { + return; + } else { + int log2_elements = log2_ceil(softmax_elements); + const int next_power_of_two = 1 << log2_elements; + int seq_len = softmax_elements; + int batch_count = attn_batches * seq_len; + + // This value must match the WARP_SIZE constexpr value computed inside softmax_warp_backward. + int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; + + // This value must match the WARP_BATCH constexpr value computed inside softmax_warp_backward. + int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; + + // use 128 threads per block to maximimize gpu utilization + constexpr int threads_per_block = 128; + + int warps_per_block = (threads_per_block / warp_size); + int batches_per_block = warps_per_block * batches_per_warp; + TORCH_INTERNAL_ASSERT(attn_batches % batches_per_block == 0); + + int blocks_per_seq = attn_batches / batches_per_block; + dim3 blocks(seq_len, blocks_per_seq, 1); + dim3 threads(warp_size, warps_per_block, 1); + // Launch code would be more elegant if C++ supported FOR CONSTEXPR + switch (log2_elements) { + case 0: // 1 + scaled_upper_triang_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 1: // 2 + scaled_upper_triang_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 2: // 4 + scaled_upper_triang_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 3: // 8 + scaled_upper_triang_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 4: // 16 + scaled_upper_triang_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 5: // 32 + scaled_upper_triang_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 6: // 64 + scaled_upper_triang_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 7: // 128 + scaled_upper_triang_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 8: // 256 + scaled_upper_triang_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 9: // 512 + scaled_upper_triang_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 10: // 1024 + scaled_upper_triang_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 11: // 2048 + scaled_upper_triang_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + default: + break; + } + } +} diff --git a/nemo/collections/nlp/modules/common/megatron/fused_kernels/scaled_upper_triang_masked_softmax_cuda.cu b/nemo/collections/nlp/modules/common/megatron/fused_kernels/scaled_upper_triang_masked_softmax_cuda.cu new file mode 100644 index 000000000..5efc3d412 --- /dev/null +++ b/nemo/collections/nlp/modules/common/megatron/fused_kernels/scaled_upper_triang_masked_softmax_cuda.cu @@ -0,0 +1,98 @@ +/* coding=utf-8 + * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include +#include +#include +#include "scaled_upper_triang_masked_softmax.h" +#include "type_shim.h" + +namespace multihead_attn { +namespace fused_softmax { +namespace scaled_upper_triang_masked_softmax { + +torch::Tensor fwd_cuda( + torch::Tensor const& input, + float scale_factor) +{ + // input is a 3d tensor with dimensions [attn_batches, seq_len, seq_len] + const int attn_batches = input.size(0); + const int seq_len = input.size(1); + TORCH_INTERNAL_ASSERT(seq_len <= 2048); + + // Output + auto act_options = input.options().requires_grad(false); + torch::Tensor softmax_results = + torch::empty({attn_batches, seq_len, seq_len}, act_options); + + // Softmax Intermediate Result Ptr + void* input_ptr = static_cast(input.data_ptr()); + void* softmax_results_ptr = static_cast(softmax_results.data_ptr()); + + DISPATCH_HALF_AND_BFLOAT( + input.scalar_type(), + "dispatch_scaled_upper_triang_masked_softmax_forward", + dispatch_scaled_upper_triang_masked_softmax_forward( + reinterpret_cast(softmax_results_ptr), + reinterpret_cast(input_ptr), + scale_factor, + seq_len, + seq_len, + attn_batches); + ); + return softmax_results; +} + + +torch::Tensor bwd_cuda( + torch::Tensor const& output_grads_, + torch::Tensor const& softmax_results_, + float scale_factor) { + + auto output_grads = output_grads_.contiguous(); + auto softmax_results = softmax_results_.contiguous(); + + //output grads is a 3d tensor with dimensions [attn_batches, seq_len, seq_len] + const int attn_batches = output_grads.size(0); + const int seq_len = output_grads.size(1); + TORCH_INTERNAL_ASSERT(output_grads.size(1) == output_grads.size(2)); + + void* output_grads_ptr = static_cast(output_grads.data_ptr()); + + //Softmax Grad + DISPATCH_HALF_AND_BFLOAT( + output_grads_.scalar_type(), + "dispatch_scaled_upper_triang_masked_softmax_backward", + dispatch_scaled_upper_triang_masked_softmax_backward( + reinterpret_cast(output_grads_ptr), + reinterpret_cast(output_grads_ptr), + reinterpret_cast(softmax_results.data_ptr()), + scale_factor, + seq_len, + seq_len, + attn_batches); + ); + + //backward pass is completely in-place + return output_grads; +} +} +} +} diff --git a/nemo/collections/nlp/modules/common/megatron/fused_kernels/type_shim.h b/nemo/collections/nlp/modules/common/megatron/fused_kernels/type_shim.h new file mode 100644 index 000000000..6437dcc7c --- /dev/null +++ b/nemo/collections/nlp/modules/common/megatron/fused_kernels/type_shim.h @@ -0,0 +1,91 @@ +/* coding=utf-8 + * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +#include +#include "compat.h" + + +#define DISPATCH_HALF_AND_BFLOAT(TYPE, NAME, ...) \ + switch(TYPE) \ + { \ + case at::ScalarType::Half: \ + { \ + using scalar_t = at::Half; \ + __VA_ARGS__; \ + break; \ + } \ + case at::ScalarType::BFloat16: \ + { \ + using scalar_t = at::BFloat16; \ + __VA_ARGS__; \ + break; \ + } \ + default: \ + AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ + } + + + +#define DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(TYPEIN, TYPEOUT, NAME, ...) \ + switch(TYPEIN) \ + { \ + case at::ScalarType::Float: \ + { \ + using scalar_t_in = float; \ + switch(TYPEOUT) \ + { \ + case at::ScalarType::Float: \ + { \ + using scalar_t_out = float; \ + __VA_ARGS__; \ + break; \ + } \ + case at::ScalarType::Half: \ + { \ + using scalar_t_out = at::Half; \ + __VA_ARGS__; \ + break; \ + } \ + case at::ScalarType::BFloat16: \ + { \ + using scalar_t_out = at::BFloat16; \ + __VA_ARGS__; \ + break; \ + } \ + default: \ + AT_ERROR(#NAME, " not implemented for '", toString(TYPEOUT), "'"); \ + } \ + break; \ + } \ + case at::ScalarType::Half: \ + { \ + using scalar_t_in = at::Half; \ + using scalar_t_out = at::Half; \ + __VA_ARGS__; \ + break; \ + } \ + case at::ScalarType::BFloat16: \ + { \ + using scalar_t_in = at::BFloat16; \ + using scalar_t_out = at::BFloat16; \ + __VA_ARGS__; \ + break; \ + } \ + default: \ + AT_ERROR(#NAME, " not implemented for '", toString(TYPEIN), "'"); \ + } + diff --git a/nemo/collections/nlp/modules/common/megatron/fused_softmax.py b/nemo/collections/nlp/modules/common/megatron/fused_softmax.py new file mode 100644 index 000000000..316563154 --- /dev/null +++ b/nemo/collections/nlp/modules/common/megatron/fused_softmax.py @@ -0,0 +1,211 @@ +# coding=utf-8 +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import torch +import torch.nn as nn + +from nemo.collections.nlp.modules.common.megatron.enums import AttnMaskType +from nemo.collections.nlp.modules.common.megatron.utils import AutocastModuleWrapper + + +class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function): + """ + Fused operation which performs following three operations in sequence + 1. Scale the tensor. + 2. Apply upper triangular mask (typically used in gpt models). + 3. Perform softmax. + """ + + @staticmethod + def forward(ctx, inputs, scale): + import scaled_upper_triang_masked_softmax_cuda + + scale_t = torch.tensor([scale]) + softmax_results = scaled_upper_triang_masked_softmax_cuda.forward(inputs, scale_t[0]) + + ctx.save_for_backward(softmax_results, scale_t) + return softmax_results + + @staticmethod + def backward(ctx, output_grads): + import scaled_upper_triang_masked_softmax_cuda + + softmax_results, scale_t = ctx.saved_tensors + input_grads = scaled_upper_triang_masked_softmax_cuda.backward(output_grads, softmax_results, scale_t[0]) + + return input_grads, None + + +class FusedScaledUpperTriangMaskedSoftmax(AutocastModuleWrapper): + def __init__(self, fp16=False, bf16=True): + super(FusedScaledUpperTriangMaskedSoftmax, self).__init__(fp16, bf16) + + self.func = ScaledUpperTriangMaskedSoftmax + + def forward(self, inputs, scale): + return self.autocast_forward(inputs, scale) + + +class ScaledMaskedSoftmax(torch.autograd.Function): + """ + Fused operation which performs following three operations in sequence + 1. Scale the tensor. + 2. Apply the mask. + 3. Perform softmax. + """ + + @staticmethod + def forward(ctx, inputs, mask, scale): + import scaled_masked_softmax_cuda + + scale_t = torch.tensor([scale]) + + softmax_results = scaled_masked_softmax_cuda.forward(inputs, mask, scale_t[0]) + ctx.save_for_backward(softmax_results, scale_t) + return softmax_results + + @staticmethod + def backward(ctx, output_grads): + import scaled_masked_softmax_cuda + + softmax_results, scale_t = ctx.saved_tensors + + input_grads = scaled_masked_softmax_cuda.backward(output_grads, softmax_results, scale_t[0]) + return input_grads, None, None + + +class FusedScaledMaskedSoftmax(AutocastModuleWrapper): + def __init__(self, fp16=False, bf16=True): + super(FusedScaledMaskedSoftmax, self).__init__(fp16, bf16) + + self.fused_func = ScaledMaskedSoftmax + + def forward(self, inputs, mask, scale): + return self.autocast_forward(inputs, mask, scale) + + +class FusedScaleMaskSoftmax(nn.Module): + """ + fused operation: scaling + mask + softmax + + Arguments: + input_in_fp16: flag to indicate if input in fp16 data format. + input_in_bf16: flag to indicate if input in bf16 data format. + attn_mask_type: attention mask type (pad or causal) + scaled_masked_softmax_fusion: flag to indicate user want to use softmax fusion + mask_func: mask function to be applied. + softmax_in_fp32: if true, softmax in performed at fp32 precision. + scale: scaling factor used in input tensor scaling. + """ + + def __init__( + self, + input_in_fp16, + input_in_bf16, + attn_mask_type, + scaled_masked_softmax_fusion, + mask_func, + softmax_in_fp32, + scale, + ): + super(FusedScaleMaskSoftmax, self).__init__() + self.input_in_fp16 = input_in_fp16 + self.input_in_bf16 = input_in_bf16 + assert not ( + self.input_in_fp16 and self.input_in_bf16 + ), "both fp16 and bf16 flags cannot be active at the same time." + self.input_in_float16 = self.input_in_fp16 or self.input_in_bf16 + self.attn_mask_type = attn_mask_type + self.scaled_masked_softmax_fusion = scaled_masked_softmax_fusion + self.mask_func = mask_func + self.softmax_in_fp32 = softmax_in_fp32 + self.scale = scale + + assert self.scale is None or softmax_in_fp32, "softmax should be in fp32 when scaled" + + self.fused_scaled_uppper_triang_masked_softmax = FusedScaledUpperTriangMaskedSoftmax( + input_in_fp16, input_in_bf16 + ) + self.fused_scaled_masked_softmax = ScaledMaskedSoftmax(input_in_fp16, input_in_bf16) + + def forward(self, input, mask): + # [b, np, sq, sk] + assert input.dim() == 4 + + if self.is_kernel_available(mask, *input.size()): + return self.forward_fused_softmax(input, mask) + else: + return self.forward_torch_softmax(input, mask) + + def is_kernel_available(self, mask, b, np, sq, sk): + attn_batches = b * np + + if ( + self.scaled_masked_softmax_fusion # user want to fuse + and self.input_in_float16 # input must be fp16 + and mask is not None # mask tensor must not be None + and 16 < sk <= 2048 # sk must be 16 ~ 2048 + and sq % 4 == 0 # sq must be divisor of 4 + and attn_batches % 4 == 0 # np * b must be divisor of 4 + ): + if 0 <= sk <= 2048: + batch_per_block = self.get_batch_per_block(sq, sk, b, np) + + if self.attn_mask_type == AttnMaskType.causal: + if attn_batches % batch_per_block == 0: + return True + else: + if sq % batch_per_block == 0: + return True + return False + + def forward_fused_softmax(self, input, mask): + b, np, sq, sk = input.size() + scale = self.scale if self.scale is not None else 1.0 + + if self.attn_mask_type == AttnMaskType.causal: + assert sq == sk, "causal mask is only for self attention" + + # input is 3D tensor (attn_batches, sq, sk) + input = input.view(-1, sq, sk) + probs = self.fused_scaled_uppper_triang_masked_softmax(input, scale) + return probs.view(b, np, sq, sk) + else: + # input is 4D tensor (b, np, sq, sk) + return self.fused_scaled_masked_softmax(input, mask, scale) + + def forward_torch_softmax(self, input, mask): + if self.input_in_float16 and self.softmax_in_fp32: + input = input.float() + + if self.scale is not None: + input = input * self.scale + mask_output = self.mask_func(input, mask) if mask is not None else input + probs = torch.nn.Softmax(dim=-1)(mask_output) + + if self.input_in_float16 and self.softmax_in_fp32: + if self.input_in_fp16: + probs = probs.half() + else: + probs = probs.bfloat16() + + return probs + + @staticmethod + def get_batch_per_block(sq, sk, b, np): + import scaled_masked_softmax_cuda + + return scaled_masked_softmax_cuda.get_batch_per_block(sq, sk, b, np) diff --git a/nemo/collections/nlp/modules/common/megatron/language_model.py b/nemo/collections/nlp/modules/common/megatron/language_model.py new file mode 100644 index 000000000..0b9adb3a6 --- /dev/null +++ b/nemo/collections/nlp/modules/common/megatron/language_model.py @@ -0,0 +1,581 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Transformer based language model.""" + +import torch +import torch.nn.functional as F +from apex.transformer import parallel_state, tensor_parallel +from apex.transformer.enums import AttnMaskType, LayerType + +from nemo.collections.nlp.modules.common.megatron.module import MegatronModule +from nemo.collections.nlp.modules.common.megatron.transformer import ParallelTransformer +from nemo.collections.nlp.modules.common.megatron.utils import ( + get_linear_layer, + init_method_normal, + scaled_init_method_normal, +) + + +def parallel_lm_logits(input_, word_embeddings_weight, parallel_output, bias=None): + """LM logits using word embedding weights.""" + # Parallel logits. + input_parallel = tensor_parallel.copy_to_tensor_model_parallel_region(input_) + # Matrix multiply. + if bias is None: + logits_parallel = F.linear(input_parallel, word_embeddings_weight) + else: + logits_parallel = F.linear(input_parallel, word_embeddings_weight, bias) + # Gather if needed. + if parallel_output: + return logits_parallel + + return tensor_parallel.gather_from_tensor_model_parallel_region(logits_parallel) + + +def get_language_model( + hidden_size, + ffn_hidden_size, + num_layers, + max_position_embeddings, + num_tokentypes, + add_pooler, + vocab_size, + num_attention_heads, + encoder_attn_mask_type, + apply_query_key_layer_scaling=True, + kv_channels=None, + init_method=None, + scaled_init_method=None, + add_decoder=False, + decoder_attn_mask_type=AttnMaskType.causal, + pre_process=True, + post_process=True, + init_method_std=0.02, + use_cpu_initialization=False, + hidden_dropout=0.1, + fused_fp16=False, + fused_bf16=False, + fp32_residual_connection=False, + activations_checkpoint_method=None, + activations_checkpoint_num_layers=1, + layernorm_epsilon=1e-5, + bias_gelu_fusion=True, + openai_gelu=False, + onnx_safe=False, +): + """Build language model and return along with the key to save.""" + + assert not (fused_fp16 and fused_bf16), "both fused_fp16 and fused_bf16 flags cannot be True at the same time." + + if kv_channels is None: + assert ( + hidden_size % num_attention_heads == 0 + ), 'hidden_size must be divisible by num_attention_heads if kv_channels is None' + kv_channels = hidden_size // num_attention_heads + + if init_method is None: + init_method = init_method_normal(init_method_std) + + if scaled_init_method is None: + scaled_init_method = scaled_init_method_normal(init_method_std, num_layers) + + # Language model. + language_model = TransformerLanguageModel( + init_method=init_method, + output_layer_init_method=scaled_init_method, + encoder_attn_mask_type=encoder_attn_mask_type, + num_tokentypes=num_tokentypes, + vocab_size=vocab_size, + max_position_embeddings=max_position_embeddings, + hidden_size=hidden_size, + num_layers=num_layers, + num_attention_heads=num_attention_heads, + apply_query_key_layer_scaling=apply_query_key_layer_scaling, + kv_channels=kv_channels, + ffn_hidden_size=ffn_hidden_size, + add_decoder=add_decoder, + decoder_attn_mask_type=decoder_attn_mask_type, + add_pooler=add_pooler, + pre_process=pre_process, + post_process=post_process, + use_cpu_initialization=use_cpu_initialization, + hidden_dropout=hidden_dropout, + fused_fp16=fused_fp16, + fused_bf16=fused_bf16, + fp32_residual_connection=fp32_residual_connection, + activations_checkpoint_method=activations_checkpoint_method, + activations_checkpoint_num_layers=activations_checkpoint_num_layers, + layernorm_epsilon=layernorm_epsilon, + bias_gelu_fusion=bias_gelu_fusion, + openai_gelu=openai_gelu, + onnx_safe=onnx_safe, + ) + # key used for checkpoints. + language_model_key = 'language_model' + + return language_model, language_model_key + + +class Pooler(MegatronModule): + """Pooler layer. + + Pool hidden states of a specific token (for example start of the + sequence) and add a linear transformation followed by a tanh. + + Arguments: + hidden_size: hidden size + init_method: weight initialization method for the linear layer. + bias is set to zero. + """ + + def __init__(self, hidden_size, init_method): + super(Pooler, self).__init__() + self.dense = get_linear_layer(hidden_size, hidden_size, init_method) + + def forward(self, hidden_states, sequence_index=0): + # hidden_states: [b, s, h] + # sequence_index: index of the token to pool. + pooled = hidden_states[:, sequence_index, :] + pooled = self.dense(pooled) + pooled = torch.tanh(pooled) + return pooled + + +class Embedding(MegatronModule): + """Language model embeddings. + + Arguments: + hidden_size: hidden size + vocab_size: vocabulary size + max_sequence_length: maximum size of sequence. This + is used for positional embedding + embedding_dropout_prob: dropout probability for embeddings + init_method: weight initialization method + num_tokentypes: size of the token-type embeddings. 0 value + will ignore this embedding + """ + + def __init__( + self, + hidden_size, + vocab_size, + max_sequence_length, + embedding_dropout_prob, + init_method, + num_tokentypes=0, + use_cpu_initialization=False, + ): + super(Embedding, self).__init__() + + self.hidden_size = hidden_size + self.init_method = init_method + self.num_tokentypes = num_tokentypes + + # Word embeddings (parallel). + self.word_embeddings = tensor_parallel.VocabParallelEmbedding( + vocab_size, self.hidden_size, init_method=self.init_method, use_cpu_initialization=use_cpu_initialization, + ) + self._word_embeddings_key = 'word_embeddings' + + # Position embedding (serial). + self.position_embeddings = torch.nn.Embedding(max_sequence_length, self.hidden_size) + self._position_embeddings_key = 'position_embeddings' + # Initialize the position embeddings. + self.init_method(self.position_embeddings.weight) + + # Token type embedding. + # Add this as an optional field that can be added through + # method call so we can load a pretrain model without + # token types and add them as needed. + self._tokentype_embeddings_key = 'tokentype_embeddings' + if self.num_tokentypes > 0: + self.tokentype_embeddings = torch.nn.Embedding(self.num_tokentypes, self.hidden_size) + # Initialize the token-type embeddings. + self.init_method(self.tokentype_embeddings.weight) + else: + self.tokentype_embeddings = None + + # Embeddings dropout + self.embedding_dropout = torch.nn.Dropout(embedding_dropout_prob) + + def add_tokentype_embeddings(self, num_tokentypes): + """Add token-type embedding. This function is provided so we can add + token-type embeddings in case the pretrained model does not have it. + This allows us to load the model normally and then add this embedding. + """ + if self.tokentype_embeddings is not None: + raise Exception('tokentype embeddings is already initialized') + if torch.distributed.get_rank() == 0: + print('adding embedding for {} tokentypes'.format(num_tokentypes), flush=True) + self.num_tokentypes = num_tokentypes + self.tokentype_embeddings = torch.nn.Embedding(num_tokentypes, self.hidden_size) + # Initialize the token-type embeddings. + self.init_method(self.tokentype_embeddings.weight) + + def forward(self, input_ids, position_ids, tokentype_ids=None): + # Embeddings. + words_embeddings = self.word_embeddings(input_ids) + position_embeddings = self.position_embeddings(position_ids) + embeddings = words_embeddings + position_embeddings + if tokentype_ids is not None: + assert self.tokentype_embeddings is not None + embeddings = embeddings + self.tokentype_embeddings(tokentype_ids) + else: + assert self.tokentype_embeddings is None + + # Dropout. + embeddings = self.embedding_dropout(embeddings) + + return embeddings + + def state_dict_for_save_checkpoint(self, destination=None, prefix='', keep_vars=False): + """For easy load.""" + + state_dict_ = {} + state_dict_[self._word_embeddings_key] = self.word_embeddings.state_dict(destination, prefix, keep_vars) + state_dict_[self._position_embeddings_key] = self.position_embeddings.state_dict( + destination, prefix, keep_vars + ) + if self.num_tokentypes > 0: + state_dict_[self._tokentype_embeddings_key] = self.tokentype_embeddings.state_dict( + destination, prefix, keep_vars + ) + + return state_dict_ + + def load_state_dict(self, state_dict, strict=True): + """Customized load.""" + + # Word embedding. + if self._word_embeddings_key in state_dict: + state_dict_ = state_dict[self._word_embeddings_key] + else: + # for backward compatibility. + state_dict_ = {} + for key in state_dict.keys(): + if 'word_embeddings' in key: + state_dict_[key.split('word_embeddings.')[1]] = state_dict[key] + self.word_embeddings.load_state_dict(state_dict_, strict=strict) + + # Position embedding. + if self._position_embeddings_key in state_dict: + state_dict_ = state_dict[self._position_embeddings_key] + else: + # for backward compatibility. + state_dict_ = {} + for key in state_dict.keys(): + if 'position_embeddings' in key: + state_dict_[key.split('position_embeddings.')[1]] = state_dict[key] + self.position_embeddings.load_state_dict(state_dict_, strict=strict) + + # Tokentype embedding. + if self.num_tokentypes > 0: + state_dict_ = {} + if self._tokentype_embeddings_key in state_dict: + state_dict_ = state_dict[self._tokentype_embeddings_key] + else: + # for backward compatibility. + for key in state_dict.keys(): + if 'tokentype_embeddings' in key: + state_dict_[key.split('tokentype_embeddings.')[1]] = state_dict[key] + if len(state_dict_.keys()) > 0: + self.tokentype_embeddings.load_state_dict(state_dict_, strict=strict) + else: + print( + '***WARNING*** expected tokentype embeddings in the ' 'checkpoint but could not find it', + flush=True, + ) + + +class TransformerLanguageModel(MegatronModule): + """Transformer language model. + + Arguments: + transformer_hparams: transformer hyperparameters + vocab_size: vocabulary size + max_sequence_length: maximum size of sequence. This + is used for positional embedding + embedding_dropout_prob: dropout probability for embeddings + num_tokentypes: size of the token-type embeddings. 0 value + will ignore this embedding + """ + + def __init__( + self, + init_method, + output_layer_init_method, + encoder_attn_mask_type, + vocab_size, + max_position_embeddings, + hidden_size, + ffn_hidden_size, + num_layers, + num_tokentypes, + num_attention_heads, + apply_query_key_layer_scaling=True, + kv_channels=None, + add_decoder=False, + decoder_attn_mask_type=AttnMaskType.causal, + add_pooler=False, + pre_process=True, + post_process=True, + use_cpu_initialization=False, + hidden_dropout=0.1, + fused_fp16=False, + fused_bf16=False, + fp32_residual_connection=False, + activations_checkpoint_method=None, + activations_checkpoint_num_layers=1, + layernorm_epsilon=1e-5, + bias_gelu_fusion=True, + openai_gelu=False, + onnx_safe=False, + ): + super(TransformerLanguageModel, self).__init__() + + self.pre_process = pre_process + self.post_process = post_process + self.hidden_size = hidden_size + self.num_layers = num_layers + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.num_tokentypes = num_tokentypes + self.init_method = init_method + self.encoder_attn_mask_type = encoder_attn_mask_type + self.add_decoder = add_decoder + self.decoder_attn_mask_type = decoder_attn_mask_type + self.add_pooler = add_pooler + self.hidden_dropout = hidden_dropout + self.output_layer_init_method = output_layer_init_method + + if kv_channels is None: + + assert ( + hidden_size % num_attention_heads == 0 + ), 'hidden_size must be divisible by num_attention_heads if kv_channels is None' + kv_channels = hidden_size // num_attention_heads + + # Embeddings. + if self.pre_process: + self.embedding = Embedding( + hidden_size=self.hidden_size, + vocab_size=self.vocab_size, + max_sequence_length=self.max_position_embeddings, + init_method=self.init_method, + num_tokentypes=self.num_tokentypes, + use_cpu_initialization=use_cpu_initialization, + embedding_dropout_prob=self.hidden_dropout, + ) + self._embedding_key = 'embedding' + + # Transformer. + self.encoder = ParallelTransformer( + init_method=self.init_method, + output_layer_init_method=self.output_layer_init_method, + num_layers=self.num_layers, + hidden_size=self.hidden_size, + num_attention_heads=num_attention_heads, + apply_query_key_layer_scaling=apply_query_key_layer_scaling, + kv_channels=kv_channels, + ffn_hidden_size=ffn_hidden_size, + self_attn_mask_type=self.encoder_attn_mask_type, + pre_process=self.pre_process, + post_process=self.post_process, + fused_fp16=fused_fp16, + fused_bf16=fused_bf16, + fp32_residual_connection=fp32_residual_connection, + activations_checkpoint_method=activations_checkpoint_method, + activations_checkpoint_num_layers=activations_checkpoint_num_layers, + layernorm_epsilon=layernorm_epsilon, + hidden_dropout=hidden_dropout, + use_cpu_initialization=use_cpu_initialization, + bias_gelu_fusion=bias_gelu_fusion, + openai_gelu=openai_gelu, + onnx_safe=onnx_safe, + ) + self._encoder_key = 'encoder' + + # Decoder + if self.add_decoder: + assert ( + parallel_state.get_pipeline_model_parallel_world_size() == 1 + ), 'pipeline parallelism is not supported in the presence of decoder' + self.decoder = ParallelTransformer( + layer_type=LayerType.decoder, + self_attn_mask_type=self.decoder_attn_mask_type, + init_method=self.init_method, + output_layer_init_method=self.output_layer_init_method, + num_layers=self.num_layers, + hidden_size=self.hidden_size, + num_attention_heads=num_attention_heads, + apply_query_key_layer_scaling=apply_query_key_layer_scaling, + kv_channels=kv_channels, + ffn_hidden_size=ffn_hidden_size, + pre_process=self.pre_process, + post_process=self.post_process, + fused_fp16=fused_fp16, + fused_bf16=fused_bf16, + fp32_residual_connection=fp32_residual_connection, + activations_checkpoint_method=activations_checkpoint_method, + activations_checkpoint_num_layers=activations_checkpoint_num_layers, + layernorm_epsilon=layernorm_epsilon, + hidden_dropout=hidden_dropout, + use_cpu_initialization=use_cpu_initialization, + bias_gelu_fusion=bias_gelu_fusion, + openai_gelu=openai_gelu, + onnx_safe=onnx_safe, + ) + self._decoder_key = 'decoder' + + if self.post_process: + # Pooler. + if self.add_pooler: + self.pooler = Pooler(self.hidden_size, self.init_method) + self._pooler_key = 'pooler' + + def set_input_tensor(self, input_tensor): + """ See megatron.model.transformer.set_input_tensor()""" + self.encoder.set_input_tensor(input_tensor) + + def forward( + self, + enc_input_ids, + enc_position_ids, + enc_attn_mask, + dec_input_ids=None, + dec_position_ids=None, + dec_attn_mask=None, + enc_dec_attn_mask=None, + tokentype_ids=None, + layer_past=None, + get_key_value=False, + pooling_sequence_index=0, + enc_hidden_states=None, + output_enc_hidden=False, + ): + # Embeddings. + if self.pre_process: + embedding_output = self.embedding(enc_input_ids, enc_position_ids, tokentype_ids=tokentype_ids) + encoder_input = embedding_output + else: + encoder_input = None + + # encoder. + if enc_hidden_states is None: + encoder_output = self.encoder( + encoder_input, enc_attn_mask, layer_past=layer_past, get_key_value=get_key_value + ) + else: + encoder_output = enc_hidden_states.to(encoder_input.dtype) + + if self.post_process: + if self.add_pooler: + pooled_output = self.pooler(encoder_output, pooling_sequence_index) + + # output_enc_hidden refers to when we just need the encoder's + # output. For example, it is helpful to compute + # similarity between two sequences by average pooling + if not self.add_decoder or output_enc_hidden: + if self.add_pooler and self.post_process: + return encoder_output, pooled_output + else: + return encoder_output + + # Decoder Embedding + dec_embedding_output = self.embedding(dec_input_ids, dec_position_ids) + # decoder + decoder_output = self.decoder( + dec_embedding_output, + dec_attn_mask, + layer_past=layer_past, + get_key_value=get_key_value, + encoder_output=encoder_output, + enc_dec_attn_mask=enc_dec_attn_mask, + ) + + if self.add_pooler and self.post_process: + return decoder_output, encoder_output, pooled_output + else: + return decoder_output, encoder_output + + def state_dict_for_save_checkpoint(self, destination=None, prefix='', keep_vars=False): + """For easy load.""" + + state_dict_ = {} + if self.pre_process: + state_dict_[self._embedding_key] = self.embedding.state_dict_for_save_checkpoint( + destination, prefix, keep_vars + ) + state_dict_[self._encoder_key] = self.encoder.state_dict_for_save_checkpoint(destination, prefix, keep_vars) + if self.post_process: + if self.add_pooler: + state_dict_[self._pooler_key] = self.pooler.state_dict_for_save_checkpoint( + destination, prefix, keep_vars + ) + if self.add_decoder: + state_dict_[self._decoder_key] = self.decoder.state_dict_for_save_checkpoint( + destination, prefix, keep_vars + ) + + return state_dict_ + + def load_state_dict(self, state_dict, strict=True): + """Customized load.""" + + # Embedding. + if self.pre_process: + if self._embedding_key in state_dict: + state_dict_ = state_dict[self._embedding_key] + else: + # for backward compatibility. + state_dict_ = {} + for key in state_dict.keys(): + if '_embeddings' in key: + state_dict_[key] = state_dict[key] + self.embedding.load_state_dict(state_dict_, strict=strict) + + # Encoder. + if self._encoder_key in state_dict: + state_dict_ = state_dict[self._encoder_key] + # for backward compatibility. + elif 'transformer' in state_dict: + state_dict_ = state_dict['transformer'] + else: + # for backward compatibility. + state_dict_ = {} + for key in state_dict.keys(): + if 'transformer.' in key: + state_dict_[key.split('transformer.')[1]] = state_dict[key] + + # for backward compatibility. + state_dict_self_attention = {} + for key in state_dict_.keys(): + if '.attention.' in key: + state_dict_self_attention[key.replace(".attention.", ".self_attention.")] = state_dict_[key] + else: + state_dict_self_attention[key] = state_dict_[key] + state_dict_ = state_dict_self_attention + + self.encoder.load_state_dict(state_dict_, strict=strict) + + if self.post_process: + # pooler + if self.add_pooler: + assert 'pooler' in state_dict, 'could not find data for pooler in the checkpoint' + self.pooler.load_state_dict(state_dict[self._pooler_key], strict=strict) + # decoder + if self.add_decoder: + assert 'decoder' in state_dict, 'could not find data for pooler in the checkpoint' + self.decoder.load_state_dict(state_dict[self._decoder_key], strict=strict) diff --git a/nemo/collections/nlp/modules/common/megatron/megatron_bert.py b/nemo/collections/nlp/modules/common/megatron/megatron_bert.py index 87a604e8e..962333296 100644 --- a/nemo/collections/nlp/modules/common/megatron/megatron_bert.py +++ b/nemo/collections/nlp/modules/common/megatron/megatron_bert.py @@ -17,11 +17,8 @@ import os import torch -from megatron import get_args, initialize_megatron -from megatron.checkpointing import set_checkpoint_version -from megatron.model import get_language_model -from megatron.model.bert_model import bert_attention_mask_func, bert_extended_attention_mask, bert_position_ids -from megatron.mpu import ( +from apex.transformer.enums import AttnMaskType +from apex.transformer.parallel_state import ( get_model_parallel_group, model_parallel_is_initialized, set_pipeline_model_parallel_rank, @@ -30,6 +27,7 @@ from megatron.mpu import ( from omegaconf import DictConfig, OmegaConf from nemo.collections.nlp.modules.common.bert_module import BertModule +from nemo.collections.nlp.modules.common.megatron.language_model import get_language_model from nemo.core.classes import typecheck from nemo.utils import logging from nemo.utils.app_state import AppState @@ -37,13 +35,6 @@ from nemo.utils.app_state import AppState __all__ = ['MegatronBertEncoder'] -def complete_lazy_init(self): - # finish megatron-lm initialization - if hasattr(self, "_lazy_init_fn") and self._lazy_init_fn is not None: - self._lazy_init_fn() - self._lazy_init_fn = None - - class MegatronBertEncoder(BertModule): """ MegatronBERT wraps around the Megatron Language model @@ -59,96 +50,47 @@ class MegatronBertEncoder(BertModule): super().__init__() - self._model_parallel_size = model_parallel_size - self._model_parallel_rank = model_parallel_rank - self._restore_path = None - self._app_state = None - self._model_name = model_name - - if 'vocab_size' in config: - self._vocab_size = config.pop('vocab_size') - else: - self._vocab_size = None - - self._hidden_size = config.get('hidden_size') - - if not os.path.exists(vocab_file): - raise ValueError(f'Vocab file not found at {vocab_file}') - - # convert config to dictionary - if isinstance(config, DictConfig): - config = OmegaConf.to_container(config) - config["vocab_file"] = vocab_file - config['tokenizer_type'] = 'BertWordPieceLowerCase' - config['lazy_mpu_init'] = True - config['onnx_safe'] = True - - num_tokentypes = config.pop('num_tokentypes', 2) - - # if 'model_parallel_size' in config: - if self._model_parallel_size is not None: - app_state = AppState() - self._app_state = app_state - - # must be set for model parallel megatron-lm - os.environ["WORLD_SIZE"] = str(app_state.world_size) - os.environ["RANK"] = str(self._model_parallel_rank) - - extra_args_provider = self._update_megatron_args(tensor_model_parallel_size=self._model_parallel_size) - - else: - extra_args_provider = self._update_megatron_args() - - # configure globals for megatron - set_pipeline_model_parallel_rank(0) # pipeline model parallelism not implemented in NeMo - set_pipeline_model_parallel_world_size(1) # pipeline model parallelism not implemented in NeMo - - # Initialize part of Megatron global state that is needed for its constructor. - # We set 'lazy_mpu_init' flag on to make Megatron do only the initialization that does not depend - # on ddp be initialized yet (and we don't want Megatron to initialize DDP itself either) - # and to return a hook for us to call after PTL has torch.distributed initialized. - # (or if no PTL in case of inference - then we'll initialize torch.distributed) - # We call and clear this hook on first call to forward() - self._lazy_init_fn = initialize_megatron( - extra_args_provider=extra_args_provider, args_defaults=config, ignore_unknown_args=True + raise ValueError( + f'megatron-lm bert has been deprecated in NeMo 1.5. Please use an earlier release of NeMo. Megatron bert support will be added back to NeMo in a future release.' ) - # read Megatron arguments back - args = get_args() - logging.info(f'Megatron-lm argparse args: {args}') + # self._model_parallel_size = model_parallel_size + # self._model_parallel_rank = model_parallel_rank + # self._restore_path = None + # self._app_state = None + # self._model_name = model_name - self.language_model, self._language_model_key = get_language_model( - attention_mask_func=bert_attention_mask_func, num_tokentypes=num_tokentypes, add_pooler=False - ) + # if 'vocab_size' in config: + # self._vocab_size = config.pop('vocab_size') + # else: + # self._vocab_size = None - self.config = OmegaConf.create(config) - # key used for checkpoints - self._hidden_size = self.language_model.hidden_size + # self._hidden_size = config.get('hidden_size') - def _update_megatron_args( - self, - micro_batch_size=1, - tensor_model_parallel_size=1, - scaled_masked_softmax_fusion=False, - bias_gelu_fusion=False, - bias_dropout_fusion=False, - ): - def extra_args_provider(parser): - parser.set_defaults(micro_batch_size=micro_batch_size) - parser.set_defaults(tensor_model_parallel_size=tensor_model_parallel_size) - parser.set_defaults(scaled_masked_softmax_fusion=scaled_masked_softmax_fusion) - parser.set_defaults(bias_gelu_fusion=bias_gelu_fusion) - parser.set_defaults(bias_dropout_fusion=bias_dropout_fusion) + # if not os.path.exists(vocab_file): + # raise ValueError(f'Vocab file not found at {vocab_file}') - return parser + # # convert config to dictionary + # if isinstance(config, DictConfig): + # config = OmegaConf.to_container(config) + # config["vocab_file"] = vocab_file + # config['tokenizer_type'] = 'BertWordPieceLowerCase' + # config['lazy_mpu_init'] = True + # config['onnx_safe'] = True - return extra_args_provider + # num_tokentypes = config.pop('num_tokentypes', 2) - def complete_lazy_init(self): - # finish megatron-lm initialization - if hasattr(self, "_lazy_init_fn") and self._lazy_init_fn is not None: - self._lazy_init_fn() - self._lazy_init_fn = None + # # configure globals for megatron + # set_pipeline_model_parallel_rank(0) # pipeline model parallelism not implemented in NeMo + # set_pipeline_model_parallel_world_size(1) # pipeline model parallelism not implemented in NeMo + + # self.language_model, self._language_model_key = get_language_model( + # encoder_attn_mask_type=AttnMaskType.padding, num_tokentypes=num_tokentypes, add_pooler=False + # ) + + # self.config = OmegaConf.create(config) + # # key used for checkpoints + # self._hidden_size = self.language_model.hidden_size @property def hidden_size(self): @@ -172,17 +114,14 @@ class MegatronBertEncoder(BertModule): @typecheck() def forward(self, input_ids, attention_mask, token_type_ids=None): - app_state = AppState() - if app_state.model_parallel_size is None: - self.complete_lazy_init() extended_attention_mask = bert_extended_attention_mask(attention_mask) position_ids = bert_position_ids(input_ids) sequence_output = self.language_model( - input_ids=input_ids, - position_ids=position_ids, - attention_mask=extended_attention_mask, + enc_input_ids=input_ids, + enc_position_ids=position_ids, + enc_attn_mask=extended_attention_mask, tokentype_ids=token_type_ids, ) return sequence_output @@ -196,13 +135,13 @@ class MegatronBertEncoder(BertModule): state_dict = torch.load(filename, map_location='cpu') if 'checkpoint_version' in state_dict: if state_dict['checkpoint_version'] is not None: - set_checkpoint_version(state_dict['checkpoint_version']) + set_megatron_checkpoint_version(state_dict['checkpoint_version']) logging.info( f"Megatron-lm checkpoint version found. Setting checkpoint_version to {state_dict['checkpoint_version']}." ) else: logging.warning('Megatron-lm checkpoint version not found. Setting checkpoint_version to 0.') - set_checkpoint_version(0) + set_megatron_checkpoint_version(0) # to load from Megatron pretrained checkpoint if 'model' in state_dict: self.language_model.load_state_dict(state_dict['model'][self._language_model_key]) @@ -233,3 +172,39 @@ class MegatronBertEncoder(BertModule): logging.info(f'torch.distributed not initialized yet. Will not restore model parallel checkpoint') else: logging.error(f'restore_path: {restore_path} must be a file or directory.') + + +def get_megatron_checkpoint_version(): + app_state = AppState() + return app_state._megatron_checkpoint_version + + +def set_megatron_checkpoint_version(version: int = None): + app_state = AppState() + app_state._megatron_checkpoint_version = version + + +def bert_extended_attention_mask(attention_mask): + # We create a 3D attention mask from a 2D tensor mask. + # [b, 1, s] + attention_mask_b1s = attention_mask.unsqueeze(1) + # [b, s, 1] + attention_mask_bs1 = attention_mask.unsqueeze(2) + # [b, s, s] + attention_mask_bss = attention_mask_b1s * attention_mask_bs1 + # [b, 1, s, s] + extended_attention_mask = attention_mask_bss.unsqueeze(1) + + # Convert attention mask to binary: + extended_attention_mask = extended_attention_mask < 0.5 + + return extended_attention_mask + + +def bert_position_ids(token_ids): + # Create position ids + seq_length = token_ids.size(1) + position_ids = torch.arange(seq_length, dtype=torch.long, device=token_ids.device) + position_ids = position_ids.unsqueeze(0).expand_as(token_ids) + + return position_ids diff --git a/nemo/collections/nlp/modules/common/megatron/megatron_init.py b/nemo/collections/nlp/modules/common/megatron/megatron_init.py new file mode 100644 index 000000000..0a410afde --- /dev/null +++ b/nemo/collections/nlp/modules/common/megatron/megatron_init.py @@ -0,0 +1,89 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import random + +import numpy as np +import torch +from apex.transformer import tensor_parallel +from apex.transformer.parallel_state import ( + get_pipeline_model_parallel_rank, + set_pipeline_model_parallel_rank, + set_pipeline_model_parallel_world_size, + set_tensor_model_parallel_rank, + set_tensor_model_parallel_world_size, +) + +from nemo.collections.nlp.modules.common.megatron.megatron_utils import compute_model_parallel_rank +from nemo.utils import AppState + + +def initialize_model_parallel_for_nemo( + world_size, global_rank, local_rank, tensor_model_parallel_size=1, seed=1234, +): + + # updating NeMo globals + app_state = AppState() + app_state.global_rank = global_rank + app_state.world_size = world_size + app_state.model_parallel_size = tensor_model_parallel_size + app_state.model_parallel_rank = compute_model_parallel_rank(local_rank, tensor_model_parallel_size) + + # update apex.mpu globals + set_tensor_model_parallel_world_size(tensor_model_parallel_size) + set_tensor_model_parallel_rank(app_state.model_parallel_rank) + + # pipeline model parallelism not implemented in NeMo yet + set_pipeline_model_parallel_rank(0) + set_pipeline_model_parallel_world_size(1) + + _set_random_seed(seed) + + app_state._is_megatron_initialized = True + + +def _set_random_seed(seed_): + """Set random seed for reproducability.""" + if seed_ is not None and seed_ > 0: + # Ensure that different pipeline MP stages get different seeds. + seed = seed_ + (100 * get_pipeline_model_parallel_rank()) + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + if torch.cuda.device_count() > 0: + tensor_parallel.model_parallel_cuda_manual_seed(seed) + else: + raise ValueError('Seed ({}) should be a positive integer.'.format(seed_)) + + +def set_jit_fusion_options(): + """Set PyTorch JIT layer fusion options.""" + # flags required to enable jit fusion kernels + TORCH_MAJOR = int(torch.__version__.split('.')[0]) + TORCH_MINOR = int(torch.__version__.split('.')[1]) + if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR >= 10): + # nvfuser + torch._C._jit_set_profiling_executor(True) + torch._C._jit_set_profiling_mode(True) + torch._C._jit_override_can_fuse_on_cpu(False) + torch._C._jit_override_can_fuse_on_gpu(False) + torch._C._jit_set_texpr_fuser_enabled(False) + torch._C._jit_set_nvfuser_enabled(True) + torch._C._debug_set_autodiff_subgraph_inlining(False) + else: + # legacy pytorch fuser + torch._C._jit_set_profiling_mode(False) + torch._C._jit_set_profiling_executor(False) + torch._C._jit_override_can_fuse_on_cpu(True) + torch._C._jit_override_can_fuse_on_gpu(True) diff --git a/nemo/collections/nlp/modules/common/megatron/megatron_utils.py b/nemo/collections/nlp/modules/common/megatron/megatron_utils.py index 4f2056fde..08c58a898 100644 --- a/nemo/collections/nlp/modules/common/megatron/megatron_utils.py +++ b/nemo/collections/nlp/modules/common/megatron/megatron_utils.py @@ -46,6 +46,14 @@ MEGATRON_CACHE = os.path.join(torch_home, "megatron") CONFIGS = {"345m": {"hidden_size": 1024, "num_attention_heads": 16, "num_layers": 24, "max_position_embeddings": 512}} MEGATRON_CONFIG_MAP = { + "megatron-gpt-345m": { + "config": CONFIGS["345m"], + "checkpoint": "models/nvidia/megatron_lm_345m/versions/v0.0/files/release/mp_rank_00/model_optim_rng.pt", + "vocab": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-vocab.json", + "merges_file": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-merges.txt", + "do_lower_case": False, + "tokenizer_name": "gpt2", + }, "megatron-bert-345m-uncased": { "config": CONFIGS["345m"], "checkpoint": "https://api.ngc.nvidia.com/v2/models/nvidia/megatron_bert_345m/versions/v0.0/files/release/mp_rank_00/model_optim_rng.pt", @@ -97,6 +105,7 @@ def get_megatron_lm_model( config_file: Optional[str] = None, checkpoint_file: Optional[str] = None, vocab_file: Optional[str] = None, + merges_file: Optional[str] = None, ) -> Tuple[MegatronBertEncoder, str]: """ Returns MegatronBertEncoder and a default or user specified path to the checkpoint file @@ -113,81 +122,87 @@ def get_megatron_lm_model( model: MegatronBertEncoder checkpoint_file: path to checkpoint file or directory """ - config = None - # get default config and checkpoint - if config_file: - with open(config_file) as f: - config = json.load(f) - # replace dashes with underscores in config keys - fixed_config = {} - for key in config.keys(): - fixed_key = key.replace("-", "_") - if fixed_key == 'max_seq_length': - fixed_key = 'max_position_embeddings' - fixed_config[fixed_key] = config[key] - # 'vocab_size" no longer used. - if 'vocab_size' in fixed_config: - fixed_config.pop('vocab_size') - config = fixed_config - elif config_dict: - config = config_dict - elif pretrained_model_name in get_megatron_lm_models_list(): - config = get_megatron_config(pretrained_model_name) - else: - raise ValueError(f"{pretrained_model_name} is not supported") - - if config is None: - raise ValueError(f"config_file or config_dict is required for {pretrained_model_name}") - - if not checkpoint_file: - checkpoint_file = get_megatron_checkpoint(pretrained_model_name) - - if not vocab_file: - vocab_file = get_megatron_vocab_file(pretrained_model_name) - - app_state = AppState() - if app_state.model_parallel_size is not None and app_state.model_parallel_rank is not None: - # model parallel already known from .nemo restore - model_parallel_size = app_state.model_parallel_size - model_parallel_rank = app_state.model_parallel_rank - elif os.path.isdir(checkpoint_file): - # starting training from megatron-lm checkpoint - mp_ranks = glob.glob(os.path.join(checkpoint_file, 'mp_rank*')) - model_parallel_size = len(mp_ranks) - app_state.model_parallel_size = model_parallel_size - logging.info( - ( - f'restore_path: {checkpoint_file} is a directory. ' - f'Assuming megatron model parallelism with ' - f'model_parallel_size: {model_parallel_size}' - ) - ) - # try to get local rank from global - local_rank = None - try: - local_rank = int(os.environ['LOCAL_RANK']) - except: - logging.info('Global variable LOCAL_RANK not yet specified') - if local_rank is not None: - app_state.local_rank = local_rank - else: - # if local is None then we are on the main process - local_rank = 0 - model_parallel_rank = compute_model_parallel_rank(local_rank, model_parallel_size) - app_state.model_parallel_rank = model_parallel_rank - else: - model_parallel_size = None - model_parallel_rank = None - - model = MegatronBertEncoder( - model_name=pretrained_model_name, - config=config, - vocab_file=vocab_file, - model_parallel_size=model_parallel_size, - model_parallel_rank=model_parallel_rank, + raise ValueError( + f'megatron-lm bert has been deprecated in NeMo 1.5. Please use an earlier release of NeMo. Megatron bert support will be added back to NeMo in a future release.' ) + # config = None + # # get default config and checkpoint + # if config_file: + # with open(config_file) as f: + # config = json.load(f) + # # replace dashes with underscores in config keys + # fixed_config = {} + # for key in config.keys(): + # fixed_key = key.replace("-", "_") + # if fixed_key == 'max_seq_length': + # fixed_key = 'max_position_embeddings' + # fixed_config[fixed_key] = config[key] + # # 'vocab_size" no longer used. + # if 'vocab_size' in fixed_config: + # fixed_config.pop('vocab_size') + # config = fixed_config + # elif config_dict: + # config = config_dict + # elif pretrained_model_name in get_megatron_lm_models_list(): + # config = get_megatron_config(pretrained_model_name) + # else: + # raise ValueError(f"{pretrained_model_name} is not supported") - return model, checkpoint_file + # if config is None: + # raise ValueError(f"config_file or config_dict is required for {pretrained_model_name}") + + # if not checkpoint_file: + # checkpoint_file = get_megatron_checkpoint(pretrained_model_name) + + # if not vocab_file: + # vocab_file = get_megatron_vocab_file(pretrained_model_name) + + # if not merges_file: + # merges_file = get_megatron_merges_file(pretrained_model_name) + + # app_state = AppState() + # if app_state.model_parallel_size is not None and app_state.model_parallel_rank is not None: + # # model parallel already known from .nemo restore + # model_parallel_size = app_state.model_parallel_size + # model_parallel_rank = app_state.model_parallel_rank + # elif os.path.isdir(checkpoint_file): + # # starting training from megatron-lm checkpoint + # mp_ranks = glob.glob(os.path.join(checkpoint_file, 'mp_rank*')) + # model_parallel_size = len(mp_ranks) + # app_state.model_parallel_size = model_parallel_size + # logging.info( + # ( + # f'restore_path: {checkpoint_file} is a directory. ' + # f'Assuming megatron model parallelism with ' + # f'model_parallel_size: {model_parallel_size}' + # ) + # ) + # # try to get local rank from global + # local_rank = None + # try: + # local_rank = int(os.environ['LOCAL_RANK']) + # except: + # logging.info('Global variable LOCAL_RANK not yet specified') + # if local_rank is not None: + # app_state.local_rank = local_rank + # else: + # # if local is None then we are on the main process + # local_rank = 0 + # model_parallel_rank = compute_model_parallel_rank(local_rank, model_parallel_size) + # app_state.model_parallel_rank = model_parallel_rank + # else: + # model_parallel_size = None + # model_parallel_rank = None + + # model = MegatronBertEncoder( + # model_name=pretrained_model_name, + # config=config, + # vocab_file=vocab_file, + # model_parallel_size=model_parallel_size, + # model_parallel_rank=model_parallel_rank, + # ) + + # return model, checkpoint_file def compute_model_parallel_rank(local_rank, model_parallel_size): @@ -239,6 +254,26 @@ def get_megatron_vocab_file(pretrained_model_name: str) -> str: return path +def get_megatron_merges_file(pretrained_model_name: str) -> str: + """ + Gets merge file from cache or downloads it + + Args: + pretrained_model_name: pretrained model name + + Returns: + path: path to the vocab file + """ + if 'gpt' not in pretrained_model_name.lower(): + return None + _check_megatron_name(pretrained_model_name) + url = MEGATRON_CONFIG_MAP[pretrained_model_name]["merges_file"] + + path = os.path.join(MEGATRON_CACHE, pretrained_model_name + "_merges") + path = _download(path, url) + return path + + def get_megatron_checkpoint(pretrained_model_name: str) -> str: """ Gets checkpoint file from cache or downloads it diff --git a/nemo/collections/nlp/modules/common/megatron/module.py b/nemo/collections/nlp/modules/common/megatron/module.py new file mode 100644 index 000000000..337a6f1e9 --- /dev/null +++ b/nemo/collections/nlp/modules/common/megatron/module.py @@ -0,0 +1,92 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Megatron Module""" + +import torch +from apex.transformer import parallel_state, tensor_parallel + + +def param_is_not_shared(param): + return not hasattr(param, 'shared') or not param.shared + + +class MegatronModule(torch.nn.Module): + """Megatron specific extensions of torch Module with support + for pipelining.""" + + def __init__(self, share_word_embeddings=True): + super(MegatronModule, self).__init__() + self.share_word_embeddings = share_word_embeddings + + def word_embeddings_weight(self): + if parallel_state.is_pipeline_first_stage(ignore_virtual=True): + return self.language_model.embedding.word_embeddings.weight + if parallel_state.is_pipeline_last_stage(ignore_virtual=True): + if not self.share_word_embeddings: + raise Exception( + 'word_embeddings_weight() called for last ' 'stage, but share_word_embeddings is false' + ) + return self.word_embeddings.weight + raise Exception('word_embeddings_weight() should be ' 'called for first and last stage only') + + def initialize_word_embeddings(self, init_method, vocab_size, hidden_size, pipeline_model_parallel_size=1): + if not self.share_word_embeddings: + raise Exception('initialize_word_embeddings() was called but ' 'share_word_embeddings is false') + + # TODO: pipeline model parallelism is not implemented in NeMo yet + # This function just initializes the word embeddings in the final stage + # when we are using pipeline parallelism. If we aren't using pipeline + # parallelism there is nothing to do. + if pipeline_model_parallel_size == 1: + return + + # Parameters are shared between the word embeddings layer, and the + # heads at the end of the model. In a pipelined setup with more than + # one stage, the initial embedding layer and the head are on different + # workers, so we do the following: + # 1. Create a second copy of word_embeddings on the last stage, with + # initial parameters of 0.0. + # 2. Do an all-reduce between the first and last stage to ensure that + # the two copies of word_embeddings start off with the same + # parameter values. + # 3. In the training loop, before an all-reduce between the grads of + # the two word_embeddings layers to ensure that every applied weight + # update is the same on both stages. + if parallel_state.is_pipeline_last_stage(): + assert not parallel_state.is_pipeline_first_stage() + self._word_embeddings_for_head_key = 'word_embeddings_for_head' + # set word_embeddings weights to 0 here, then copy first + # stage's weights using all_reduce below. + self.word_embeddings = tensor_parallel.VocabParallelEmbedding( + vocab_size, hidden_size, init_method=init_method + ) + self.word_embeddings.weight.data.fill_(0) + self.word_embeddings.weight.shared = True + + # Ensure that first and last stages have the same initial parameter + # values. + if torch.distributed.is_initialized(): + if parallel_state.is_pipeline_first_stage() or parallel_state.is_pipeline_last_stage(): + torch.distributed.all_reduce( + self.word_embeddings_weight().data, group=parallel_state.get_embedding_group() + ) + else: + print( + "WARNING! Distributed processes aren't initialized, so " + "word embeddings in the last layer are not initialized. " + "If you are just manipulating a model this is fine, but " + "this needs to be handled manually. If you are training " + "something is definitely wrong." + ) diff --git a/nemo/collections/nlp/modules/common/megatron/transformer.py b/nemo/collections/nlp/modules/common/megatron/transformer.py new file mode 100644 index 000000000..96c4e0c29 --- /dev/null +++ b/nemo/collections/nlp/modules/common/megatron/transformer.py @@ -0,0 +1,850 @@ +# coding=utf-8 +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Transformer.""" +import math + +import torch +import torch.nn.functional as F +from apex.normalization.fused_layer_norm import MixedFusedLayerNorm as LayerNorm +from apex.transformer import parallel_state, tensor_parallel +from apex.transformer.enums import AttnMaskType, AttnType, LayerType +from apex.transformer.functional.fused_softmax import FusedScaleMaskSoftmax + +from nemo.collections.nlp.modules.common.megatron.fused_bias_dropout_add import ( + BiasDropoutAddFusedInference, + BiasDropoutAddFusedTrain, + bias_dropout_add, +) +from nemo.collections.nlp.modules.common.megatron.fused_bias_gelu import FusedBiasGeLU +from nemo.collections.nlp.modules.common.megatron.module import MegatronModule +from nemo.collections.nlp.modules.common.megatron.utils import attention_mask_func, erf_gelu + + +""" We use the following notation throughout this file: + h: hidden size + n: number of attention heads + p: number of model parallel partitions + np: n/p + hp: h/p + hn: h/n + b: batch size + s: sequence length + l: number of layers + Transformer takes input of size [s, b, h] and returns a + tensor of the same size. We use the following arguments: + hyperparameters: transformer hyperparameters +""" + + +class ParallelMLP(MegatronModule): + """MLP. + + MLP will take the input with h hidden state, project it to 4*h + hidden dimension, perform nonlinear transformation, and project the + state back into h hidden dimension. + """ + + def __init__( + self, + init_method, + output_layer_init_method, + hidden_size, + ffn_hidden_size, + use_cpu_initialization=False, + bias_gelu_fusion=True, + openai_gelu=False, + onnx_safe=False, + fused_fp16=False, + fused_bf16=False, + ): + super(ParallelMLP, self).__init__() + + # Project to 4h. + self.dense_h_to_4h = tensor_parallel.ColumnParallelLinear( + hidden_size, + ffn_hidden_size, + gather_output=False, + init_method=init_method, + skip_bias_add=True, + use_cpu_initialization=use_cpu_initialization, + ) + + self.bias_gelu_fusion = bias_gelu_fusion + self.activation_func = F.gelu + if openai_gelu: + self.activation_func = openai_gelu + elif onnx_safe: + self.activation_func = erf_gelu + + # Project back to h. + self.dense_4h_to_h = tensor_parallel.RowParallelLinear( + ffn_hidden_size, + hidden_size, + input_is_parallel=True, + init_method=output_layer_init_method, + skip_bias_add=True, + use_cpu_initialization=use_cpu_initialization, + ) + + self.bias_gelu_impl = FusedBiasGeLU(fused_fp16, fused_bf16) + + def forward(self, hidden_states): + + # [s, b, 4hp] + intermediate_parallel, bias_parallel = self.dense_h_to_4h(hidden_states) + + if self.bias_gelu_fusion: + intermediate_parallel = self.bias_gelu_impl(intermediate_parallel, bias_parallel) + else: + intermediate_parallel = self.activation_func(intermediate_parallel + bias_parallel) + + # [s, b, h] + output, output_bias = self.dense_4h_to_h(intermediate_parallel) + return output, output_bias + + +class ParallelAttention(MegatronModule): + """Parallel self-attention layer abstract class. + + Self-attention layer takes input with size [b, s, h] + and returns output of the same size. + """ + + def __init__( + self, + init_method, + output_layer_init_method, + layer_number, + num_attention_heads, + hidden_size, + attention_type=AttnType.self_attn, + attn_mask_type=AttnMaskType.padding, + fused_fp16=False, + fused_bf16=False, + apply_query_key_layer_scaling=True, + kv_channels=None, + use_cpu_initialization=False, + masked_softmax_fusion=True, + attention_dropout=0.1, + ): + super(ParallelAttention, self).__init__() + + self.apply_query_key_layer_scaling = apply_query_key_layer_scaling + self.attention_softmax_in_fp32 = False + if self.apply_query_key_layer_scaling: + self.attention_softmax_in_fp32 = True + self.layer_number = max(1, layer_number) + self.attention_type = attention_type + self.attn_mask_type = attn_mask_type + + if kv_channels is None: + assert ( + hidden_size % num_attention_heads == 0 + ), 'hidden_size must be divisible by num_attention_heads if kv_channels is None' + kv_channels = hidden_size // num_attention_heads + projection_size = kv_channels * num_attention_heads + + # Per attention head and per partition values. + world_size = parallel_state.get_tensor_model_parallel_world_size() + self.hidden_size_per_partition = tensor_parallel.divide(projection_size, world_size) + self.hidden_size_per_attention_head = tensor_parallel.divide(projection_size, num_attention_heads) + self.num_attention_heads_per_partition = tensor_parallel.divide(num_attention_heads, world_size) + + # Strided linear layer. + if attention_type == AttnType.self_attn: + self.query_key_value = tensor_parallel.ColumnParallelLinear( + hidden_size, + 3 * projection_size, + gather_output=False, + init_method=init_method, + use_cpu_initialization=use_cpu_initialization, + ) + else: + assert attention_type == AttnType.cross_attn + self.query = tensor_parallel.ColumnParallelLinear( + hidden_size, projection_size, gather_output=False, init_method=init_method + ) + + self.key_value = tensor_parallel.ColumnParallelLinear( + hidden_size, 2 * projection_size, gather_output=False, init_method=init_method + ) + + coeff = None + self.norm_factor = math.sqrt(self.hidden_size_per_attention_head) + if self.apply_query_key_layer_scaling: + coeff = self.layer_number + self.norm_factor *= coeff + + self.scale_mask_softmax = FusedScaleMaskSoftmax( + fused_fp16, + fused_bf16, + self.attn_mask_type, + masked_softmax_fusion, + attention_mask_func, + self.attention_softmax_in_fp32, + coeff, + ) + + # Dropout. Note that for a single iteration, this layer will generate + # different outputs on different number of parallel partitions but + # on average it should not be partition dependent. + self.attention_dropout = torch.nn.Dropout(attention_dropout) + + # Output. + self.dense = tensor_parallel.RowParallelLinear( + projection_size, + hidden_size, + input_is_parallel=True, + init_method=output_layer_init_method, + skip_bias_add=True, + use_cpu_initialization=use_cpu_initialization, + ) + + def forward(self, hidden_states, attention_mask, layer_past=None, get_key_value=False, encoder_output=None): + # hidden_states: [sq, b, h] + + # ===================== + # Query, Key, and Value + # ===================== + + if self.attention_type == AttnType.self_attn: + # Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)] + mixed_x_layer, _ = self.query_key_value(hidden_states) + + # [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn] + new_tensor_shape = mixed_x_layer.size()[:-1] + ( + self.num_attention_heads_per_partition, + 3 * self.hidden_size_per_attention_head, + ) + mixed_x_layer = mixed_x_layer.view(*new_tensor_shape) + + # [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn] + (query_layer, key_layer, value_layer) = tensor_parallel.split_tensor_along_last_dim(mixed_x_layer, 3) + else: + # Attention heads [sk, b, h] --> [sk, b, (np * 2 * hn)] + mixed_kv_layer, _ = self.key_value(encoder_output) + + # [sk, b, (np * 2 * hn)] --> [sk, b, np, 2 * hn] + new_tensor_shape = mixed_kv_layer.size()[:-1] + ( + self.num_attention_heads_per_partition, + 2 * self.hidden_size_per_attention_head, + ) + mixed_kv_layer = mixed_kv_layer.view(*new_tensor_shape) + + # [sk, b, np, 2 * hn] --> 2 [sk, b, np, hn] + (key_layer, value_layer) = tensor_parallel.split_tensor_along_last_dim(mixed_kv_layer, 2) + + # Attention head [sq, b, h] --> [sq, b, hp] + query_layer, _ = self.query(hidden_states) + # [sq, b, hp] --> [sq, b, np, hn] + new_tensor_shape = query_layer.size()[:-1] + ( + self.num_attention_heads_per_partition, + self.hidden_size_per_attention_head, + ) + query_layer = query_layer.view(*new_tensor_shape) + + # ================================== + # Adjust key and value for inference + # ================================== + + if layer_past is not None: + past_key, past_value = layer_past + key_layer = torch.cat((past_key.type_as(key_layer), key_layer), dim=0) + value_layer = torch.cat((past_value.type_as(value_layer), value_layer), dim=0) + if get_key_value: + present = (key_layer, value_layer) + + # =================================== + # Raw attention scores. [b, np, s, s] + # =================================== + + # [b, np, sq, sk] + output_size = (query_layer.size(1), query_layer.size(2), query_layer.size(0), key_layer.size(0)) + + # [sq, b, np, hn] -> [sq, b * np, hn] + query_layer = query_layer.view(output_size[2], output_size[0] * output_size[1], -1) + # [sk, b, np, hn] -> [sk, b * np, hn] + key_layer = key_layer.view(output_size[3], output_size[0] * output_size[1], -1) + + # preallocting result tensor: [b * np, sq, sk] + matmul_result = torch.empty( + output_size[0] * output_size[1], + output_size[2], + output_size[3], + dtype=query_layer.dtype, + device=torch.cuda.current_device(), + ) + + # Raw attention scores. [b * np, sq, sk] + matmul_result = torch.baddbmm( + matmul_result, + query_layer.transpose(0, 1), # [b * np, sq, hn] + key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk] + beta=0.0, + alpha=(1.0 / self.norm_factor), + ) + + # change view to [b, np, sq, sk] + attention_scores = matmul_result.view(*output_size) + + # ================================================== + # Update attention mask for inference. [b, np, sq, sk] + # ================================================== + + if get_key_value: + with torch.no_grad(): + if layer_past is not None: + attention_mask = attention_mask[ + ..., attention_scores.size(3) - 1, : attention_scores.size(3) + ].unsqueeze(2) + else: + attention_mask = attention_mask[..., : attention_scores.size(3), : attention_scores.size(3)] + + # =========================== + # Attention probs and dropout + # =========================== + + # attention scores and attention mask [b, np, sq, sk] + attention_probs = self.scale_mask_softmax(attention_scores, attention_mask) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + with tensor_parallel.random.get_cuda_rng_tracker().fork(): + attention_probs = self.attention_dropout(attention_probs) + + # ========================= + # Context layer. [sq, b, hp] + # ========================= + + # value_layer -> context layer. + # [sk, b, np, hn] --> [b, np, sq, hn] + + # context layer shape: [b, np, sq, hn] + output_size = (value_layer.size(1), value_layer.size(2), query_layer.size(0), value_layer.size(3)) + + # change view [sk, b * np, hn] + value_layer = value_layer.view(value_layer.size(0), output_size[0] * output_size[1], -1) + + # change view [b * np, sq, sk] + attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1) + + # matmul: [b * np, sq, hn] + context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1)) + + # change view [b, np, sq, hn] + context_layer = context_layer.view(*output_size) + + # [b, np, sq, hn] --> [sq, b, np, hn] + context_layer = context_layer.permute(2, 0, 1, 3).contiguous() + + # [sq, b, np, hn] --> [sq, b, hp] + new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,) + context_layer = context_layer.view(*new_context_layer_shape) + + # ================= + # Output. [sq, b, h] + # ================= + + output, bias = self.dense(context_layer) + + if get_key_value: + output = [output, present] + + return output, bias + + +def get_bias_dropout_add(training): + def _bias_dropout_add(x, bias, residual, prob): + return bias_dropout_add(x, bias, residual, prob, training) + + return _bias_dropout_add + + +class ParallelTransformerLayer_(MegatronModule): + """A single transformer layer. + + Transformer layer takes input with size [b, s, h] and returns an + output of the same size. + """ + + def __init__( + self, + init_method, + output_layer_init_method, + layer_number, + hidden_size, + ffn_hidden_size, + num_attention_heads, + layer_type=LayerType.encoder, + self_attn_mask_type=AttnMaskType.padding, + fp32_residual_connection=False, + apply_residual_connection_post_layernorm=False, + fused_fp16=False, + fused_bf16=False, + apply_query_key_layer_scaling=True, + kv_channels=None, + layernorm_epsilon=1e-5, + hidden_dropout=0.1, + bias_dropout_fusion=True, + use_cpu_initialization=False, + bias_gelu_fusion=True, + openai_gelu=False, + onnx_safe=False, + masked_softmax_fusion=True, + attention_dropout=0.1, + ): + super(ParallelTransformerLayer_, self).__init__() + + assert not (fused_fp16 and fused_bf16), "both fused_fp16 and fused_bf16 flags cannot be True at the same time." + + if kv_channels is None: + assert ( + hidden_size % num_attention_heads == 0 + ), 'hidden_size must be divisible by num_attention_heads if kv_channels is None' + kv_channels = hidden_size // num_attention_heads + + self.layer_number = layer_number + self.layer_type = layer_type + + self.apply_residual_connection_post_layernorm = apply_residual_connection_post_layernorm # if true apply residual connection post layer norm (like original bert) + + self.bf16 = fused_bf16 + self.fp32_residual_connection = fp32_residual_connection # if true move residual connections to fp32 + + # Layernorm on the input data. + self.input_layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon, fp16=fused_fp16, bf16=fused_bf16) + + # Self attention. + self.self_attention = ParallelAttention( + init_method=init_method, + output_layer_init_method=output_layer_init_method, + layer_number=layer_number, + num_attention_heads=num_attention_heads, + hidden_size=hidden_size, + attention_type=AttnType.self_attn, + attn_mask_type=self_attn_mask_type, + fused_fp16=fused_fp16, + fused_bf16=fused_bf16, + apply_query_key_layer_scaling=apply_query_key_layer_scaling, + kv_channels=kv_channels, + use_cpu_initialization=use_cpu_initialization, + masked_softmax_fusion=masked_softmax_fusion, + attention_dropout=attention_dropout, + ) + self.hidden_dropout = hidden_dropout + self.bias_dropout_fusion = bias_dropout_fusion # if true, enable bias dropout fusion + + # Layernorm on the attention output + self.post_attention_layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon, fp16=fused_fp16, bf16=fused_bf16) + + if self.layer_type == LayerType.decoder: + self.inter_attention = ParallelAttention( + init_method=init_method, + output_layer_init_method=output_layer_init_method, + layer_number=layer_number, + num_attention_heads=num_attention_heads, + hidden_size=hidden_size, + attention_type=AttnType.cross_attn, + attn_mask_type=AttnMaskType.padding, + fused_fp16=fused_fp16, + fused_bf16=fused_bf16, + apply_query_key_layer_scaling=apply_query_key_layer_scaling, + kv_channels=kv_channels, + use_cpu_initialization=use_cpu_initialization, + masked_softmax_fusion=masked_softmax_fusion, + attention_dropout=attention_dropout, + ) + # Layernorm on the attention output. + self.post_inter_attention_layernorm = LayerNorm( + hidden_size, eps=layernorm_epsilon, fp16=fused_fp16, bf16=fused_bf16 + ) + + # MLP + self.mlp = ParallelMLP( + init_method=init_method, + output_layer_init_method=output_layer_init_method, + hidden_size=hidden_size, + ffn_hidden_size=ffn_hidden_size, + use_cpu_initialization=use_cpu_initialization, + bias_gelu_fusion=bias_gelu_fusion, + openai_gelu=openai_gelu, + onnx_safe=onnx_safe, + fused_fp16=fused_fp16, + fused_bf16=fused_bf16, + ) + self.bias_dropout_add_fused_train = BiasDropoutAddFusedTrain(fused_fp16, fused_bf16) + self.bias_dropout_add_fused_inference = BiasDropoutAddFusedInference(fused_fp16, fused_bf16) + + self.bias_dropout_add_fused_train = BiasDropoutAddFusedTrain(fused_fp16, fused_bf16) + self.bias_dropout_add_fused_inference = BiasDropoutAddFusedInference(fused_fp16, fused_bf16) + + def forward( + self, + hidden_states, + attention_mask, + encoder_output=None, + enc_dec_attn_mask=None, + layer_past=None, + get_key_value=False, + ): + # hidden_states: [b, s, h] + + # Layer norm at the beginning of the transformer layer. + layernorm_output = self.input_layernorm(hidden_states) + # Self attention. + attention_output, attention_bias = self.self_attention( + layernorm_output, attention_mask, layer_past=layer_past, get_key_value=get_key_value + ) + + if get_key_value: + attention_output, presents = attention_output + + # Residual connection. + if self.apply_residual_connection_post_layernorm: + residual = layernorm_output + else: + residual = hidden_states + + # jit scripting for a nn.module (with dropout) is not + # trigerring the fusion kernel. For now, we use two + # different nn.functional routines to account for varying + # dropout semantics during training and inference phases. + if self.bias_dropout_fusion: + if self.training: + bias_dropout_add_func = self.bias_dropout_add_fused_train + else: + bias_dropout_add_func = self.bias_dropout_add_fused_inference + else: + bias_dropout_add_func = get_bias_dropout_add(self.training) + + # re-enable torch grad to enable fused optimization. + with torch.enable_grad(): + layernorm_input = bias_dropout_add_func( + attention_output, attention_bias.expand_as(residual), residual, self.hidden_dropout + ) + + # Layer norm post the self attention. + layernorm_output = self.post_attention_layernorm(layernorm_input) + + if self.layer_type == LayerType.decoder: + attention_output, attention_bias = self.inter_attention( + layernorm_output, enc_dec_attn_mask, encoder_output=encoder_output + ) + # residual connection + if self.apply_residual_connection_post_layernorm: + residual = layernorm_output + else: + residual = layernorm_input + + # re-enable torch grad to enable fused optimization. + with torch.enable_grad(): + layernorm_input = bias_dropout_add_func( + attention_output, attention_bias.expand_as(residual), residual, self.hidden_dropout + ) + + # Layer norm post the decoder attention + layernorm_output = self.post_inter_attention_layernorm(layernorm_input) + + # MLP. + mlp_output, mlp_bias = self.mlp(layernorm_output) + + # Second residual connection. + if self.apply_residual_connection_post_layernorm: + residual = layernorm_output + else: + residual = layernorm_input + + # re-enable torch grad to enable fused optimization. + with torch.enable_grad(): + output = bias_dropout_add_func(mlp_output, mlp_bias.expand_as(residual), residual, self.hidden_dropout) + + if get_key_value: + output = [output, presents] + + return output + + +class ParallelTransformerLayer(ParallelTransformerLayer_): + def __init__(self, **kwargs): + super(ParallelTransformerLayer, self).__init__(**kwargs) + + assert not (kwargs['fused_fp16'] and kwargs['fused_bf16']) + if kwargs['fused_fp16']: + self.dtype = torch.float16 + elif kwargs['fused_bf16']: + self.dtype = torch.bfloat16 + else: + self.dtype = torch.float32 + + def forward( + self, + hidden_states, + attention_mask, + encoder_output=None, + enc_dec_attn_mask=None, + layer_past=None, + get_key_value=False, + ): + + if self.dtype == torch.float32: + return super().forward( + hidden_states, attention_mask, encoder_output, enc_dec_attn_mask, layer_past, get_key_value + ) + else: + with torch.cuda.amp.autocast(dtype=self.dtype): + return super().forward( + hidden_states, attention_mask, encoder_output, enc_dec_attn_mask, layer_past, get_key_value + ) + + +class ParallelTransformer(MegatronModule): + """Transformer class.""" + + def __init__( + self, + init_method, + output_layer_init_method, + num_layers, + hidden_size, + ffn_hidden_size, + num_attention_heads, + apply_query_key_layer_scaling=True, + kv_channels=None, + layer_type=LayerType.encoder, + self_attn_mask_type=AttnMaskType.padding, + pre_process=True, + post_process=True, + fused_fp16=False, + fused_bf16=False, + fp32_residual_connection=False, + activations_checkpoint_method=None, + activations_checkpoint_num_layers=1, + layernorm_epsilon=1e-5, + hidden_dropout=0.1, + use_cpu_initialization=False, + bias_gelu_fusion=True, + openai_gelu=False, + onnx_safe=False, + ): + super(ParallelTransformer, self).__init__() + + assert not (fused_fp16 and fused_bf16), "both fused_fp16 and fused_bf16 flags cannot be True at the same time." + + if kv_channels is None: + assert ( + hidden_size % num_attention_heads == 0 + ), 'hidden_size must be divisible by num_attention_heads if kv_channels is None' + kv_channels = hidden_size // num_attention_heads + + self.bf16 = fused_bf16 + self.fp32_residual_connection = fp32_residual_connection + self.pre_process = pre_process + self.post_process = post_process + self.input_tensor = None + + # Store activation checkpointing flag. + self.activations_checkpoint_method = activations_checkpoint_method + self.activations_checkpoint_num_layers = activations_checkpoint_num_layers + + # Number of layers. + assert ( + num_layers % parallel_state.get_pipeline_model_parallel_world_size() == 0 + ), 'num_layers must be divisible by pipeline_model_parallel_size' + self.num_layers = num_layers // parallel_state.get_pipeline_model_parallel_world_size() + + # Transformer layers. + def build_layer(layer_number): + return ParallelTransformerLayer( + init_method=init_method, + output_layer_init_method=output_layer_init_method, + layer_number=layer_number, + hidden_size=hidden_size, + ffn_hidden_size=ffn_hidden_size, + num_attention_heads=num_attention_heads, + apply_query_key_layer_scaling=apply_query_key_layer_scaling, + kv_channels=kv_channels, + layer_type=layer_type, + self_attn_mask_type=self_attn_mask_type, + fused_fp16=fused_fp16, + fused_bf16=fused_bf16, + fp32_residual_connection=fp32_residual_connection, + layernorm_epsilon=layernorm_epsilon, + hidden_dropout=hidden_dropout, + use_cpu_initialization=use_cpu_initialization, + bias_gelu_fusion=bias_gelu_fusion, + openai_gelu=openai_gelu, + onnx_safe=onnx_safe, + ) + + # TODO: get virtual_pipeline_model_parallel_size from apex.mpu + # if parallel_state.get_virtual_pipeline_model_parallel_rank() is not None: + # assert args.num_layers % args.virtual_pipeline_model_parallel_size == 0, ( + # 'num_layers_per_stage must be divisible by ' 'virtual_pipeline_model_parallel_size' + # ) + # # Number of layers in each model chunk is the number of layers in the stage, + # # divided by the number of model chunks in a stage. + # self.num_layers = self.num_layers // args.virtual_pipeline_model_parallel_size + # # With 8 layers, 2 stages, and 4 model chunks, we want an assignment of + # # layers to stages like (each list is a model chunk): + # # Stage 0: [0] [2] [4] [6] + # # Stage 1: [1] [3] [5] [7] + # # With 8 layers, 2 stages, and 2 virtual stages, we want an assignment of + # # layers to stages like (each list is a model chunk): + # # Stage 0: [0, 1] [4, 5] + # # Stage 1: [2, 3] [6, 7] + # offset = parallel_state.get_virtual_pipeline_model_parallel_rank() * ( + # args.num_layers // args.virtual_pipeline_model_parallel_size + # ) + (parallel_state.get_pipeline_model_parallel_rank() * self.num_layers) + # else: + # # Each stage gets a contiguous set of layers. + # offset = parallel_state.get_pipeline_model_parallel_rank() * self.num_layers + offset = parallel_state.get_pipeline_model_parallel_rank() * self.num_layers + + self.layers = torch.nn.ModuleList([build_layer(i + 1 + offset) for i in range(self.num_layers)]) + + if self.post_process: + # Final layer norm before output. + self.final_layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon, fp16=fused_fp16, bf16=fused_bf16) + + def _get_layer(self, layer_number): + return self.layers[layer_number] + + def _checkpointed_forward(self, hidden_states, attention_mask, encoder_output, enc_dec_attn_mask): + """Forward method with activation checkpointing.""" + + def custom(start, end): + def custom_forward(*inputs): + x_ = inputs[0] + attention_mask = inputs[1] + encoder_output = inputs[2] + enc_dec_attn_mask = inputs[3] + for index in range(start, end): + layer = self._get_layer(index) + x_ = layer(x_, attention_mask, encoder_output, enc_dec_attn_mask) + return x_ + + return custom_forward + + # Make sure memory is freed. + tensor_parallel.reset_checkpointed_activations_memory_buffer() + + if self.activations_checkpoint_method == 'uniform': + # Uniformly divide the total number of Transformer layers and checkpoint + # the input activation of each divided chunk. + # A method to further reduce memory usage reducing checkpoints. + l = 0 + while l < self.num_layers: + hidden_states = tensor_parallel.checkpoint( + custom(l, l + self.activations_checkpoint_num_layers), + hidden_states, + attention_mask, + encoder_output, + enc_dec_attn_mask, + ) + l += self.activations_checkpoint_num_layers + elif self.activations_checkpoint_method == 'block': + # Checkpoint the input activation of only a set number of individual + # Transformer layers and skip the rest. + # A method fully use the device memory removing redundant re-computation. + for l in range(self.num_layers): + if l < self.activations_checkpoint_num_layers: + hidden_states = tensor_parallel.checkpoint( + custom(l, l + 1), hidden_states, attention_mask, encoder_output, enc_dec_attn_mask + ) + else: + hidden_states = custom(l, l + 1)(hidden_states, attention_mask, encoder_output, enc_dec_attn_mask) + else: + raise ValueError("Invalid activation checkpoint method.") + + return hidden_states + + def set_input_tensor(self, input_tensor): + """Set input tensor to be used instead of forward()'s input. + + When doing pipeline parallelism the input from the previous + stage comes from communication, not from the input, so the + model's forward_step_func won't have it. This function is thus + used by internal code to bypass the input provided by the + forward_step_func""" + self.input_tensor = input_tensor + + def forward( + self, + hidden_states, + attention_mask, + layer_past=None, + get_key_value=False, + encoder_output=None, + enc_dec_attn_mask=None, + ): + # Checks. + if layer_past is not None: + assert get_key_value, 'for not None values in layer_past, ' 'expected get_key_value to be set' + if get_key_value: + assert self.activations_checkpoint_method is None, ( + 'get_key_value does not work with ' 'activation checkpointing' + ) + + if self.pre_process: + # Data format change to avoid explicit tranposes : [b s h] --> [s b h]. + # If the input flag for fp32 residual connection is set, convert for float. + if self.fp32_residual_connection: + hidden_states = hidden_states.transpose(0, 1).contiguous().float() + # Otherwise, leave it as is. + else: + hidden_states = hidden_states.transpose(0, 1).contiguous() + else: + # See set_input_tensor() + hidden_states = self.input_tensor + + if encoder_output is not None: + encoder_output = encoder_output.transpose(0, 1).contiguous() + + if self.activations_checkpoint_method is not None: + hidden_states = self._checkpointed_forward( + hidden_states, attention_mask, encoder_output, enc_dec_attn_mask + ) + else: + if get_key_value: + presents = [] + for index in range(self.num_layers): + layer = self._get_layer(index) + past = None + if layer_past is not None: + past = layer_past[index] + hidden_states = layer( + hidden_states, + attention_mask, + encoder_output=encoder_output, + enc_dec_attn_mask=enc_dec_attn_mask, + layer_past=past, + get_key_value=get_key_value, + ) + if get_key_value: + hidden_states, present = hidden_states + presents.append(present) + + # Final layer norm. + if self.post_process: + # Reverting data format change [s b h] --> [b s h]. + hidden_states = hidden_states.transpose(0, 1).contiguous() + output = self.final_layernorm(hidden_states) + else: + output = hidden_states + if get_key_value: + output = [output, presents] + + return output diff --git a/nemo/collections/nlp/modules/common/megatron/utils.py b/nemo/collections/nlp/modules/common/megatron/utils.py new file mode 100644 index 000000000..18366cb0e --- /dev/null +++ b/nemo/collections/nlp/modules/common/megatron/utils.py @@ -0,0 +1,166 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utilities for models.""" + +import math + +import torch +from apex.transformer import parallel_state + + +def init_method_normal(sigma): + """Init method based on N(0, sigma).""" + + def init_(tensor): + return torch.nn.init.normal_(tensor, mean=0.0, std=sigma) + + return init_ + + +def scaled_init_method_normal(sigma, num_layers): + """Init method based on N(0, sigma/sqrt(2*num_layers).""" + std = sigma / math.sqrt(2.0 * num_layers) + + def init_(tensor): + return torch.nn.init.normal_(tensor, mean=0.0, std=std) + + return init_ + + +def attention_mask_func(attention_scores, attention_mask): + attention_scores.masked_fill_(attention_mask, -10000.0) + return attention_scores + + +def get_linear_layer(rows, columns, init_method): + """Simple linear layer with weight initialization.""" + layer = torch.nn.Linear(rows, columns) + init_method(layer.weight) + with torch.no_grad(): + layer.bias.zero_() + return layer + + +@torch.jit.script +def gelu_impl(x): + """OpenAI's gelu implementation.""" + return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * x * (1.0 + 0.044715 * x * x))) + + +def openai_gelu(x): + return gelu_impl(x) + + +# This is actually Python equivalent of torch.nn.functional.gelu(), also with type hints for ONNX exporter +@torch.jit.script +def erf_gelu(x): + return x * 0.5 * (torch.erf(x / 1.41421).to(dtype=x.dtype) + torch.ones_like(x).to(dtype=x.dtype)) + + +def average_losses_across_data_parallel_group(losses): + """Reduce a tensor of losses across all GPUs.""" + averaged_losses = torch.cat([loss.clone().detach().view(1) for loss in losses]) + torch.distributed.all_reduce(averaged_losses, group=parallel_state.get_data_parallel_group()) + averaged_losses = averaged_losses / torch.distributed.get_world_size( + group=parallel_state.get_data_parallel_group() + ) + + return averaged_losses + + +def get_ltor_masks_and_position_ids(data, eod_token, reset_position_ids, reset_attention_mask, eod_mask_loss): + """Build masks and position id for left to right model.""" + + # Extract batch size and sequence length. + micro_batch_size, seq_length = data.size() + + # Attention mask (lower triangular). + if reset_attention_mask: + att_mask_batch = micro_batch_size + else: + att_mask_batch = 1 + attention_mask = torch.tril(torch.ones((att_mask_batch, seq_length, seq_length), device=data.device)).view( + att_mask_batch, 1, seq_length, seq_length + ) + + # Loss mask. + loss_mask = torch.ones(data.size(), dtype=torch.float, device=data.device) + if eod_mask_loss: + loss_mask[data == eod_token] = 0.0 + + # Position ids. + position_ids = torch.arange(seq_length, dtype=torch.long, device=data.device) + position_ids = position_ids.unsqueeze(0).expand_as(data) + # We need to clone as the ids will be modifed based on batch index. + if reset_position_ids: + position_ids = position_ids.clone() + + if reset_position_ids or reset_attention_mask: + # Loop through the batches: + for b in range(micro_batch_size): + + # Find indecies where EOD token is. + eod_index = position_ids[b, data[b] == eod_token] + # Detach indecies from positions if going to modify positions. + if reset_position_ids: + eod_index = eod_index.clone() + + # Loop through EOD indicies: + prev_index = 0 + for j in range(eod_index.size()[0]): + i = eod_index[j] + # Mask attention loss. + if reset_attention_mask: + attention_mask[b, 0, (i + 1) :, : (i + 1)] = 0 + # Reset positions. + if reset_position_ids: + position_ids[b, (i + 1) :] -= i + 1 - prev_index + prev_index = i + 1 + + # Convert attention mask to binary: + attention_mask = attention_mask < 0.5 + + return attention_mask, loss_mask, position_ids + + +class AutocastModuleWrapper(torch.nn.Module): + def __init__(self, fp16=False, bf16=False): + super(AutocastModuleWrapper, self).__init__() + + # TODO: @titu1994 Maybe useful to make this adaptive - prefer bf16 over fp16 when both flags are set and both dtypes are supported? + assert not (fp16 and bf16) + if fp16 or bf16: + self.precision = torch.float16 if fp16 else torch.bfloat16 + self.use_autocast = True + else: + self.use_autocast = False + + # class or function to autocast should be defined + self.func = None + + def autocast_forward(self, *args): + + assert self.func is not None + if self.use_autocast: + if isinstance(self.func, torch.autograd.function.FunctionMeta): + fwd = torch.cuda.amp.custom_fwd(cast_inputs=self.precision)(self.func.apply) + else: + fwd = torch.cuda.amp.custom_fwd(cast_inputs=self.precision)(self.func) + return fwd(*args) + else: + if isinstance(self.func, torch.autograd.function.FunctionMeta): + return self.func.apply(*args) + else: + return self.func(*args) diff --git a/nemo/collections/nlp/modules/common/tokenizer_utils.py b/nemo/collections/nlp/modules/common/tokenizer_utils.py index d08cdc3a6..b7fe3c4d0 100644 --- a/nemo/collections/nlp/modules/common/tokenizer_utils.py +++ b/nemo/collections/nlp/modules/common/tokenizer_utils.py @@ -31,6 +31,13 @@ from nemo.utils import logging __all__ = ['get_tokenizer', 'get_tokenizer_list'] +megatron_tokenizer_model_map = { + 'BertWordPieceLowerCase': 'megatron-bert-345m-uncased', + 'BertWordPieceCase': 'megatron-bert-345m-cased', + 'GPT2BPETokenizer': 'megatron-gpt-345m', +} + + def get_tokenizer_list() -> List[str]: """ Returns all all supported tokenizer names @@ -57,6 +64,7 @@ def get_tokenizer( tokenizer_name: str, tokenizer_model: Optional[str] = None, vocab_file: Optional[str] = None, + merges_file: Optional[str] = None, special_tokens: Optional[Dict[str, str]] = None, use_fast: Optional[bool] = False, bpe_dropout: Optional[float] = 0.0, @@ -84,6 +92,9 @@ def get_tokenizer( vocab_file = nemo.collections.nlp.modules.common.megatron.megatron_utils.get_megatron_vocab_file( tokenizer_name ) + merges_file = nemo.collections.nlp.modules.common.megatron.megatron_utils.get_megatron_merges_file( + tokenizer_name + ) tokenizer_name = get_megatron_tokenizer(tokenizer_name) if tokenizer_name == 'sentencepiece': @@ -101,7 +112,11 @@ def get_tokenizer( f"Getting HuggingFace AutoTokenizer with pretrained_model_name: {tokenizer_name}, vocab_file: {vocab_file}, special_tokens_dict: {special_tokens_dict}, and use_fast: {use_fast}" ) return AutoTokenizer( - pretrained_model_name=tokenizer_name, vocab_file=vocab_file, **special_tokens_dict, use_fast=use_fast + pretrained_model_name=tokenizer_name, + vocab_file=vocab_file, + merges_file=merges_file, + **special_tokens_dict, + use_fast=use_fast, ) @@ -110,6 +125,7 @@ def get_nmt_tokenizer( model_name: Optional[str] = None, tokenizer_model: Optional[str] = None, vocab_file: Optional[str] = None, + merges_file: Optional[str] = None, special_tokens: Optional[Dict[str, str]] = None, use_fast: Optional[bool] = False, bpe_dropout: Optional[float] = 0.0, @@ -138,10 +154,14 @@ def get_nmt_tokenizer( elif library == 'huggingface': logging.info(f'Getting HuggingFace AutoTokenizer with pretrained_model_name: {model_name}') return AutoTokenizer( - pretrained_model_name=model_name, vocab_file=vocab_file, **special_tokens_dict, use_fast=use_fast + pretrained_model_name=model_name, + vocab_file=vocab_file, + merges_file=merges_file, + **special_tokens_dict, + use_fast=use_fast, ) elif library == 'sentencepiece': - logging.info(f'Getting SentencePiece with model: {model_name}') + logging.info(f'Getting SentencePiece with model: {tokenizer_model}') return nemo.collections.common.tokenizers.sentencepiece_tokenizer.SentencePieceTokenizer( model_path=tokenizer_model, special_tokens=special_tokens_dict ) @@ -149,10 +169,12 @@ def get_nmt_tokenizer( logging.info(f'Using byte-level tokenization') return ByteLevelTokenizer() elif library == 'megatron': + if model_name in megatron_tokenizer_model_map: + model_name = megatron_tokenizer_model_map[model_name] logging.info( f'Getting Megatron tokenizer for pretrained model name: {model_name} and custom vocab file: {vocab_file}' ) - return get_tokenizer(tokenizer_name=model_name, vocab_file=vocab_file) + return get_tokenizer(tokenizer_name=model_name, vocab_file=vocab_file, merges_file=merges_file) else: raise NotImplementedError( 'Currently we only support "yttm", "huggingface", "sentencepiece", "megatron", and "byte-level" tokenizer libraries.' diff --git a/nemo/collections/nlp/parts/nlp_overrides.py b/nemo/collections/nlp/parts/nlp_overrides.py index 53b520020..9407e7331 100644 --- a/nemo/collections/nlp/parts/nlp_overrides.py +++ b/nemo/collections/nlp/parts/nlp_overrides.py @@ -13,23 +13,34 @@ # limitations under the License. import os +import shutil +import tempfile from typing import Any, Dict, List, Optional, Union +import pytorch_lightning as pl import torch -from megatron import mpu -from megatron.checkpointing import get_checkpoint_version, set_checkpoint_version -from megatron.initialize import _set_random_seed -from pytorch_lightning.core.lightning import LightningModule +from apex.transformer import parallel_state from pytorch_lightning.overrides import LightningDistributedModule from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment +from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO +from pytorch_lightning.plugins.precision.native_amp import NativeMixedPrecisionPlugin +from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin from pytorch_lightning.plugins.training_type.ddp import DDPPlugin from pytorch_lightning.trainer.connectors.checkpoint_connector import CheckpointConnector +from pytorch_lightning.trainer.trainer import Trainer from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.cloud_io import atomic_save +from pytorch_lightning.utilities.enums import GradClipAlgorithmType +from torch.nn.modules.module import Module from torch.nn.parallel import DistributedDataParallel +from torch.optim.optimizer import Optimizer -from nemo.collections.nlp.modules.common.megatron.megatron_bert import MegatronBertEncoder -from nemo.collections.nlp.modules.common.megatron.megatron_encoder import MegatronEncoderModule +from nemo.collections.nlp.modules.common.megatron.clip_grads import clip_grad_norm_fp32 +from nemo.collections.nlp.modules.common.megatron.megatron_bert import ( + get_megatron_checkpoint_version, + set_megatron_checkpoint_version, +) +from nemo.core.connectors.save_restore_connector import SaveRestoreConnector from nemo.utils import AppState, logging @@ -45,21 +56,22 @@ class NLPDDPPlugin(DDPPlugin): num_nodes: int = 1, cluster_environment: ClusterEnvironment = None, sync_batchnorm: bool = False, + checkpoint_io: Optional[CheckpointIO] = None, **kwargs: Union[Any, Dict[str, Any]], ) -> None: - super().__init__(parallel_devices, num_nodes, cluster_environment, sync_batchnorm, **kwargs) + super().__init__(parallel_devices, num_nodes, cluster_environment, checkpoint_io, sync_batchnorm, **kwargs) - def init_ddp_connection(self, global_rank: int = None, world_size: int = None) -> None: + def setup_distributed(self, global_rank: int = None, world_size: int = None) -> None: # call PTL init ddp - super().init_ddp_connection() + super().setup_distributed() # init model parallel if needed app_state = AppState() if app_state.model_parallel_size is not None: - - if self.lightning_module.has_megatron_encoder and not self.lightning_module.is_model_parallel_initialized: - self.init_model_parallel(app_state.global_rank, app_state.world_size) + self.init_model_parallel(app_state.global_rank, app_state.world_size) + # if self.lightning_module.has_megatron_encoder and not self.lightning_module.is_model_parallel_initialized: + # self.init_model_parallel(app_state.global_rank, app_state.world_size) def start_training(self, trainer: 'Trainer') -> None: """ PTL Hook that is called after DPP is initialized. """ @@ -73,7 +85,7 @@ class NLPDDPPlugin(DDPPlugin): if not hasattr(p, 'model_parallel'): p.model_parallel = False - if get_checkpoint_version() is not None: + if get_megatron_checkpoint_version() is not None: # megatron checkpoint already restored pass elif trainer.checkpoint_connector.resume_checkpoint_path is not None: @@ -92,14 +104,14 @@ class NLPDDPPlugin(DDPPlugin): 'checkpoint_version', None ) if checkpoint_version is not None: - set_checkpoint_version(checkpoint_version) + set_megatron_checkpoint_version(checkpoint_version) else: logging.warning('Megatron-lm checkpoint version not found. Setting checkpoint_version to 0.') - set_checkpoint_version(0) + set_megatron_checkpoint_version(0) else: self.lightning_module.restore_megatron_encoder_weights() else: - if get_checkpoint_version() is not None: + if get_megatron_checkpoint_version() is not None: # megatron checkpoint already restored pass else: @@ -117,7 +129,7 @@ class NLPDDPPlugin(DDPPlugin): if self.has_megatron_encoder: # check megatron checkpoint version - checkpoint_version = get_checkpoint_version() + checkpoint_version = get_megatron_checkpoint_version() if checkpoint_version is None: raise ValueError("Unable to find megatron checkpoint version.") @@ -125,7 +137,7 @@ class NLPDDPPlugin(DDPPlugin): def configure_ddp(self): """ Override LightningModule ddp if using model parallel. - Sets find_unused_parameters to True. + Sets find_unused_parameters to False to use activation-checkpoint-recomputation. """ app_state = AppState() @@ -142,7 +154,7 @@ class NLPDDPPlugin(DDPPlugin): device_ids=device_ids, output_device=device_ids[0], process_group=app_state.data_parallel_group, - find_unused_parameters=True, + find_unused_parameters=False, **self._ddp_kwargs, ) @@ -163,17 +175,35 @@ class NLPDDPPlugin(DDPPlugin): # after initializing DDP with PTL. if app_state.model_parallel_size is not None: if torch.distributed.is_initialized(): - mpu.initialize_model_parallel(app_state.model_parallel_size) - app_state.model_parallel_group = mpu.get_model_parallel_group() - app_state.data_parallel_group = mpu.get_data_parallel_group() - app_state.model_parallel_rank = mpu.get_tensor_model_parallel_rank() - app_state.data_parallel_rank = mpu.get_data_parallel_rank() - app_state.data_parallel_size = mpu.get_data_parallel_world_size() + parallel_state.initialize_model_parallel(app_state.model_parallel_size) + app_state.model_parallel_group = parallel_state.get_tensor_model_parallel_group() + app_state.data_parallel_group = parallel_state.get_data_parallel_group() + app_state.model_parallel_rank = parallel_state.get_tensor_model_parallel_rank() + app_state.data_parallel_rank = parallel_state.get_data_parallel_rank() + app_state.data_parallel_size = parallel_state.get_data_parallel_world_size() logging.info(f'mp_rank: {app_state.model_parallel_rank}') logging.info(f'dp_rank: {app_state.data_parallel_rank}') - seed = os.environ.get("PL_GLOBAL_SEED", 1234) - # random seed must be set for megatron model parallel init - _set_random_seed(seed) + + def save_checkpoint(self, checkpoint: Dict[str, Any], filepath: str) -> None: + """Save model/training states as a checkpoint file through state-dump and file-write. + + Args: + checkpoint: dict containing model and trainer state + filepath: write-target file's path + """ + app_state = AppState() + # dump states as a checkpoint dictionary object + # TrainingTypePlugin.on_save() just seems to return the same thing. + # checkpoint = self.on_save(checkpoint) + if self.is_global_zero or app_state.data_parallel_rank == 0: + try: + # write the checkpoint dictionary on the file + atomic_save(checkpoint, filepath) + except AttributeError as err: + key = pl.LightningModule.CHECKPOINT_HYPER_PARAMS_KEY + checkpoint.pop(key, None) + rank_zero_warn(f"Warning, `{key}` dropped from checkpoint. An attribute is not picklable: {err}") + atomic_save(checkpoint, filepath) @property def distributed_sampler_kwargs(self): @@ -195,18 +225,17 @@ class NLPCheckpointConnector(CheckpointConnector): """ Override PTL CheckpointConnector to support model parallel checkpoints from Megatron-LM. """ - def __init__(self, trainer): - super().__init__(trainer) + def __init__(self, trainer, resume_from_checkpoint): + super().__init__(trainer, resume_from_checkpoint) - def save_checkpoint(self, filepath, weights_only: bool): + def save_checkpoint(self, filepath, weights_only: bool = False) -> None: """Slightly modified version of PyTorch Lightning's save_checkpoint. + Accounts for model parallel training. + Save model/training states as a checkpoint file through state-dump and file-write. Args: - filepath ([str]): [description] - weights_only (bool): [description] - - Returns: - [type]: [description] + filepath: write-target file's path + weights_only: saving model weights only """ app_state = AppState() if app_state.model_parallel_size is not None: @@ -214,22 +243,137 @@ class NLPCheckpointConnector(CheckpointConnector): dirname = os.path.dirname(filepath) basename = os.path.basename(filepath) filepath = f'{dirname}/mp_rank_{app_state.model_parallel_rank:02d}/{basename}' - - # dump states as a checkpoint dictionary object - checkpoint = self.dump_checkpoint(weights_only) - + _checkpoint = self.dump_checkpoint(weights_only) # each model parallel rank needs to save a copy of its model if app_state.data_parallel_rank == 0: - # write the checkpoint dictionary on the file - if self.trainer.accelerator_backend: - checkpoint = self.trainer.accelerator_backend.on_save(checkpoint) - try: - atomic_save(checkpoint, filepath) - except AttributeError as err: - if LightningModule.CHECKPOINT_HYPER_PARAMS_KEY in checkpoint: - del checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY] - rank_zero_warn( - 'Warning, `hyper_parameters` dropped from checkpoint.' f' An attribute is not picklable {err}' - ) - atomic_save(checkpoint, filepath) - return None + self.trainer.accelerator.save_checkpoint(_checkpoint, filepath) + else: + super().save_checkpoint(filepath, weights_only) + + +class NLPSaveRestoreConnector(SaveRestoreConnector): + def __init__(self) -> None: + super().__init__() + + def save_to(self, model, save_path: str): + app_state = AppState() + if app_state.model_parallel_size is not None and app_state.model_parallel_size > 1: + + dir_name = os.path.dirname(save_path) + + # first we save the weights for each model parallel rank + if app_state.data_parallel_rank == 0: + mp_model_weights = os.path.join( + dir_name, f'mp_rank_{app_state.model_parallel_rank:02d}_' + self.model_weights_ckpt + ) + self._save_state_dict_to_disk(model.state_dict(), mp_model_weights) + + torch.distributed.barrier() + + # create nemo file from folder with all mp_ranks checkpoints + if app_state.model_parallel_rank == 0 and app_state.data_parallel_rank == 0: + with tempfile.TemporaryDirectory() as tmpdir: + + # move weights to the tmpdir + for mp_rank in range(app_state.model_parallel_size): + os.makedirs(os.path.join(tmpdir, f'mp_rank_{mp_rank:02d}')) + mp_model_weights = os.path.join(dir_name, f'mp_rank_{mp_rank:02d}_' + self.model_weights_ckpt) + shutil.move( + mp_model_weights, os.path.join(tmpdir, f'mp_rank_{mp_rank:02d}', self.model_weights_ckpt) + ) + + # create config and artifacts in tmpdir + config_yaml = os.path.join(tmpdir, self.model_config_yaml) + model.to_config_file(path2yaml_file=config_yaml) + if hasattr(model, 'artifacts') and model.artifacts is not None: + self._handle_artifacts(model, nemo_file_folder=tmpdir) + self._update_artifact_paths(model, path2yaml_file=config_yaml) + + # create tar file + self._make_nemo_file_from_folder(save_path, tmpdir) + + else: + return super().save_to(model, save_path) + + +class NLPNativeMixedPrecisionPlugin(NativeMixedPrecisionPlugin): + def __init__(self, init_scale: float = 2 ** 32, growth_interval: int = 1000) -> None: + super().__init__(precision=16) + + self.scaler = torch.cuda.amp.GradScaler(init_scale=init_scale, growth_interval=growth_interval) + + def clip_gradients( + self, + optimizer: Optimizer, + clip_val: Union[int, float], + gradient_clip_algorithm: GradClipAlgorithmType, + model: Optional[Module], + ) -> None: + """Override PTL gradient clipping. + Do nothing because we've already clipped gradients in `on_before_optimizer_step` hook. + """ + pass + + +class NLPNativeBfloat16PrecisionPlugin(NativeMixedPrecisionPlugin): + def __init__(self) -> None: + super().__init__(precision='bf16') + + def clip_gradients( + self, + optimizer: Optimizer, + clip_val: Union[int, float], + gradient_clip_algorithm: GradClipAlgorithmType, + model: Optional[Module], + ) -> None: + """Override PTL gradient clipping. + Model parallel models require gradient clipping from megatron-lm. + """ + + if clip_val is None: + return + + clip_val = float(clip_val) + if clip_val <= 0: + return + + app_state = AppState() + if app_state.model_parallel_size is not None: + parameters = model.parameters() + clip_grad_norm_fp32(parameters=parameters, max_norm=clip_val) + else: + return super().clip_gradients( + optimizer, clip_val, gradient_clip_algorithm=gradient_clip_algorithm, model=model + ) + + +class NLPPrecisionPlugin(PrecisionPlugin): + def __init__(self) -> None: + super().__init__() + + def clip_gradients( + self, + optimizer: Optimizer, + clip_val: Union[int, float], + gradient_clip_algorithm: GradClipAlgorithmType, + model: Optional[Module], + ) -> None: + """Override PTL gradient clipping. + Model parallel models require gradient clipping from megatron-lm. + """ + + if clip_val is None: + return + + clip_val = float(clip_val) + if clip_val <= 0: + return + + app_state = AppState() + if app_state.model_parallel_size is not None: + parameters = model.parameters() + clip_grad_norm_fp32(parameters=parameters, max_norm=clip_val) + else: + return super().clip_gradients( + optimizer, clip_val, gradient_clip_algorithm=gradient_clip_algorithm, model=model + ) diff --git a/nemo/collections/tts/models/fastspeech2.py b/nemo/collections/tts/models/fastspeech2.py index d93d3b256..926773f43 100644 --- a/nemo/collections/tts/models/fastspeech2.py +++ b/nemo/collections/tts/models/fastspeech2.py @@ -150,7 +150,7 @@ class FastSpeech2Model(SpectrogramGenerator): self.log(name="train_loss", value=total_loss) return {"loss": total_loss, "outputs": [spec, mel]} - def training_epoch_end(self, training_step_outputs): + def training_epoch_end(self, outputs): if self.log_train_images and self.logger is not None and self.logger.experiment is not None: tb_logger = self.logger.experiment if isinstance(self.logger, LoggerCollection): @@ -158,7 +158,7 @@ class FastSpeech2Model(SpectrogramGenerator): if isinstance(logger, TensorBoardLogger): tb_logger = logger.experiment break - spec_target, spec_predict = training_step_outputs[0]["outputs"] + spec_target, spec_predict = outputs[0]["outputs"] tb_logger.add_image( "train_mel_target", plot_spectrogram_to_numpy(spec_target[0].data.cpu().numpy()), @@ -171,6 +171,8 @@ class FastSpeech2Model(SpectrogramGenerator): ) self.log_train_images = False + return super().training_epoch_end(outputs) + def validation_step(self, batch, batch_idx): f, fl, t, tl, _, _, _ = batch spec, spec_len = self.audio_to_melspec_preprocessor(f, fl) diff --git a/nemo/collections/tts/models/hifigan.py b/nemo/collections/tts/models/hifigan.py index c24ca61e2..7a8c5990d 100644 --- a/nemo/collections/tts/models/hifigan.py +++ b/nemo/collections/tts/models/hifigan.py @@ -126,7 +126,7 @@ class HifiGanModel(Vocoder, Exportable): def convert_spectrogram_to_audio(self, spec: 'torch.tensor') -> 'torch.tensor': return self(spec=spec).squeeze(1) - def training_step(self, batch, batch_idx, optimizer_idx): + def training_step(self, batch, batch_idx): # if in finetune mode the mels are pre-computed using a # spectrogram generator if self.input_as_mel: diff --git a/nemo/core/classes/common.py b/nemo/core/classes/common.py index 46d710f03..6bd6058f7 100644 --- a/nemo/core/classes/common.py +++ b/nemo/core/classes/common.py @@ -15,6 +15,7 @@ """Interfaces common to all Neural Modules and Models.""" import hashlib +import inspect import traceback from abc import ABC, abstractmethod from contextlib import contextmanager @@ -429,7 +430,7 @@ class Typing(ABC): class Serialization(ABC): @classmethod - def from_config_dict(cls, config: 'DictConfig'): + def from_config_dict(cls, config: 'DictConfig', trainer: Optional['Trainer'] = None): """Instantiates object using DictConfig-based configuration""" # Resolve the config dict if _HAS_HYDRA: @@ -470,7 +471,12 @@ class Serialization(ABC): imported_cls = cls try: - instance = imported_cls(cfg=config) + accepts_trainer = Serialization._inspect_signature_for_trainer(imported_cls) + if accepts_trainer: + instance = imported_cls(cfg=config, trainer=trainer) + else: + instance = imported_cls(cfg=config) + except Exception as e: imported_cls_tb = traceback.format_exc() instance_init_error = str(e) @@ -484,8 +490,14 @@ class Serialization(ABC): f"Falling back to `cls`.\n" f"{imported_cls_tb}" ) + instance = cls(cfg=config, trainer=trainer) try: - instance = cls(cfg=config) + accepts_trainer = Serialization._inspect_signature_for_trainer(cls) + if accepts_trainer: + instance = cls(cfg=config, trainer=trainer) + else: + instance = cls(cfg=config) + except Exception as e: if imported_cls_tb is not None: logging.error(f"Instance failed restore_from due to: {instance_init_error}") @@ -515,6 +527,17 @@ class Serialization(ABC): 'to_config_dict() can currently only return object._cfg but current object does not have it.' ) + @classmethod + def _inspect_signature_for_trainer(cls, check_cls): + if hasattr(check_cls, '__init__'): + signature = inspect.signature(check_cls.__init__) + if 'trainer' in signature.parameters: + return True + else: + return False + else: + return False + class FileIO(ABC): def save_to(self, save_path: str): @@ -529,6 +552,8 @@ class FileIO(ABC): map_location: Optional['torch.device'] = None, strict: bool = True, return_config: bool = False, + trainer: Optional['Trainer'] = None, + save_restore_connector: SaveRestoreConnector = None, ): """Restores module/model with weights""" raise NotImplementedError() @@ -644,6 +669,7 @@ class Model(Typing, Serialization, FileIO): map_location: Optional['torch.device'] = None, strict: bool = True, return_config: bool = False, + trainer: Optional['Trainer'] = None, save_restore_connector: SaveRestoreConnector = None, ): """ @@ -708,6 +734,7 @@ class Model(Typing, Serialization, FileIO): map_location=map_location, strict=strict, return_config=return_config, + trainer=trainer, save_restore_connector=save_restore_connector, ) return instance diff --git a/nemo/core/classes/modelPT.py b/nemo/core/classes/modelPT.py index 0458d9277..9caef4a62 100644 --- a/nemo/core/classes/modelPT.py +++ b/nemo/core/classes/modelPT.py @@ -201,7 +201,6 @@ class ModelPT(LightningModule, Model): return self._save_restore_connector.register_artifact(self, config_path, src, verify_src_exists) - @rank_zero_only def save_to(self, save_path: str): """ Saves model instance (weights and configuration) into .nemo file @@ -215,12 +214,24 @@ class ModelPT(LightningModule, Model): save_path: Path to .nemo file where model instance should be saved """ + app_state = AppState() # Add NeMo rank check as well - if not is_global_rank_zero(): - return - else: + if app_state.model_parallel_size is not None: + if app_state.model_parallel_size > 1: + if isinstance(self._save_restore_connector, SaveRestoreConnector): + raise ValueError( + 'Default NeMo SaveRestoreConnector will not work in model parallel mode. You should use a connector which supports model parallel mode, such as NLPSaveRestoreConnector in NLP. You can also you custom one.' + ) + save_path = os.path.abspath(os.path.expanduser(save_path)) + # connector checks for ranks properly, no need to check here self._save_restore_connector.save_to(self, save_path) + else: + if not is_global_rank_zero(): + return + else: + save_path = os.path.abspath(os.path.expanduser(save_path)) + self._save_restore_connector.save_to(self, save_path) @classmethod def restore_from( @@ -231,6 +242,7 @@ class ModelPT(LightningModule, Model): strict: bool = True, return_config: bool = False, save_restore_connector: SaveRestoreConnector = None, + trainer: Optional[Trainer] = None, ): """ Restores model instance (weights and configuration) from .nemo file. @@ -244,6 +256,8 @@ class ModelPT(LightningModule, Model): strict: Passed to load_state_dict. By default True. return_config: If set to true, will return just the underlying config of the restored model as an OmegaConf DictConfig object without instantiating the model. + trainer: Optional, a pytorch lightning Trainer object that will be forwarded to the + instantiated model's constructor. save_restore_connector (SaveRestoreConnector): Can be overrided to add custom save and restore logic. Example: @@ -268,7 +282,7 @@ class ModelPT(LightningModule, Model): cls.update_save_restore_connector(save_restore_connector) instance = cls._save_restore_connector.restore_from( - cls, restore_path, override_config_path, map_location, strict, return_config + cls, restore_path, override_config_path, map_location, strict, return_config, trainer ) if isinstance(instance, ModelPT): instance._save_restore_connector = save_restore_connector diff --git a/nemo/core/config/pytorch_lightning.py b/nemo/core/config/pytorch_lightning.py index 51067bb9f..646c697d4 100644 --- a/nemo/core/config/pytorch_lightning.py +++ b/nemo/core/config/pytorch_lightning.py @@ -48,6 +48,7 @@ class TrainerConfig: tpu_cores: Optional[Any] = None log_gpu_memory: Optional[str] = None progress_bar_refresh_rate: int = 1 + enable_progress_bar: bool = True overfit_batches: Any = 0.0 track_grad_norm: Any = -1 check_val_every_n_epoch: int = 1 @@ -65,11 +66,10 @@ class TrainerConfig: log_every_n_steps: int = 50 accelerator: Optional[str] = None sync_batchnorm: bool = False - precision: int = 32 + precision: Any = 32 weights_summary: Optional[str] = "full" # ModelSummary.MODE_DEFAULT weights_save_path: Optional[str] = None num_sanity_val_steps: int = 2 - truncated_bptt_steps: Optional[int] = None resume_from_checkpoint: Optional[str] = None profiler: Optional[Any] = None benchmark: bool = False @@ -77,11 +77,12 @@ class TrainerConfig: reload_dataloaders_every_epoch: bool = False auto_lr_find: Any = False replace_sampler_ddp: bool = True + detect_anomaly: bool = False terminate_on_nan: bool = False auto_scale_batch_size: Any = False prepare_data_per_node: bool = True amp_backend: str = 'native' - amp_level: str = 'O2' # backward compatible, todo: remove in v1.0.0 + amp_level: Optional[str] = None plugins: Optional[Any] = None # Optional[Union[str, list]] move_metrics_to_cpu: bool = False multiple_trainloader_mode: str = 'max_size_cycle' diff --git a/nemo/core/config/schedulers.py b/nemo/core/config/schedulers.py index 846b785d6..effabb660 100644 --- a/nemo/core/config/schedulers.py +++ b/nemo/core/config/schedulers.py @@ -51,6 +51,18 @@ class WarmupHoldSchedulerParams(WarmupSchedulerParams): min_lr: float = 0.0 +@dataclass +class WarmupAnnealingHoldSchedulerParams(WarmupSchedulerParams): + """ + Base configuration for all schedulers. + It is not derived from Config as it is not a NeMo object (and in particular it doesn't need a name). + """ + + constant_steps: Optional[float] = None + constant_ratio: Optional[float] = None + min_lr: float = 0.0 + + @dataclass class SquareAnnealingParams(WarmupSchedulerParams): """ @@ -72,7 +84,7 @@ class SquareRootAnnealingParams(WarmupSchedulerParams): @dataclass -class CosineAnnealingParams(WarmupSchedulerParams): +class CosineAnnealingParams(WarmupAnnealingHoldSchedulerParams): """ Cosine Annealing parameter config It is not derived from Config as it is not a NeMo object (and in particular it doesn't need a name). @@ -98,7 +110,7 @@ class WarmupAnnealingParams(WarmupSchedulerParams): It is not derived from Config as it is not a NeMo object (and in particular it doesn't need a name). """ - warmup_ratio: 0.0 + warmup_ratio: float = 0.0 @dataclass @@ -240,6 +252,7 @@ AVAILABLE_SCHEDULER_PARAMS = { 'SchedulerParams': SchedulerParams, 'WarmupPolicyParams': WarmupSchedulerParams, 'WarmupHoldPolicyParams': WarmupHoldSchedulerParams, + 'WarmupAnnealingHoldSchedulerParams': WarmupAnnealingHoldSchedulerParams, 'SquareAnnealingParams': SquareAnnealingParams, 'SquareRootAnnealingParams': SquareRootAnnealingParams, 'InverseSquareRootAnnealingParams': InverseSquareRootAnnealingParams, diff --git a/nemo/core/connectors/save_restore_connector.py b/nemo/core/connectors/save_restore_connector.py index cf7314b9d..598f9dc52 100644 --- a/nemo/core/connectors/save_restore_connector.py +++ b/nemo/core/connectors/save_restore_connector.py @@ -23,9 +23,11 @@ from typing import Optional, Union import torch from omegaconf import DictConfig, OmegaConf from omegaconf.omegaconf import open_dict +from pytorch_lightning.trainer.trainer import Trainer from nemo.utils import logging, model_utils from nemo.utils.app_state import AppState +from nemo.utils.get_rank import is_global_rank_zero class SaveRestoreConnector: @@ -47,16 +49,19 @@ class SaveRestoreConnector: save_path: Path to .nemo file where model instance should be saved """ - with tempfile.TemporaryDirectory() as tmpdir: - config_yaml = os.path.join(tmpdir, self.model_config_yaml) - model_weights = os.path.join(tmpdir, self.model_weights_ckpt) - model.to_config_file(path2yaml_file=config_yaml) - if hasattr(model, 'artifacts') and model.artifacts is not None: - self._handle_artifacts(model, nemo_file_folder=tmpdir) - # We should not update self._cfg here - the model can still be in use - self._update_artifact_paths(model, path2yaml_file=config_yaml) - self._save_state_dict_to_disk(model.state_dict(), model_weights) - self._make_nemo_file_from_folder(filename=save_path, source_dir=tmpdir) + if is_global_rank_zero(): + with tempfile.TemporaryDirectory() as tmpdir: + config_yaml = os.path.join(tmpdir, self.model_config_yaml) + model_weights = os.path.join(tmpdir, self.model_weights_ckpt) + model.to_config_file(path2yaml_file=config_yaml) + if hasattr(model, 'artifacts') and model.artifacts is not None: + self._handle_artifacts(model, nemo_file_folder=tmpdir) + # We should not update self._cfg here - the model can still be in use + self._update_artifact_paths(model, path2yaml_file=config_yaml) + self._save_state_dict_to_disk(model.state_dict(), model_weights) + self._make_nemo_file_from_folder(filename=save_path, source_dir=tmpdir) + else: + return def restore_from( self, @@ -66,6 +71,7 @@ class SaveRestoreConnector: map_location: Optional[torch.device] = None, strict: bool = True, return_config: bool = False, + trainer: Trainer = None, ): """ Restores model instance (weights and configuration) into .nemo file @@ -125,7 +131,7 @@ class SaveRestoreConnector: return instance else: app_state = AppState() - if app_state.model_parallel_rank is not None: + if app_state.model_parallel_rank is not None and app_state.model_parallel_size > 1: model_weights = self._inject_model_parallel_rank_for_ckpt(tmpdir, self.model_weights_ckpt) else: model_weights = os.path.join(tmpdir, self.model_weights_ckpt) @@ -133,7 +139,7 @@ class SaveRestoreConnector: os.chdir(cwd) # get the class calling_cls._set_model_restore_state(is_being_restored=True, folder=tmpdir) - instance = calling_cls.from_config_dict(config=conf) + instance = calling_cls.from_config_dict(config=conf, trainer=trainer) instance = instance.to(map_location) # add load_state_dict override instance.load_state_dict( @@ -369,6 +375,8 @@ class SaveRestoreConnector: @staticmethod def _make_nemo_file_from_folder(filename, source_dir): + dirname = os.path.dirname(filename) + os.makedirs(dirname, exist_ok=True) with tarfile.open(filename, "w:gz") as tar: tar.add(source_dir, arcname=".") diff --git a/nemo/core/optim/lr_scheduler.py b/nemo/core/optim/lr_scheduler.py index a84be13f6..e6bc82d2e 100644 --- a/nemo/core/optim/lr_scheduler.py +++ b/nemo/core/optim/lr_scheduler.py @@ -160,6 +160,84 @@ class WarmupHoldPolicy(WarmupPolicy): return self._get_lr(step) +class WarmupAnnealHoldPolicy(_LRScheduler): + """Adds warmup kwargs and warmup logic to lr policy. + All arguments should be passed as kwargs for clarity, + Args: + warmup_steps: Number of training steps in warmup stage + warmup_ratio: Ratio of warmup steps to total steps + max_steps: Total number of steps while training or `None` for + infinite training + min_lr: Minimum lr to hold the learning rate after decay at. + constant_steps: Number of steps to keep lr constant at. + constant_ratio: Ratio of steps to keep lr constant. + """ + + def __init__( + self, + optimizer, + *, + warmup_steps=None, + warmup_ratio=None, + constant_steps=None, + constant_ratio=None, + max_steps=None, + min_lr=0.0, + last_epoch=-1, + ): + assert not ( + warmup_steps is not None and warmup_ratio is not None + ), "Either use particular number of step or ratio" + assert warmup_ratio is None or max_steps is not None, "If there is a ratio, there should be a total steps" + + # It is necessary to assign all attributes *before* __init__, + # as class is wrapped by an inner class. + self.max_steps = max_steps + + if warmup_steps is not None: + self.warmup_steps = warmup_steps + elif warmup_ratio is not None: + self.warmup_steps = int(warmup_ratio * max_steps) + else: + self.warmup_steps = 0 + + if constant_steps is not None: + self.constant_steps = constant_steps + self.decay_steps = max_steps - (self.constant_steps + self.warmup_steps) + assert self.decay_steps > 0 + elif constant_ratio is not None: + self.constant_steps = int(constant_ratio * max_steps) + self.decay_steps = max_steps - (self.constant_steps + self.warmup_steps) + assert self.decay_steps > 0 + else: + self.constant_steps = 0 + self.decay_steps = max_steps - self.warmup_steps + + self.min_lr = min_lr + super().__init__(optimizer, last_epoch) + + def get_lr(self): + if not self._get_lr_called_within_step: + warnings.warn( + "To get the last learning rate computed by the scheduler, please use `get_last_lr()`.", UserWarning + ) + + step = self.last_epoch + + if step <= self.warmup_steps: + lr_val = (step + 1) / (self.warmup_steps + 1) + return [initial_lr * lr_val for initial_lr in self.base_lrs] + + if step > self.max_steps: + return [self.min_lr for _ in self.base_lrs] + + return self._get_lr(step) + + def _get_lr(self, step): + """Simple const lr policy""" + return self.base_lrs + + def _squareroot_annealing(initial_lr, step, max_steps, min_lr): mult = ((max_steps - step) / max_steps) ** 0.5 out_lr = initial_lr * mult @@ -180,6 +258,30 @@ def _cosine_annealing(initial_lr, step, max_steps, min_lr): return out_lr +def _linear_warmup_with_cosine_annealing(max_lr, warmup_steps, step, decay_steps, min_lr): + + assert max_lr > min_lr + # Use linear warmup for the initial part. + if warmup_steps > 0 and step <= warmup_steps: + return max_lr * float(step) / float(warmup_steps) + + # For any steps larger than `decay_steps`, use `min_lr`. + if step > warmup_steps + decay_steps: + return min_lr + + # If we are done with the warmup period, use the decay style. + num_steps_ = step - warmup_steps + decay_steps_ = decay_steps + decay_ratio = float(num_steps_) / float(decay_steps_) + assert decay_ratio >= 0.0 + assert decay_ratio <= 1.0 + delta_lr = max_lr - min_lr + + coeff = 0.5 * (math.cos(math.pi * decay_ratio) + 1.0) + + return min_lr + coeff * delta_lr + + def _poly_decay(initial_lr, step, decay_steps, power, min_lr, cycle): if cycle: multiplier = 1.0 if step == 0 else math.ceil(step / decay_steps) @@ -221,7 +323,7 @@ class SquareRootAnnealing(WarmupPolicy): return new_lrs -class CosineAnnealing(WarmupPolicy): +class CosineAnnealing(WarmupAnnealHoldPolicy): def __init__(self, optimizer, *, max_steps, min_lr=0, last_epoch=-1, **kwargs): super().__init__(optimizer=optimizer, max_steps=max_steps, last_epoch=last_epoch, min_lr=min_lr, **kwargs) @@ -232,15 +334,27 @@ class CosineAnnealing(WarmupPolicy): f"{self} received an initial learning rate that was lower than the minimum learning rate." ) - new_lrs = [ - _cosine_annealing( - initial_lr=initial_lr, - step=step - self.warmup_steps, - max_steps=self.max_steps - self.warmup_steps, - min_lr=self.min_lr, - ) - for initial_lr in self.base_lrs - ] + if self.constant_steps is None: + new_lrs = [ + _cosine_annealing( + initial_lr=initial_lr, + step=step - self.warmup_steps, + max_steps=self.max_steps - self.warmup_steps, + min_lr=self.min_lr, + ) + for initial_lr in self.base_lrs + ] + else: + new_lrs = [ + _linear_warmup_with_cosine_annealing( + max_lr=self.base_lrs[0], + warmup_steps=self.warmup_steps, + step=step, + decay_steps=self.decay_steps, + min_lr=self.min_lr, + ) + for _ in self.base_lrs + ] return new_lrs @@ -587,7 +701,17 @@ def prepare_lr_scheduler( # Compute effective num max_steps num_samples = len(train_dataloader.dataset) - batch_size = train_dataloader.batch_size + # TODO: not sure if this will be the correct LR schedule for Megatron + # we may need to override ModelPT setup_optimization + if train_dataloader.batch_size is not None: + batch_size = train_dataloader.batch_size + elif train_dataloader.batch_sampler is not None: + if train_dataloader.batch_sampler.micro_batch_size is not None: + batch_size = train_dataloader.batch_sampler.micro_batch_size + else: + raise ValueError(f'Could not find batch_size from batch_sampler: {train_dataloader.batch_sampler}') + else: + raise ValueError(f'Could not find batch_size from train_dataloader: {train_dataloader}') drop_last = train_dataloader.drop_last max_steps = compute_max_steps( diff --git a/nemo/core/optim/optimizers.py b/nemo/core/optim/optimizers.py index 127e44bee..9d818be58 100644 --- a/nemo/core/optim/optimizers.py +++ b/nemo/core/optim/optimizers.py @@ -17,6 +17,7 @@ from functools import partial from typing import Any, Dict, List, Optional, Union import hydra +import torch import torch.optim as optim from omegaconf import DictConfig, OmegaConf from torch.optim import adadelta, adagrad, adamax, rmsprop, rprop @@ -41,10 +42,12 @@ AVAILABLE_OPTIMIZERS = { try: from apex.optimizers import FusedLAMB + from apex.optimizers import FusedAdam AVAILABLE_OPTIMIZERS['lamb'] = FusedLAMB + AVAILABLE_OPTIMIZERS['fused_adam'] = FusedAdam except ModuleNotFoundError: - logging.warning("Apex was not found. Using the lamb optimizer will error out.") + logging.warning("Apex was not found. Using the lamb or fused_adam optimizer will error out.") __all__ = ['get_optimizer', 'register_optimizer', 'parse_optimizer_args'] @@ -165,6 +168,9 @@ def get_optimizer(name: str, **kwargs: Optional[Dict[str, Any]]) -> Optimizer: raise ValueError( f"Cannot resolve optimizer '{name}'. Available optimizers are : " f"{AVAILABLE_OPTIMIZERS.keys()}" ) + if name == 'fused_adam': + if not torch.cuda.is_available(): + raise ValueError(f'CUDA must be available to use fused_adam.') optimizer = AVAILABLE_OPTIMIZERS[name] optimizer = partial(optimizer, **kwargs) diff --git a/nemo/package_info.py b/nemo/package_info.py index 5112f221b..87e54c55c 100644 --- a/nemo/package_info.py +++ b/nemo/package_info.py @@ -14,9 +14,9 @@ MAJOR = 1 -MINOR = 4 +MINOR = 5 PATCH = 0 -PRE_RELEASE = '' +PRE_RELEASE = 'b1' # Use the following formatting: (major, minor, patch, pre-release) VERSION = (MAJOR, MINOR, PATCH, PRE_RELEASE) diff --git a/nemo/utils/app_state.py b/nemo/utils/app_state.py index eee9a4bd3..22d6717ac 100644 --- a/nemo/utils/app_state.py +++ b/nemo/utils/app_state.py @@ -44,8 +44,10 @@ class AppState(metaclass=Singleton): self._world_size = None self._model_parallel_size = None self._model_parallel_group = None + self._is_megatron_initialized = False self._data_parallel_size = None self._data_parallel_group = None + self._megatron_checkpoint_version = None self._random_seed = None diff --git a/nemo/utils/exp_manager.py b/nemo/utils/exp_manager.py index 776d2be3a..a5a6daa50 100644 --- a/nemo/utils/exp_manager.py +++ b/nemo/utils/exp_manager.py @@ -31,11 +31,10 @@ from pytorch_lightning.callbacks import Callback, ModelCheckpoint from pytorch_lightning.loggers import LoggerCollection as _LoggerCollection from pytorch_lightning.loggers import TensorBoardLogger, WandbLogger from pytorch_lightning.plugins.training_type.ddp import DDPPlugin -from pytorch_lightning.utilities import rank_zero_only from pytorch_lightning.utilities.types import _METRIC from nemo.constants import NEMO_ENV_VARNAME_VERSION -from nemo.utils import app_state, logging, timers +from nemo.utils import logging, timers from nemo.utils.app_state import AppState from nemo.utils.exceptions import NeMoBaseException from nemo.utils.get_rank import is_global_rank_zero @@ -72,7 +71,6 @@ class CallbackParams: save_top_k: Optional[int] = 3 save_weights_only: Optional[bool] = False mode: Optional[str] = "min" - period: Optional[int] = None every_n_val_epochs: Optional[int] = 1 prefix: Optional[str] = None # If None, exp_manager will attempt to handle the filepath postfix: str = ".nemo" @@ -380,6 +378,7 @@ def check_resume( NotFoundError: If resume is True, resume_ignore_no_checkpoint is False, and checkpoints could not be found. ValueError: If resume is True, and there were more than 1 checkpoint could found. """ + if not log_dir: raise ValueError(f"Resuming requires the log_dir {log_dir} to be passed to exp_manager") @@ -707,58 +706,55 @@ class NeMoModelCheckpoint(ModelCheckpoint): for _ in range(models_to_delete): model = best_k_models.pop(-1) self.best_k_models.pop(model) - self._del_model(model) + self._del_model_without_trainer(model) logging.debug(f"Removed checkpoint: {model}") self.kth_best_model_path = best_k_models[-1] self.best_model_path = best_k_models[0] self.best_model_score = self.best_k_models[self.best_model_path] - @rank_zero_only def on_save_checkpoint(self, trainer, pl_module, checkpoint): output = super().on_save_checkpoint(trainer, pl_module, checkpoint) - if not self.always_save_nemo: return output - # Load the best model and then re-save it - app_state = AppState() - # since we are creating tarfile artifacts we need to update .nemo path - app_state.model_restore_path = os.path.abspath( - os.path.expanduser(os.path.join(self.dirpath, self.prefix + self.postfix)) - ) - if self.save_best_model: - if not os.path.exists(self.best_model_path): - return output - - if self.best_model_path == self.previous_best_path: - return output - - self.previous_model_path = self.best_model_path - old_state_dict = deepcopy(pl_module.state_dict()) - checkpoint = torch.load(self.best_model_path, map_location='cpu') - if 'state_dict' in checkpoint: - checkpoint = checkpoint['state_dict'] - # get a new instanace of the model - pl_module.load_state_dict(checkpoint, strict=True) - pl_module.save_to(save_path=app_state.model_restore_path) - pl_module.load_state_dict(old_state_dict, strict=True) else: - pl_module.save_to(save_path=app_state.model_restore_path) - return output + # Load the best model and then re-save it + app_state = AppState() + if app_state.model_parallel_size is not None and app_state.model_parallel_size > 1: + raise ValueError(f'always_save_nemo is not implemented for model parallel models.') + # since we are creating tarfile artifacts we need to update .nemo path + app_state.model_restore_path = os.path.abspath( + os.path.expanduser(os.path.join(self.dirpath, self.prefix + self.postfix)) + ) + if self.save_best_model: + if not os.path.exists(self.best_model_path): + return output + + if self.best_model_path == self.previous_best_path: + return output + + self.previous_model_path = self.best_model_path + old_state_dict = deepcopy(pl_module.state_dict()) + checkpoint = torch.load(self.best_model_path, map_location='cpu') + if 'state_dict' in checkpoint: + checkpoint = checkpoint['state_dict'] + # get a new instanace of the model + pl_module.load_state_dict(checkpoint, strict=True) + pl_module.save_to(save_path=app_state.model_restore_path) + pl_module.load_state_dict(old_state_dict, strict=True) + else: + pl_module.save_to(save_path=app_state.model_restore_path) + return output - @rank_zero_only def on_train_end(self, trainer, pl_module): if trainer.fast_dev_run: return None - app_state = AppState() - if app_state.model_parallel_size is not None: - return None - # TODO: make this work for model parallel, need to call on data parallel rank 0 and update best_model_path # Load the best model and then re-save it if self.save_best_model: trainer.checkpoint_connector.restore(self.best_model_path) + pl_module.save_to(save_path=os.path.join(self.dirpath, self.prefix + self.postfix)) def _del_model(self, trainer: "pl.Trainer", filepath: str) -> None: @@ -768,18 +764,36 @@ class NeMoModelCheckpoint(ModelCheckpoint): app_state = AppState() if app_state.model_parallel_size is not None: # filepath needs to be updated to include mp_rank + # TODO: figure out a good way to update these filepaths dirname = os.path.dirname(filepath) basename = os.path.basename(filepath) filepath = f'{dirname}/mp_rank_{app_state.model_parallel_rank:02d}/{basename}' # each model parallel rank needs to remove its model if app_state.data_parallel_rank == 0: - super()._del_model(trainer, filepath) - logging.info(f"Removed model parallel checkpoint: {filepath}") + if self._fs.exists(filepath): + self._fs.rm(filepath) + logging.info(f"Removed model parallel checkpoint: {filepath}") else: return super()._del_model(trainer, filepath) + def _del_model_without_trainer(self, filepath: str) -> None: + app_state = AppState() + if app_state.model_parallel_size is not None: + # filepath needs to be updated to include mp_rank + dirname = os.path.dirname(filepath) + basename = os.path.basename(filepath) + filepath = f'{dirname}/mp_rank_{app_state.model_parallel_rank:02d}/{basename}' + + # each model parallel rank needs to remove its model + if is_global_rank_zero() or (app_state.model_parallel_size is not None and app_state.data_parallel_rank == 0): + try: + self._fs.rm(filepath) + logging.info(f"Removed checkpoint: {filepath}") + except: + logging.info(f"Tried to remove checkpoint: {filepath} but failed.") + def _save_last_checkpoint(self, trainer: 'pl.Trainer', monitor_candidates: Dict[str, _METRIC]) -> None: """ Overrides PTL method to account for model parallel checkpoints. Checks for data parallel rank 0 rather than global rank 0. @@ -792,11 +806,18 @@ class NeMoModelCheckpoint(ModelCheckpoint): filepath = self._format_checkpoint_name(self.CHECKPOINT_NAME_LAST, monitor_candidates) filepath = os.path.join(self.dirpath, f"{filepath}{self.FILE_EXTENSION}") - self._save_model(trainer, filepath) + trainer.save_checkpoint(filepath) + + # TODO: figure out where self.last_model_path is being set + if self.last_model_path is not None: + if 'mp_rank' in self.last_model_path: + last_model_path = Path(self.last_model_path) + last_model_path = last_model_path.parent.parent.joinpath(last_model_path.name) + self.last_model_path = str(last_model_path) # for model parallel we need to delete models for each model parallel rank if self.last_model_path and self.last_model_path != filepath and app_state.data_parallel_rank == 0: - self._del_model(self.last_model_path) + self._del_model(trainer, self.last_model_path) self.last_model_path = filepath @@ -813,7 +834,8 @@ class NeMoModelCheckpoint(ModelCheckpoint): return filepath = self._get_metric_interpolated_filepath_name(monitor_candidates, trainer) - self._save_model(trainer, filepath) + + trainer.save_checkpoint(filepath) if ( self.save_top_k is None @@ -821,7 +843,7 @@ class NeMoModelCheckpoint(ModelCheckpoint): and self.best_model_path != filepath and app_state.data_parallel_rank == 0 ): - self._del_model(self.best_model_path) + self._del_model(trainer, self.best_model_path) self.best_model_path = filepath else: @@ -887,14 +909,6 @@ def configure_checkpointing( f"{trainer.check_val_every_n_epoch} epochs to ensure that checkpointing will not error out." ) - if params.period is not None: - logging.warning( - "The use of `period` in the checkpoint callback is deprecrated, please use `every_n_val_epochs` instead. " - "Overwriting `every_n_val_epochs` with `period`." - ) - params.every_n_val_epochs = params.period - - # NOTE: checkpoint_callback should be called last checkpoint_callback = NeMoModelCheckpoint(n_resume=resume, **params) checkpoint_callback.last_model_path = trainer.checkpoint_connector.resume_checkpoint_path or "" trainer.callbacks.append(checkpoint_callback) diff --git a/requirements/requirements_lightning.txt b/requirements/requirements_lightning.txt index 8e721beb8..f8480abad 100644 --- a/requirements/requirements_lightning.txt +++ b/requirements/requirements_lightning.txt @@ -1,4 +1,4 @@ -pytorch-lightning>=1.4.8,<1.5 +pytorch-lightning>1.4.9 torchmetrics>=0.4.1rc0 transformers>=4.0.1 webdataset>=0.1.48,<=0.1.62 diff --git a/requirements/requirements_nlp.txt b/requirements/requirements_nlp.txt index 49cde607b..dcc2c5b7d 100644 --- a/requirements/requirements_nlp.txt +++ b/requirements/requirements_nlp.txt @@ -6,7 +6,6 @@ youtokentome>=1.0.5 numpy rapidfuzz gdown -megatron-lm==2.2.0 inflect sacrebleu[ja] sacremoses>=0.0.43 diff --git a/scripts/nlp_language_modeling/preprocess_data_for_megatron.py b/scripts/nlp_language_modeling/preprocess_data_for_megatron.py new file mode 100644 index 000000000..472d1d478 --- /dev/null +++ b/scripts/nlp_language_modeling/preprocess_data_for_megatron.py @@ -0,0 +1,214 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Processing data for megatron pretraining.""" + +import argparse +import json +import multiprocessing +import sys +import time + +import torch + +from nemo.collections.nlp.data.language_modeling.megatron import indexed_dataset +from nemo.collections.nlp.modules.common.tokenizer_utils import get_nmt_tokenizer + +try: + import nltk + + nltk_available = True +except ImportError: + nltk_available = False + +# https://stackoverflow.com/questions/33139531/preserve-empty-lines-with-nltks-punkt-tokenizer +class CustomLanguageVars(nltk.tokenize.punkt.PunktLanguageVars): + + _period_context_fmt = r""" + \S* # some word material + %(SentEndChars)s # a potential sentence ending + \s* # <-- THIS is what I changed + (?=(?P + %(NonWord)s # either other punctuation + | + (?P\S+) # <-- Normally you would have \s+ here + ))""" + + +class IdentitySplitter(object): + def tokenize(self, *text): + return text + + +class Encoder(object): + def __init__(self, args): + self.args = args + + def initializer(self): + # Use Encoder class as a container for global data + Encoder.tokenizer = get_nmt_tokenizer( + library=self.args.tokenizer_library, + model_name=self.args.tokenizer_type, + tokenizer_model=self.args.tokenizer_model, + vocab_file=self.args.vocab_file, + merges_file=self.args.merge_file, + ) + if self.args.split_sentences: + if not nltk_available: + print("NLTK is not available to split sentences.") + exit() + splitter = nltk.load("tokenizers/punkt/english.pickle") + if self.args.keep_newlines: + # this prevents punkt from eating newlines after sentences + Encoder.splitter = nltk.tokenize.punkt.PunktSentenceTokenizer( + train_text=splitter._params, lang_vars=CustomLanguageVars() + ) + else: + Encoder.splitter = splitter + + else: + Encoder.splitter = IdentitySplitter() + + def encode(self, json_line): + data = json.loads(json_line) + ids = {} + for key in self.args.json_keys: + text = data[key] + doc_ids = [] + for sentence in Encoder.splitter.tokenize(text): + sentence_ids = Encoder.tokenizer.text_to_ids(sentence) + if len(sentence_ids) > 0: + doc_ids.append(sentence_ids) + if len(doc_ids) > 0 and self.args.append_eod: + doc_ids[-1].append(Encoder.tokenizer.eos_id) + ids[key] = doc_ids + return ids, len(json_line) + + +def get_args(): + parser = argparse.ArgumentParser() + group = parser.add_argument_group(title='input data') + group.add_argument('--input', type=str, required=True, help='Path to input JSON') + group.add_argument( + '--json-keys', nargs='+', default=['text'], help='space separate listed of keys to extract from json' + ) + group.add_argument('--split-sentences', action='store_true', help='Split documents into sentences.') + group.add_argument('--keep-newlines', action='store_true', help='Keep newlines between sentences when splitting.') + + group = parser.add_argument_group(title='tokenizer') + group.add_argument( + '--tokenizer-library', + type=str, + required=True, + choices=['yttm', 'sentencepiece', 'megatron', 'huggingface'], + help='What tokenizer library to use.', + ) + group.add_argument( + '--tokenizer-type', type=str, default=None, help='What type of tokenizer to use.', + ) + group.add_argument( + '--tokenizer-model', type=str, default=None, help='Path to tokenizer model.', + ) + group.add_argument('--vocab-file', type=str, default=None, help='Path to the vocab file') + group.add_argument('--merge-file', type=str, default=None, help='Path to the BPE merge file (if necessary).') + group.add_argument('--append-eod', action='store_true', help='Append an token to the end of a document.') + + group = parser.add_argument_group(title='output data') + group.add_argument('--output-prefix', type=str, required=True, help='Path to binary output file without suffix') + group.add_argument('--dataset-impl', type=str, default='mmap', choices=['lazy', 'cached', 'mmap']) + + group = parser.add_argument_group(title='runtime') + group.add_argument('--workers', type=int, default=1, help='Number of worker processes to launch') + group.add_argument('--log-interval', type=int, default=100, help='Interval between progress updates') + args = parser.parse_args() + args.keep_empty = False + + if args.tokenizer_type is not None and args.tokenizer_type.lower().startswith('bert'): + if not args.split_sentences: + print("Bert tokenizer detected, are you sure you don't want to split sentences?") + + # some default/dummy values for the tokenizer + args.rank = 0 + args.make_vocab_size_divisible_by = 128 + args.tensor_model_parallel_size = 1 + args.vocab_extra_ids = 0 + # TODO: There are dependencies b/w libraries and model files / tokenizer type strings to check. + assert args.tokenizer_type is not None or args.tokenizer_model is not None + return args + + +def main(): + args = get_args() + startup_start = time.time() + + print("Opening", args.input) + fin = open(args.input, 'r', encoding='utf-8') + + if nltk_available and args.split_sentences: + nltk.download("punkt", quiet=True) + + encoder = Encoder(args) + + tokenizer = get_nmt_tokenizer( + library=args.tokenizer_library, + model_name=args.tokenizer_type, + tokenizer_model=args.tokenizer_model, + vocab_file=args.vocab_file, + merges_file=args.merge_file, + ) + pool = multiprocessing.Pool(args.workers, initializer=encoder.initializer) + encoded_docs = pool.imap(encoder.encode, fin, 25) + # encoded_docs = map(encoder.encode, fin) + + level = "document" + if args.split_sentences: + level = "sentence" + + print(f"Vocab size: {tokenizer.vocab_size}") + print(f"Output prefix: {args.output_prefix}") + output_bin_files = {} + output_idx_files = {} + builders = {} + for key in args.json_keys: + output_bin_files[key] = "{}_{}_{}.bin".format(args.output_prefix, key, level) + output_idx_files[key] = "{}_{}_{}.idx".format(args.output_prefix, key, level) + builders[key] = indexed_dataset.make_builder( + output_bin_files[key], impl=args.dataset_impl, vocab_size=tokenizer.vocab_size + ) + + startup_end = time.time() + proc_start = time.time() + total_bytes_processed = 0 + print("Time to startup:", startup_end - startup_start) + + for i, (doc, bytes_processed) in enumerate(encoded_docs, start=1): + total_bytes_processed += bytes_processed + for key, sentences in doc.items(): + if len(sentences) == 0: + continue + for sentence in sentences: + builders[key].add_item(torch.IntTensor(sentence)) + builders[key].end_document() + if i % args.log_interval == 0: + current = time.time() + elapsed = current - proc_start + mbs = total_bytes_processed / elapsed / 1024 / 1024 + print(f"Processed {i} documents", f"({i/elapsed} docs/s, {mbs} MB/s).", file=sys.stderr) + + for key in args.json_keys: + builders[key].finalize(output_idx_files[key]) + + +if __name__ == '__main__': + main() diff --git a/tests/collections/nlp/test_megatron.py b/tests/collections/nlp/test_megatron.py index 4e7d448a7..59acc0a19 100644 --- a/tests/collections/nlp/test_megatron.py +++ b/tests/collections/nlp/test_megatron.py @@ -39,6 +39,7 @@ def get_pretrained_bert_345m_uncased_model(): class TestMegatron: + @pytest.mark.skip("This test was written for megatron-lm") @pytest.mark.run_only_on('GPU') @pytest.mark.unit def test_list_pretrained_models(self): @@ -62,6 +63,7 @@ class TestMegatron: @pytest.mark.with_downloads() @pytest.mark.run_only_on('GPU') @pytest.mark.unit + @pytest.mark.skip("Megatron-LM BERT support deprecated. Supported in NeMo < 1.5") def test_onnx_export(self): model = get_pretrained_bert_345m_uncased_model() assert model diff --git a/tests/core/test_exp_manager.py b/tests/core/test_exp_manager.py index 4284622ed..6c4a2da07 100644 --- a/tests/core/test_exp_manager.py +++ b/tests/core/test_exp_manager.py @@ -36,7 +36,12 @@ class MyTestOptimizer(torch.optim.Optimizer): self._step = 0 super().__init__(params, {}) - def step(self, *args, **kwargs): + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() for group in self.param_groups: for p in group['params']: if self._step == 0: @@ -46,7 +51,7 @@ class MyTestOptimizer(torch.optim.Optimizer): else: p.data = 0.01 * torch.ones(p.shape) self._step += 1 - return None + return loss class OnesDataset(torch.utils.data.Dataset): @@ -65,6 +70,7 @@ class ExampleModel(ModelPT): def __init__(self, *args, **kwargs): cfg = OmegaConf.structured({}) super().__init__(cfg) + pl.seed_everything(1234) self.l1 = torch.nn.modules.Linear(in_features=2, out_features=1) def train_dataloader(self): diff --git a/tests/core/test_optimizers_schedulers.py b/tests/core/test_optimizers_schedulers.py index 84f7d14a7..e9e2d0011 100644 --- a/tests/core/test_optimizers_schedulers.py +++ b/tests/core/test_optimizers_schedulers.py @@ -111,11 +111,15 @@ class TestOptimizersSchedulers: MIN_LR = 1e-3 MAX_STEPS = 10 + # fused_adam is looking for CUDA and this test is being run on CPU only tests @pytest.mark.unit def test_get_optimizer(self): model = TempModel() for opt_name in AVAILABLE_OPTIMIZERS.keys(): + if opt_name == 'fused_adam': + if not torch.cuda.is_available(): + continue opt_cls = optim.get_optimizer(opt_name) opt = opt_cls(model.parameters(), lr=self.INITIAL_LR)