[BigNLP] Merge Megatron GPT to main (#2975)

* fix gpu init after removing debug print in mpu

Signed-off-by: ericharper <complex451@gmail.com>

* add fused_adam

Signed-off-by: ericharper <complex451@gmail.com>

* check ds is not none before logging len

Signed-off-by: ericharper <complex451@gmail.com>

* set fp16 arg to true and fix enum conflict

Signed-off-by: ericharper <complex451@gmail.com>

* make fp16 arg configurable

Signed-off-by: ericharper <complex451@gmail.com>

* add grad clip from megatron

Signed-off-by: ericharper <complex451@gmail.com>

* Linear warmup with cosine annealing and constant holding (#2846)

* Testing cosine schedule

Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca>

* Style fixes

Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca>

* Fixes

Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca>

* More fixes

Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca>

* update config for constant steps in schedule

Signed-off-by: ericharper <complex451@gmail.com>

* temporarily import enum from megatron

Signed-off-by: ericharper <complex451@gmail.com>

* add grad clip for fp32

Signed-off-by: ericharper <complex451@gmail.com>

* update check for _del_model_without_trainer

Signed-off-by: ericharper <complex451@gmail.com>

* updating restore for model parallel

Signed-off-by: ericharper <complex451@gmail.com>

* add predict script

Signed-off-by: ericharper <complex451@gmail.com>

* update test iters

Signed-off-by: ericharper <complex451@gmail.com>

* add barrier

Signed-off-by: ericharper <complex451@gmail.com>

* return if clip_val is 0 or None

Signed-off-by: ericharper <complex451@gmail.com>

* when using amp clip grads after they are unscaled

Signed-off-by: ericharper <complex451@gmail.com>

* make native amp scaler hyperparams configurable

Signed-off-by: ericharper <complex451@gmail.com>

* (1) nvfuser, (2) amp-casting decoration (#2894)

* (1) nvfuser, (2) amp-casting decoration

Signed-off-by: Sangkug Lym <slym@nvidia.com>

* support bf16

Signed-off-by: Sangkug Lym <slym@nvidia.com>

* update package info

Signed-off-by: ericharper <complex451@gmail.com>

* add set device to constructor

Signed-off-by: ericharper <complex451@gmail.com>

* set_device in constructor

Signed-off-by: ericharper <complex451@gmail.com>

* [BigNLP] Remove megatron-lm dependency. (#2910)

* remove args

Signed-off-by: ericharper <complex451@gmail.com>

* remove args

Signed-off-by: ericharper <complex451@gmail.com>

* remove args

Signed-off-by: ericharper <complex451@gmail.com>

* remove args

Signed-off-by: ericharper <complex451@gmail.com>

* remove args in progress

Signed-off-by: ericharper <complex451@gmail.com>

* remove args in progress

Signed-off-by: ericharper <complex451@gmail.com>

* remove args in progress

Signed-off-by: ericharper <complex451@gmail.com>

* remove args in progress

Signed-off-by: ericharper <complex451@gmail.com>

* add load_fused_kernels

Signed-off-by: ericharper <complex451@gmail.com>

* add load_fused_kernels

Signed-off-by: ericharper <complex451@gmail.com>

* update megatron_init

Signed-off-by: ericharper <complex451@gmail.com>

* add fused kernels

Signed-off-by: ericharper <complex451@gmail.com>

* add fused kernels

Signed-off-by: ericharper <complex451@gmail.com>

* update process batch

Signed-off-by: ericharper <complex451@gmail.com>

* remove erroneous import

Signed-off-by: ericharper <complex451@gmail.com>

* remove erroneous import

Signed-off-by: ericharper <complex451@gmail.com>

* remove erroneous import

Signed-off-by: ericharper <complex451@gmail.com>

* add megatron clip_grad

Signed-off-by: ericharper <complex451@gmail.com>

* trying to resolve circular import error

Signed-off-by: ericharper <complex451@gmail.com>

* rename file

Signed-off-by: ericharper <complex451@gmail.com>

* remove non-gpt models and datasets from __init__ files

Signed-off-by: ericharper <complex451@gmail.com>

* set device in constructorfor gpu init

Signed-off-by: ericharper <complex451@gmail.com>

* set device in constructorfor gpu init

Signed-off-by: ericharper <complex451@gmail.com>

* set_device in constructor

Signed-off-by: ericharper <complex451@gmail.com>

* clean config

Signed-off-by: ericharper <complex451@gmail.com>

* update MegatronDataset

Signed-off-by: ericharper <complex451@gmail.com>

* clean up MegatronModule

Signed-off-by: ericharper <complex451@gmail.com>

* clean up MegatronModule

Signed-off-by: ericharper <complex451@gmail.com>

* rename fp16 and bf16 flags to fused_softmax_input_in_fp16/bf16

Signed-off-by: ericharper <complex451@gmail.com>

* rename to fused_fp16

Signed-off-by: ericharper <complex451@gmail.com>

* add fused_fp16 arg to LayerNorm calls

Signed-off-by: ericharper <complex451@gmail.com>

* fix arg name

Signed-off-by: ericharper <complex451@gmail.com>

* fix arg name

Signed-off-by: ericharper <complex451@gmail.com>

* fix import

Signed-off-by: ericharper <complex451@gmail.com>

* update arg

Signed-off-by: ericharper <complex451@gmail.com>

* skip warmup default to True

Signed-off-by: ericharper <complex451@gmail.com>

* skip warmup default to True

Signed-off-by: ericharper <complex451@gmail.com>

* Adding complete method to MegatronGPTModel (#2935)

Signed-off-by: Oleksii Kuchaiev <okuchaiev@nvidia.com>

* make ffn_hidden_size mandatory

Signed-off-by: ericharper <complex451@gmail.com>

* Manually migrating timing of step into branch (#2937)

* 1. Manually migrating timing of step into branch.

Signed-off-by: Micha Livne <mlivne@nvidia.com>

* 1. Updated file name and content.

Signed-off-by: Micha Livne <mlivne@nvidia.com>

* 1. Updated to latest code.

Signed-off-by: Micha Livne <mlivne@nvidia.com>

Co-authored-by: Micha Livne <mlivne@nvidia.com>

* remove unused imports

Signed-off-by: ericharper <complex451@gmail.com>

* remove unused import

Signed-off-by: ericharper <complex451@gmail.com>

* remove unused import

Signed-off-by: ericharper <complex451@gmail.com>

* remove unused import

Signed-off-by: ericharper <complex451@gmail.com>

* check fused_fp16 and fused_bf16 are not both True

Signed-off-by: ericharper <complex451@gmail.com>

* update predict script for model parallel .nemo

Signed-off-by: ericharper <complex451@gmail.com>

* typo

Signed-off-by: ericharper <complex451@gmail.com>

* typo

Signed-off-by: ericharper <complex451@gmail.com>

Co-authored-by: Oleksii Kuchaiev <okuchaiev@users.noreply.github.com>
Co-authored-by: Micha Livne <michalivne@users.noreply.github.com>
Co-authored-by: Micha Livne <mlivne@nvidia.com>

* NVfuser (#2943)

* activation checkpoint recompute

Signed-off-by: Sangkug Lym <slym@nvidia.com>

* selective nvfuser setup

* Megatron gpt bfloat support (#2926)

* Save/restore fix

Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca>

* Another merge

Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca>

* Bf16 args in init

Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca>

* Set precision

Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca>

* Remove debug stuff

Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca>

* add bf16 casting decorator

Signed-off-by: Sangkug Lym <slym@nvidia.com>

* Bfloat layernorm propagation

Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca>

* activation checkpoint recompute

Signed-off-by: Sangkug Lym <slym@nvidia.com>

* selective nvfuser setup

* More arg removal

Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca>

* Remove BERTDataset

Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca>

* update to latest apex and patch transformer autocast

Signed-off-by: ericharper <complex451@gmail.com>

Co-authored-by: Sangkug Lym <slym@nvidia.com>
Co-authored-by: ericharper <complex451@gmail.com>

* don't set jit for bf16

Signed-off-by: ericharper <complex451@gmail.com>

* replace apex.mpu

Signed-off-by: ericharper <complex451@gmail.com>

* fix grad clip

Signed-off-by: ericharper <complex451@gmail.com>

* NVFuser fixes (#2951)

* Fuser fixes

Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca>

* Remove dummy handler

Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca>

* Remove PTL plugin based logic for fusion

Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca>

* remove duplicated file

Signed-off-by: ericharper <complex451@gmail.com>

* typo (#2960)

Signed-off-by: ericharper <complex451@gmail.com>

* [BigNLP] Script to convert GPT checkpoint to .nemo (#2958)

* remove args

Signed-off-by: ericharper <complex451@gmail.com>

* remove args

Signed-off-by: ericharper <complex451@gmail.com>

* remove args

Signed-off-by: ericharper <complex451@gmail.com>

* remove args

Signed-off-by: ericharper <complex451@gmail.com>

* remove args in progress

Signed-off-by: ericharper <complex451@gmail.com>

* remove args in progress

Signed-off-by: ericharper <complex451@gmail.com>

* remove args in progress

Signed-off-by: ericharper <complex451@gmail.com>

* remove args in progress

Signed-off-by: ericharper <complex451@gmail.com>

* add load_fused_kernels

Signed-off-by: ericharper <complex451@gmail.com>

* add load_fused_kernels

Signed-off-by: ericharper <complex451@gmail.com>

* update megatron_init

Signed-off-by: ericharper <complex451@gmail.com>

* add fused kernels

Signed-off-by: ericharper <complex451@gmail.com>

* add fused kernels

Signed-off-by: ericharper <complex451@gmail.com>

* update process batch

Signed-off-by: ericharper <complex451@gmail.com>

* remove erroneous import

Signed-off-by: ericharper <complex451@gmail.com>

* remove erroneous import

Signed-off-by: ericharper <complex451@gmail.com>

* remove erroneous import

Signed-off-by: ericharper <complex451@gmail.com>

* add megatron clip_grad

Signed-off-by: ericharper <complex451@gmail.com>

* trying to resolve circular import error

Signed-off-by: ericharper <complex451@gmail.com>

* rename file

Signed-off-by: ericharper <complex451@gmail.com>

* remove non-gpt models and datasets from __init__ files

Signed-off-by: ericharper <complex451@gmail.com>

* set device in constructorfor gpu init

Signed-off-by: ericharper <complex451@gmail.com>

* set device in constructorfor gpu init

Signed-off-by: ericharper <complex451@gmail.com>

* set_device in constructor

Signed-off-by: ericharper <complex451@gmail.com>

* clean config

Signed-off-by: ericharper <complex451@gmail.com>

* update MegatronDataset

Signed-off-by: ericharper <complex451@gmail.com>

* clean up MegatronModule

Signed-off-by: ericharper <complex451@gmail.com>

* clean up MegatronModule

Signed-off-by: ericharper <complex451@gmail.com>

* rename fp16 and bf16 flags to fused_softmax_input_in_fp16/bf16

Signed-off-by: ericharper <complex451@gmail.com>

* rename to fused_fp16

Signed-off-by: ericharper <complex451@gmail.com>

* add fused_fp16 arg to LayerNorm calls

Signed-off-by: ericharper <complex451@gmail.com>

* fix arg name

Signed-off-by: ericharper <complex451@gmail.com>

* fix arg name

Signed-off-by: ericharper <complex451@gmail.com>

* fix import

Signed-off-by: ericharper <complex451@gmail.com>

* update arg

Signed-off-by: ericharper <complex451@gmail.com>

* skip warmup default to True

Signed-off-by: ericharper <complex451@gmail.com>

* skip warmup default to True

Signed-off-by: ericharper <complex451@gmail.com>

* Adding complete method to MegatronGPTModel (#2935)

Signed-off-by: Oleksii Kuchaiev <okuchaiev@nvidia.com>

* make ffn_hidden_size mandatory

Signed-off-by: ericharper <complex451@gmail.com>

* Manually migrating timing of step into branch (#2937)

* 1. Manually migrating timing of step into branch.

Signed-off-by: Micha Livne <mlivne@nvidia.com>

* 1. Updated file name and content.

Signed-off-by: Micha Livne <mlivne@nvidia.com>

* 1. Updated to latest code.

Signed-off-by: Micha Livne <mlivne@nvidia.com>

Co-authored-by: Micha Livne <mlivne@nvidia.com>

* remove unused imports

Signed-off-by: ericharper <complex451@gmail.com>

* remove unused import

Signed-off-by: ericharper <complex451@gmail.com>

* remove unused import

Signed-off-by: ericharper <complex451@gmail.com>

* remove unused import

Signed-off-by: ericharper <complex451@gmail.com>

* check fused_fp16 and fused_bf16 are not both True

Signed-off-by: ericharper <complex451@gmail.com>

* update predict script for model parallel .nemo

Signed-off-by: ericharper <complex451@gmail.com>

* typo

Signed-off-by: ericharper <complex451@gmail.com>

* add script to convert .ckpt to .nemo

Signed-off-by: ericharper <complex451@gmail.com>

* in progress

Signed-off-by: ericharper <complex451@gmail.com>

* update

Signed-off-by: ericharper <complex451@gmail.com>

* convert mp checkpoints to nemo

Signed-off-by: ericharper <complex451@gmail.com>

* update help

Signed-off-by: ericharper <complex451@gmail.com>

* add safeguard for model parallel save_to

Signed-off-by: ericharper <complex451@gmail.com>

* adjust NLPModel save_to to be safer for model parallel

Signed-off-by: Oleksii Kuchaiev <okuchaiev@nvidia.com>

Co-authored-by: Oleksii Kuchaiev <okuchaiev@users.noreply.github.com>
Co-authored-by: Micha Livne <michalivne@users.noreply.github.com>
Co-authored-by: Micha Livne <mlivne@nvidia.com>
Co-authored-by: Oleksii Kuchaiev <okuchaiev@nvidia.com>

* [BigNLP] Update GPT evaluation to work with tensor model parallel  (#2959)

* in progress

Signed-off-by: ericharper <complex451@gmail.com>

* update args

Signed-off-by: ericharper <complex451@gmail.com>

* add request dataset

Signed-off-by: ericharper <complex451@gmail.com>

* tokenize request

Signed-off-by: ericharper <complex451@gmail.com>

* in progress

Signed-off-by: ericharper <complex451@gmail.com>

* able to run

Signed-off-by: ericharper <complex451@gmail.com>

* reduce logits

Signed-off-by: ericharper <complex451@gmail.com>

* capture response

Signed-off-by: ericharper <complex451@gmail.com>

* squeeze and unsqueeze

Signed-off-by: ericharper <complex451@gmail.com>

* handle non model parallel case

Signed-off-by: ericharper <complex451@gmail.com>

* clean imports

Signed-off-by: ericharper <complex451@gmail.com>

* add file

Signed-off-by: ericharper <complex451@gmail.com>

* convert logits to log_probs

Signed-off-by: Oleksii Kuchaiev <okuchaiev@nvidia.com>

* rename logits to log_probs

Signed-off-by: Oleksii Kuchaiev <okuchaiev@nvidia.com>

Co-authored-by: Oleksii Kuchaiev <okuchaiev@nvidia.com>

* add megatron gpt pretraining

Signed-off-by: ericharper <complex451@gmail.com>

* add megatron gpt pretraining

Signed-off-by: ericharper <complex451@gmail.com>

* add megatron gpt pretraining

Signed-off-by: ericharper <complex451@gmail.com>

* updating to work with latest megatron

Signed-off-by: ericharper <complex451@gmail.com>

* updating to work with latest megatron

Signed-off-by: ericharper <complex451@gmail.com>

* update _del_model

Signed-off-by: ericharper <complex451@gmail.com>

* adding gpt model

Signed-off-by: ericharper <complex451@gmail.com>

* adding gpt model

Signed-off-by: ericharper <complex451@gmail.com>

* adding gpt model

Signed-off-by: ericharper <complex451@gmail.com>

* instantiate GPTmodel

Signed-off-by: ericharper <complex451@gmail.com>

* adding build dataset

Signed-off-by: ericharper <complex451@gmail.com>

* build megatron dataset in .setup

Signed-off-by: ericharper <complex451@gmail.com>

* setup dataloader

Signed-off-by: ericharper <complex451@gmail.com>

* add vocab_file and merge_file to megatron init

Signed-off-by: ericharper <complex451@gmail.com>

* add forward

Signed-off-by: ericharper <complex451@gmail.com>

* add train loss

Signed-off-by: ericharper <complex451@gmail.com>

* add optimizer

Signed-off-by: ericharper <complex451@gmail.com>

* add exp_manager

Signed-off-by: ericharper <complex451@gmail.com>

* multi-gpu is working

Signed-off-by: ericharper <complex451@gmail.com>

* adding val loop

Signed-off-by: ericharper <complex451@gmail.com>

* style

Signed-off-by: ericharper <complex451@gmail.com>

* adding val loop

Signed-off-by: ericharper <complex451@gmail.com>

* fix ranks

Signed-off-by: ericharper <complex451@gmail.com>

* fix model parallel checkpoint saving

Signed-off-by: ericharper <complex451@gmail.com>

* fix _del_model

Signed-off-by: ericharper <complex451@gmail.com>

* added megatron batch sampler

Signed-off-by: ericharper <complex451@gmail.com>

* try to fix num steps

Signed-off-by: ericharper <complex451@gmail.com>

* add wandb to config

Signed-off-by: ericharper <complex451@gmail.com>

* log lr

Signed-off-by: ericharper <complex451@gmail.com>

* add warmup ratio to config

Signed-off-by: ericharper <complex451@gmail.com>

* update configs

Signed-off-by: ericharper <complex451@gmail.com>

* update configs

Signed-off-by: ericharper <complex451@gmail.com>

* add cpu init to args

Signed-off-by: ericharper <complex451@gmail.com>

* update config

Signed-off-by: ericharper <complex451@gmail.com>

* update config

Signed-off-by: ericharper <complex451@gmail.com>

* Initial megatron dataset port

Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca>

* Fix merge conflicts

Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca>

* License fixes and megatron model porting

Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca>

* Style fixes

Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca>

* More fixes to import from nemo rather than megatron

Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca>

* Fix circular imports

Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca>

* Style fixes

Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca>

* Revert config file

Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca>

* Restructure further to avoid circular imports

Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca>

* add Makefile

Signed-off-by: ericharper <complex451@gmail.com>

* Add megatron modules

Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca>

* add license

Signed-off-by: ericharper <complex451@gmail.com>

* Port from latest megatron

Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca>

* update cfg

Signed-off-by: ericharper <complex451@gmail.com>

* update config

Signed-off-by: ericharper <complex451@gmail.com>

* add _del_model_without_trainer

Signed-off-by: ericharper <complex451@gmail.com>

* add data preprocessing script

Signed-off-by: ericharper <complex451@gmail.com>

* update config

Signed-off-by: ericharper <complex451@gmail.com>

* use apex mpu

Signed-off-by: ericharper <complex451@gmail.com>

* replace print_rank_0 with nemo utils logging

Signed-off-by: ericharper <complex451@gmail.com>

* use apex mpu

Signed-off-by: ericharper <complex451@gmail.com>

* use apex mpu

Signed-off-by: ericharper <complex451@gmail.com>

* add use_cpu_initialization

Signed-off-by: ericharper <complex451@gmail.com>

* fixing autoresume in progress

Signed-off-by: ericharper <complex451@gmail.com>

* properly removing last checkpoint

Signed-off-by: ericharper <complex451@gmail.com>

* log consumed samples

Signed-off-by: ericharper <complex451@gmail.com>

* fix mp autoresume

Signed-off-by: ericharper <complex451@gmail.com>

* add NLPSaveRestoreConnector

Signed-off-by: ericharper <complex451@gmail.com>

* Megatron GPT training with NeMo tokenizers (#2818)

* Update files from megatron repo

Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca>

* Remove non NLP data related files from megatron

Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca>

* Merge megatron and nemo tokenizers

Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca>

* Remove get_tokenizer() calls from gpt model

Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca>

* Update tokenizer yaml config

Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca>

* add todo

Signed-off-by: ericharper <complex451@gmail.com>

* update config

Signed-off-by: ericharper <complex451@gmail.com>

* make init_method_std configurable

Signed-off-by: ericharper <complex451@gmail.com>

* make gpu init work by setting random seed earlier

Signed-off-by: ericharper <complex451@gmail.com>

* fix gpu init after removing debug print in mpu

Signed-off-by: ericharper <complex451@gmail.com>

* add fused_adam

Signed-off-by: ericharper <complex451@gmail.com>

* check ds is not none before logging len

Signed-off-by: ericharper <complex451@gmail.com>

* set fp16 arg to true and fix enum conflict

Signed-off-by: ericharper <complex451@gmail.com>

* make fp16 arg configurable

Signed-off-by: ericharper <complex451@gmail.com>

* add grad clip from megatron

Signed-off-by: ericharper <complex451@gmail.com>

* Linear warmup with cosine annealing and constant holding (#2846)

* Testing cosine schedule

Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca>

* Style fixes

Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca>

* Fixes

Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca>

* More fixes

Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca>

* update config for constant steps in schedule

Signed-off-by: ericharper <complex451@gmail.com>

* temporarily import enum from megatron

Signed-off-by: ericharper <complex451@gmail.com>

* add grad clip for fp32

Signed-off-by: ericharper <complex451@gmail.com>

* update check for _del_model_without_trainer

Signed-off-by: ericharper <complex451@gmail.com>

* updating restore for model parallel

Signed-off-by: ericharper <complex451@gmail.com>

* add predict script

Signed-off-by: ericharper <complex451@gmail.com>

* update test iters

Signed-off-by: ericharper <complex451@gmail.com>

* add barrier

Signed-off-by: ericharper <complex451@gmail.com>

* return if clip_val is 0 or None

Signed-off-by: ericharper <complex451@gmail.com>

* when using amp clip grads after they are unscaled

Signed-off-by: ericharper <complex451@gmail.com>

* make native amp scaler hyperparams configurable

Signed-off-by: ericharper <complex451@gmail.com>

* (1) nvfuser, (2) amp-casting decoration (#2894)

* (1) nvfuser, (2) amp-casting decoration

Signed-off-by: Sangkug Lym <slym@nvidia.com>

* support bf16

Signed-off-by: Sangkug Lym <slym@nvidia.com>

* update package info

Signed-off-by: ericharper <complex451@gmail.com>

* add set device to constructor

Signed-off-by: ericharper <complex451@gmail.com>

* set_device in constructor

Signed-off-by: ericharper <complex451@gmail.com>

* [BigNLP] Remove megatron-lm dependency. (#2910)

* remove args

Signed-off-by: ericharper <complex451@gmail.com>

* remove args

Signed-off-by: ericharper <complex451@gmail.com>

* remove args

Signed-off-by: ericharper <complex451@gmail.com>

* remove args

Signed-off-by: ericharper <complex451@gmail.com>

* remove args in progress

Signed-off-by: ericharper <complex451@gmail.com>

* remove args in progress

Signed-off-by: ericharper <complex451@gmail.com>

* remove args in progress

Signed-off-by: ericharper <complex451@gmail.com>

* remove args in progress

Signed-off-by: ericharper <complex451@gmail.com>

* add load_fused_kernels

Signed-off-by: ericharper <complex451@gmail.com>

* add load_fused_kernels

Signed-off-by: ericharper <complex451@gmail.com>

* update megatron_init

Signed-off-by: ericharper <complex451@gmail.com>

* add fused kernels

Signed-off-by: ericharper <complex451@gmail.com>

* add fused kernels

Signed-off-by: ericharper <complex451@gmail.com>

* update process batch

Signed-off-by: ericharper <complex451@gmail.com>

* remove erroneous import

Signed-off-by: ericharper <complex451@gmail.com>

* remove erroneous import

Signed-off-by: ericharper <complex451@gmail.com>

* remove erroneous import

Signed-off-by: ericharper <complex451@gmail.com>

* add megatron clip_grad

Signed-off-by: ericharper <complex451@gmail.com>

* trying to resolve circular import error

Signed-off-by: ericharper <complex451@gmail.com>

* rename file

Signed-off-by: ericharper <complex451@gmail.com>

* remove non-gpt models and datasets from __init__ files

Signed-off-by: ericharper <complex451@gmail.com>

* set device in constructorfor gpu init

Signed-off-by: ericharper <complex451@gmail.com>

* set device in constructorfor gpu init

Signed-off-by: ericharper <complex451@gmail.com>

* set_device in constructor

Signed-off-by: ericharper <complex451@gmail.com>

* clean config

Signed-off-by: ericharper <complex451@gmail.com>

* update MegatronDataset

Signed-off-by: ericharper <complex451@gmail.com>

* clean up MegatronModule

Signed-off-by: ericharper <complex451@gmail.com>

* clean up MegatronModule

Signed-off-by: ericharper <complex451@gmail.com>

* rename fp16 and bf16 flags to fused_softmax_input_in_fp16/bf16

Signed-off-by: ericharper <complex451@gmail.com>

* rename to fused_fp16

Signed-off-by: ericharper <complex451@gmail.com>

* add fused_fp16 arg to LayerNorm calls

Signed-off-by: ericharper <complex451@gmail.com>

* fix arg name

Signed-off-by: ericharper <complex451@gmail.com>

* fix arg name

Signed-off-by: ericharper <complex451@gmail.com>

* fix import

Signed-off-by: ericharper <complex451@gmail.com>

* update arg

Signed-off-by: ericharper <complex451@gmail.com>

* skip warmup default to True

Signed-off-by: ericharper <complex451@gmail.com>

* skip warmup default to True

Signed-off-by: ericharper <complex451@gmail.com>

* Adding complete method to MegatronGPTModel (#2935)

Signed-off-by: Oleksii Kuchaiev <okuchaiev@nvidia.com>

* make ffn_hidden_size mandatory

Signed-off-by: ericharper <complex451@gmail.com>

* Manually migrating timing of step into branch (#2937)

* 1. Manually migrating timing of step into branch.

Signed-off-by: Micha Livne <mlivne@nvidia.com>

* 1. Updated file name and content.

Signed-off-by: Micha Livne <mlivne@nvidia.com>

* 1. Updated to latest code.

Signed-off-by: Micha Livne <mlivne@nvidia.com>

Co-authored-by: Micha Livne <mlivne@nvidia.com>

* remove unused imports

Signed-off-by: ericharper <complex451@gmail.com>

* remove unused import

Signed-off-by: ericharper <complex451@gmail.com>

* remove unused import

Signed-off-by: ericharper <complex451@gmail.com>

* remove unused import

Signed-off-by: ericharper <complex451@gmail.com>

* check fused_fp16 and fused_bf16 are not both True

Signed-off-by: ericharper <complex451@gmail.com>

* update predict script for model parallel .nemo

Signed-off-by: ericharper <complex451@gmail.com>

* typo

Signed-off-by: ericharper <complex451@gmail.com>

* typo

Signed-off-by: ericharper <complex451@gmail.com>

Co-authored-by: Oleksii Kuchaiev <okuchaiev@users.noreply.github.com>
Co-authored-by: Micha Livne <michalivne@users.noreply.github.com>
Co-authored-by: Micha Livne <mlivne@nvidia.com>

* NVfuser (#2943)

* activation checkpoint recompute

Signed-off-by: Sangkug Lym <slym@nvidia.com>

* selective nvfuser setup

* Megatron gpt bfloat support (#2926)

* Save/restore fix

Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca>

* Another merge

Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca>

* Bf16 args in init

Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca>

* Set precision

Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca>

* Remove debug stuff

Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca>

* add bf16 casting decorator

Signed-off-by: Sangkug Lym <slym@nvidia.com>

* Bfloat layernorm propagation

Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca>

* activation checkpoint recompute

Signed-off-by: Sangkug Lym <slym@nvidia.com>

* selective nvfuser setup

* More arg removal

Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca>

* Remove BERTDataset

Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca>

* update to latest apex and patch transformer autocast

Signed-off-by: ericharper <complex451@gmail.com>

Co-authored-by: Sangkug Lym <slym@nvidia.com>
Co-authored-by: ericharper <complex451@gmail.com>

* don't set jit for bf16

Signed-off-by: ericharper <complex451@gmail.com>

* replace apex.mpu

Signed-off-by: ericharper <complex451@gmail.com>

* fix grad clip

Signed-off-by: ericharper <complex451@gmail.com>

* NVFuser fixes (#2951)

* Fuser fixes

Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca>

* Remove dummy handler

Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca>

* Remove PTL plugin based logic for fusion

Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca>

* remove duplicated file

Signed-off-by: ericharper <complex451@gmail.com>

* typo (#2960)

Signed-off-by: ericharper <complex451@gmail.com>

* [BigNLP] Script to convert GPT checkpoint to .nemo (#2958)

* remove args

Signed-off-by: ericharper <complex451@gmail.com>

* remove args

Signed-off-by: ericharper <complex451@gmail.com>

* remove args

Signed-off-by: ericharper <complex451@gmail.com>

* remove args

Signed-off-by: ericharper <complex451@gmail.com>

* remove args in progress

Signed-off-by: ericharper <complex451@gmail.com>

* remove args in progress

Signed-off-by: ericharper <complex451@gmail.com>

* remove args in progress

Signed-off-by: ericharper <complex451@gmail.com>

* remove args in progress

Signed-off-by: ericharper <complex451@gmail.com>

* add load_fused_kernels

Signed-off-by: ericharper <complex451@gmail.com>

* add load_fused_kernels

Signed-off-by: ericharper <complex451@gmail.com>

* update megatron_init

Signed-off-by: ericharper <complex451@gmail.com>

* add fused kernels

Signed-off-by: ericharper <complex451@gmail.com>

* add fused kernels

Signed-off-by: ericharper <complex451@gmail.com>

* update process batch

Signed-off-by: ericharper <complex451@gmail.com>

* remove erroneous import

Signed-off-by: ericharper <complex451@gmail.com>

* remove erroneous import

Signed-off-by: ericharper <complex451@gmail.com>

* remove erroneous import

Signed-off-by: ericharper <complex451@gmail.com>

* add megatron clip_grad

Signed-off-by: ericharper <complex451@gmail.com>

* trying to resolve circular import error

Signed-off-by: ericharper <complex451@gmail.com>

* rename file

Signed-off-by: ericharper <complex451@gmail.com>

* remove non-gpt models and datasets from __init__ files

Signed-off-by: ericharper <complex451@gmail.com>

* set device in constructorfor gpu init

Signed-off-by: ericharper <complex451@gmail.com>

* set device in constructorfor gpu init

Signed-off-by: ericharper <complex451@gmail.com>

* set_device in constructor

Signed-off-by: ericharper <complex451@gmail.com>

* clean config

Signed-off-by: ericharper <complex451@gmail.com>

* update MegatronDataset

Signed-off-by: ericharper <complex451@gmail.com>

* clean up MegatronModule

Signed-off-by: ericharper <complex451@gmail.com>

* clean up MegatronModule

Signed-off-by: ericharper <complex451@gmail.com>

* rename fp16 and bf16 flags to fused_softmax_input_in_fp16/bf16

Signed-off-by: ericharper <complex451@gmail.com>

* rename to fused_fp16

Signed-off-by: ericharper <complex451@gmail.com>

* add fused_fp16 arg to LayerNorm calls

Signed-off-by: ericharper <complex451@gmail.com>

* fix arg name

Signed-off-by: ericharper <complex451@gmail.com>

* fix arg name

Signed-off-by: ericharper <complex451@gmail.com>

* fix import

Signed-off-by: ericharper <complex451@gmail.com>

* update arg

Signed-off-by: ericharper <complex451@gmail.com>

* skip warmup default to True

Signed-off-by: ericharper <complex451@gmail.com>

* skip warmup default to True

Signed-off-by: ericharper <complex451@gmail.com>

* Adding complete method to MegatronGPTModel (#2935)

Signed-off-by: Oleksii Kuchaiev <okuchaiev@nvidia.com>

* make ffn_hidden_size mandatory

Signed-off-by: ericharper <complex451@gmail.com>

* Manually migrating timing of step into branch (#2937)

* 1. Manually migrating timing of step into branch.

Signed-off-by: Micha Livne <mlivne@nvidia.com>

* 1. Updated file name and content.

Signed-off-by: Micha Livne <mlivne@nvidia.com>

* 1. Updated to latest code.

Signed-off-by: Micha Livne <mlivne@nvidia.com>

Co-authored-by: Micha Livne <mlivne@nvidia.com>

* remove unused imports

Signed-off-by: ericharper <complex451@gmail.com>

* remove unused import

Signed-off-by: ericharper <complex451@gmail.com>

* remove unused import

Signed-off-by: ericharper <complex451@gmail.com>

* remove unused import

Signed-off-by: ericharper <complex451@gmail.com>

* check fused_fp16 and fused_bf16 are not both True

Signed-off-by: ericharper <complex451@gmail.com>

* update predict script for model parallel .nemo

Signed-off-by: ericharper <complex451@gmail.com>

* typo

Signed-off-by: ericharper <complex451@gmail.com>

* add script to convert .ckpt to .nemo

Signed-off-by: ericharper <complex451@gmail.com>

* in progress

Signed-off-by: ericharper <complex451@gmail.com>

* update

Signed-off-by: ericharper <complex451@gmail.com>

* convert mp checkpoints to nemo

Signed-off-by: ericharper <complex451@gmail.com>

* update help

Signed-off-by: ericharper <complex451@gmail.com>

* add safeguard for model parallel save_to

Signed-off-by: ericharper <complex451@gmail.com>

* adjust NLPModel save_to to be safer for model parallel

Signed-off-by: Oleksii Kuchaiev <okuchaiev@nvidia.com>

Co-authored-by: Oleksii Kuchaiev <okuchaiev@users.noreply.github.com>
Co-authored-by: Micha Livne <michalivne@users.noreply.github.com>
Co-authored-by: Micha Livne <mlivne@nvidia.com>
Co-authored-by: Oleksii Kuchaiev <okuchaiev@nvidia.com>

* [BigNLP] Update GPT evaluation to work with tensor model parallel  (#2959)

* in progress

Signed-off-by: ericharper <complex451@gmail.com>

* update args

Signed-off-by: ericharper <complex451@gmail.com>

* add request dataset

Signed-off-by: ericharper <complex451@gmail.com>

* tokenize request

Signed-off-by: ericharper <complex451@gmail.com>

* in progress

Signed-off-by: ericharper <complex451@gmail.com>

* able to run

Signed-off-by: ericharper <complex451@gmail.com>

* reduce logits

Signed-off-by: ericharper <complex451@gmail.com>

* capture response

Signed-off-by: ericharper <complex451@gmail.com>

* squeeze and unsqueeze

Signed-off-by: ericharper <complex451@gmail.com>

* handle non model parallel case

Signed-off-by: ericharper <complex451@gmail.com>

* clean imports

Signed-off-by: ericharper <complex451@gmail.com>

* add file

Signed-off-by: ericharper <complex451@gmail.com>

* convert logits to log_probs

Signed-off-by: Oleksii Kuchaiev <okuchaiev@nvidia.com>

* rename logits to log_probs

Signed-off-by: Oleksii Kuchaiev <okuchaiev@nvidia.com>

Co-authored-by: Oleksii Kuchaiev <okuchaiev@nvidia.com>

* style

Signed-off-by: ericharper <complex451@gmail.com>

* fix copyright headers

Signed-off-by: ericharper <complex451@gmail.com>

* fix copyright headers

Signed-off-by: ericharper <complex451@gmail.com>

* remove old TimingCallback

Signed-off-by: ericharper <complex451@gmail.com>

* style

Signed-off-by: ericharper <complex451@gmail.com>

* update jenkins to use latest apex and sandeep's fork

Signed-off-by: ericharper <complex451@gmail.com>

* update jenkins

Signed-off-by: ericharper <complex451@gmail.com>

* update jenkins

Signed-off-by: ericharper <complex451@gmail.com>

* update jenkins

Signed-off-by: ericharper <complex451@gmail.com>

* update jenkins

Signed-off-by: ericharper <complex451@gmail.com>

* try 2109 container

Signed-off-by: ericharper <complex451@gmail.com>

* try cuda container

Signed-off-by: ericharper <complex451@gmail.com>

* use internal container

Signed-off-by: ericharper <complex451@gmail.com>

* update checkpoint tests

Signed-off-by: ericharper <complex451@gmail.com>

* fix scheduler args

Signed-off-by: ericharper <complex451@gmail.com>

* update eval

Signed-off-by: ericharper <complex451@gmail.com>

* style

Signed-off-by: ericharper <complex451@gmail.com>

* update jenkins to use ptl 1.5 rc

Signed-off-by: ericharper <complex451@gmail.com>

* add import guard to jenkins

Signed-off-by: ericharper <complex451@gmail.com>

* add import guard to jenkins

Signed-off-by: ericharper <complex451@gmail.com>

* remove deterministic

Signed-off-by: ericharper <complex451@gmail.com>

* install numba .53

Signed-off-by: ericharper <complex451@gmail.com>

* allow for more variance

Signed-off-by: ericharper <complex451@gmail.com>

* update trainer config dataclass

Signed-off-by: ericharper <complex451@gmail.com>

* test_get_optimizer on gpu

Signed-off-by: ericharper <complex451@gmail.com>

* revert comment

Signed-off-by: ericharper <complex451@gmail.com>

* change trainer config default to 32

Signed-off-by: ericharper <complex451@gmail.com>

* [BigNLP] Remove fused kernel code instead use Apex (#2984)

* remove fused_kernels

Signed-off-by: ericharper <complex451@gmail.com>

* remove fused_kernels

Signed-off-by: ericharper <complex451@gmail.com>

* remove fused layer norm and fused softmax and use apex instead

Signed-off-by: ericharper <complex451@gmail.com>

* update imports

Signed-off-by: ericharper <complex451@gmail.com>

* remove comment

Signed-off-by: ericharper <complex451@gmail.com>

* use apex enums

Signed-off-by: ericharper <complex451@gmail.com>

* use apex enums

Signed-off-by: ericharper <complex451@gmail.com>

* add tab

Signed-off-by: ericharper <complex451@gmail.com>

* Timer with sliding window (#3002)

Co-authored-by: Micha Livne <michalivne@users.noreply.github.com>

* revert tab

Signed-off-by: ericharper <complex451@gmail.com>

* check for rank zero

Signed-off-by: ericharper <complex451@gmail.com>

* check for rank zero

Signed-off-by: ericharper <complex451@gmail.com>

* try explicit log dir

Signed-off-by: ericharper <complex451@gmail.com>

* add +

Signed-off-by: ericharper <complex451@gmail.com>

* don't rm

Signed-off-by: ericharper <complex451@gmail.com>

* make dir if it doesn't exist

Signed-off-by: ericharper <complex451@gmail.com>

* create mp nemo file in temp directory

Signed-off-by: ericharper <complex451@gmail.com>

* simplify mp save_to

Signed-off-by: ericharper <complex451@gmail.com>

* handle mp 1 case

Signed-off-by: ericharper <complex451@gmail.com>

* style fix

Signed-off-by: ericharper <complex451@gmail.com>

* remove files

Signed-off-by: ericharper <complex451@gmail.com>

* fix consumed_samples when resuming

Signed-off-by: ericharper <complex451@gmail.com>

* fix reinstall.sh

Signed-off-by: ericharper <complex451@gmail.com>

* update req

Signed-off-by: ericharper <complex451@gmail.com>

* add more detailed log for dataloaders

Signed-off-by: ericharper <complex451@gmail.com>

* check if cuda is available before using fused_adam

Signed-off-by: ericharper <complex451@gmail.com>

* revert comment

Signed-off-by: ericharper <complex451@gmail.com>

* update eval script to use model.freeze

Signed-off-by: ericharper <complex451@gmail.com>

* log train loss averaged over gradient accumulation steps

Signed-off-by: ericharper <complex451@gmail.com>

* check copyright earlier

Signed-off-by: ericharper <complex451@gmail.com>

* todo

Signed-off-by: ericharper <complex451@gmail.com>

* override SaveRestoreConnector in NLPModel init

Signed-off-by: ericharper <complex451@gmail.com>

* move to scripts

Signed-off-by: ericharper <complex451@gmail.com>

* remove star import

Signed-off-by: ericharper <complex451@gmail.com>

* remove comments

Signed-off-by: ericharper <complex451@gmail.com>

* remove unused dataset

Signed-off-by: ericharper <complex451@gmail.com>

* removed barrier

Signed-off-by: ericharper <complex451@gmail.com>

* check cfg

Signed-off-by: ericharper <complex451@gmail.com>

* remove logging

Signed-off-by: ericharper <complex451@gmail.com>

* freeze, unfreeze

Signed-off-by: ericharper <complex451@gmail.com>

* return None

Signed-off-by: ericharper <complex451@gmail.com>

* remove unused imports

Signed-off-by: ericharper <complex451@gmail.com>

* add TODO

Signed-off-by: ericharper <complex451@gmail.com>

* typecheck

Signed-off-by: ericharper <complex451@gmail.com>

* typo

Signed-off-by: ericharper <complex451@gmail.com>

* todo

Signed-off-by: ericharper <complex451@gmail.com>

* add common native plugin

Signed-off-by: ericharper <complex451@gmail.com>

* restore with trainer

Signed-off-by: ericharper <complex451@gmail.com>

* style

Signed-off-by: ericharper <complex451@gmail.com>

* deprecate megatron-lm bert

Signed-off-by: ericharper <complex451@gmail.com>

* deprecate megatron-lm bert

Signed-off-by: ericharper <complex451@gmail.com>

* compile helpers ont he fly

Signed-off-by: ericharper <complex451@gmail.com>

* remove amp_level

Signed-off-by: ericharper <complex451@gmail.com>

* remove amp_level from configs

Signed-off-by: ericharper <complex451@gmail.com>

* add missing import

Signed-off-by: ericharper <complex451@gmail.com>

* typo

Signed-off-by: ericharper <complex451@gmail.com>

* remove amp_level

Signed-off-by: ericharper <complex451@gmail.com>

* use fast huggingface tokenizers by default

Signed-off-by: ericharper <complex451@gmail.com>

* deal with huggingface tokenizer positional args

Signed-off-by: ericharper <complex451@gmail.com>

* deal with huggingface tokenizer positional args

Signed-off-by: ericharper <complex451@gmail.com>

* deal with huggingface tokenizer positional args

Signed-off-by: ericharper <complex451@gmail.com>

* revert use_fast default to False

Signed-off-by: ericharper <complex451@gmail.com>

* return super training_epoch_end

Signed-off-by: ericharper <complex451@gmail.com>

* remove optimizer_idx arg from training_step

Signed-off-by: ericharper <complex451@gmail.com>

* remove unused arg from on_train_epoch_end

Signed-off-by: ericharper <complex451@gmail.com>

* add restore_from_path to nemo config

Signed-off-by: ericharper <complex451@gmail.com>

* add comment

Signed-off-by: ericharper <complex451@gmail.com>

* revert

Signed-off-by: ericharper <complex451@gmail.com>

* override connector if not subclassing NLPSaveRestoreConnector for model parallel save

Signed-off-by: ericharper <complex451@gmail.com>

* update test optimizer

Signed-off-by: ericharper <complex451@gmail.com>

* clean up

Signed-off-by: ericharper <complex451@gmail.com>

* clean up

Signed-off-by: ericharper <complex451@gmail.com>

* clean up

Signed-off-by: ericharper <complex451@gmail.com>

* clean up

Signed-off-by: ericharper <complex451@gmail.com>

* make data_prefix mandatory in config

Signed-off-by: ericharper <complex451@gmail.com>

* update installation instructions on readme

Signed-off-by: ericharper <complex451@gmail.com>

* update dockerfile

Signed-off-by: ericharper <complex451@gmail.com>

* add todo

Signed-off-by: ericharper <complex451@gmail.com>

* raise error if trying to use always_save_nemo with model parallel model

Signed-off-by: ericharper <complex451@gmail.com>

* remove comment

Signed-off-by: ericharper <complex451@gmail.com>

Co-authored-by: Sandeep Subramanian <sandeep.subramanian.1@umontreal.ca>
Co-authored-by: Sangkug Lym <slym@nvidia.com>
Co-authored-by: Oleksii Kuchaiev <okuchaiev@users.noreply.github.com>
Co-authored-by: Micha Livne <michalivne@users.noreply.github.com>
Co-authored-by: Micha Livne <mlivne@nvidia.com>
Co-authored-by: Oleksii Kuchaiev <okuchaiev@nvidia.com>
This commit is contained in:
Eric Harper 2021-10-20 21:06:37 -06:00 committed by GitHub
parent 620e8a8986
commit 32fa5cfaf3
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
106 changed files with 11014 additions and 865 deletions

View file

@ -14,7 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
ARG BASE_IMAGE=nvcr.io/nvidia/pytorch:21.08-py3
ARG BASE_IMAGE=nvcr.io/nvidia/pytorch:21.10-py3
# build an image that includes only the nemo dependencies, ensures that dependencies

403
Jenkinsfile vendored
View file

@ -1,7 +1,7 @@
pipeline {
agent {
docker {
image 'nvcr.io/nvidia/pytorch:21.08-py3'
image 'gitlab-master.nvidia.com/dl/dgx/pytorch:21.10-py3-devel'
args '--device=/dev/nvidia0 --gpus all --user 0:128 -v /home/TestData:/home/TestData -v $HOME/.cache/torch:/root/.cache/torch --shm-size=8g'
}
}
@ -10,6 +10,7 @@ pipeline {
disableConcurrentBuilds()
}
stages {
stage('PyTorch version') {
steps {
sh 'python -c "import torch; print(torch.__version__)"'
@ -23,6 +24,12 @@ pipeline {
}
}
stage('Code formatting checks') {
steps {
sh 'python setup.py style'
}
}
stage('Copyright Headers check') {
steps {
sh 'python tests/check_copyright_header.py --dir .'
@ -35,12 +42,6 @@ pipeline {
}
}
stage('Code formatting checks') {
steps {
sh 'python setup.py style'
}
}
stage('Torch TTS unit tests') {
when {
anyOf {
@ -58,7 +59,14 @@ pipeline {
}
}
stage('Installation') {
// Revert once import guards are added by PTL or version comparing is fixed
stage('Install PyTorch Lighting 1.5 RC') {
steps{
sh 'pip install pytorch-lightning==1.5.0rc0 && sed -i "s/from pytorch_lightning.callbacks.quantization import QuantizationAwareTraining/try:\\n\\tfrom pytorch_lightning.callbacks.quantization import QuantizationAwareTraining\\nexcept:\\n\\tpass/g" /opt/conda/lib/python3.8/site-packages/pytorch_lightning/callbacks/__init__.py'
}
}
stage('NeMo Installation') {
steps {
sh './reinstall.sh release'
}
@ -605,30 +613,31 @@ pipeline {
}
}
stage('L2: Multi-GPU Megatron finetuning') {
when {
anyOf {
branch 'main'
changeRequest target: 'main'
}
}
failFast true
parallel {
stage('L2: Cased Megatron finetuning on MRPC') {
steps {
sh 'cd examples/nlp/glue_benchmark && \
python glue_benchmark.py \
model.dataset.data_dir=/home/TestData/nlp/glue_fake/MRPC \
trainer.gpus=[0,1] \
+trainer.fast_dev_run=true \
model.dataset.use_cache=false \
model.language_model.pretrained_model_name=megatron-bert-345m-cased \
trainer.accelerator=ddp \
exp_manager=null'
}
}
}
}
// TODO: add test once megatron-bert is supported again
// stage('L2: Multi-GPU Megatron finetuning') {
// when {
// anyOf {
// branch 'main'
// changeRequest target: 'main'
// }
// }
// failFast true
// parallel {
// stage('L2: Cased Megatron finetuning on MRPC') {
// steps {
// sh 'cd examples/nlp/glue_benchmark && \
// python glue_benchmark.py \
// model.dataset.data_dir=/home/TestData/nlp/glue_fake/MRPC \
// trainer.gpus=[0,1] \
// +trainer.fast_dev_run=true \
// model.dataset.use_cache=false \
// model.language_model.pretrained_model_name=megatron-bert-345m-cased \
// trainer.accelerator=ddp \
// exp_manager=null'
// }
// }
// }
// }
stage('L2: SGD-QA') {
when {
@ -725,7 +734,6 @@ pipeline {
model.language_model.pretrained_model_name=bert-base-uncased \
model.dataset.version_2_with_negative=false \
trainer.precision=16 \
trainer.amp_level=O1 \
trainer.gpus=[0] \
exp_manager=null'
}
@ -747,7 +755,6 @@ pipeline {
model.language_model.pretrained_model_name=bert-base-uncased \
model.dataset.version_2_with_negative=true \
trainer.precision=16 \
trainer.amp_level=O1 \
trainer.gpus=[1] \
exp_manager=null'
}
@ -777,30 +784,31 @@ pipeline {
}
}
// Runs out of memory on the 12G TITAN V (GPU 0 on main CI)
stage('L2: MegaBERT Token Classification') {
when {
anyOf {
branch 'main'
changeRequest target: 'main'
}
}
failFast true
steps {
sh 'cd examples/nlp/token_classification && \
python token_classification_train.py \
model.dataset.data_dir=/home/TestData/nlp/token_classification_punctuation/ \
model.language_model.pretrained_model_name=megatron-bert-345m-uncased \
model.train_ds.batch_size=10 \
model.dataset.max_seq_length=50 \
model.dataset.use_cache=false \
trainer.accelerator=ddp \
trainer.precision=16 \
trainer.amp_level=O1 \
trainer.gpus=[1] \
+trainer.fast_dev_run=true \
exp_manager=null'
}
}
// TODO: add when megatron bert is supported again in NeMo
// stage('L2: MegaBERT Token Classification') {
// when {
// anyOf {
// branch 'main'
// changeRequest target: 'main'
// }
// }
// failFast true
// steps {
// sh 'cd examples/nlp/token_classification && \
// python token_classification_train.py \
// model.dataset.data_dir=/home/TestData/nlp/token_classification_punctuation/ \
// model.language_model.pretrained_model_name=megatron-bert-345m-uncased \
// model.train_ds.batch_size=10 \
// model.dataset.max_seq_length=50 \
// model.dataset.use_cache=false \
// trainer.accelerator=ddp \
// trainer.precision=16 \
// trainer.gpus=[1] \
// +trainer.fast_dev_run=true \
// exp_manager=null'
// }
// }
stage('L2: Parallel SQUAD v1.1 & v2.0') {
when {
anyOf {
@ -810,8 +818,13 @@ pipeline {
}
failFast true
parallel {
stage('SQUAD v2.0 with Megatron with ckpt & config') {
// TODO: use megatron bert when supported again
stage('SQUAD v2.0 with DistilBERT Uncased') {
// stage('SQUAD v2.0 with Megatron with ckpt & config') {
// Cannot do fast_dev_run because squad needs whole dev dataset
// model.language_model.pretrained_model_name=megatron-bert-uncased \
// model.language_model.lm_checkpoint=/home/TestData/nlp/megatron_345m_uncased/model_optim_rng.pt \
// model.language_model.config_file=/home/TestData/nlp/megatron_345m_uncased/345m_config.json \
steps {
sh 'cd examples/nlp/question_answering && \
python question_answering_squad.py \
@ -825,12 +838,9 @@ pipeline {
trainer.max_epochs=1 \
+trainer.max_steps=1 \
model.validation_ds.file=/home/TestData/nlp/squad_mini/v2.0/dev-v2.0.json \
model.language_model.pretrained_model_name=megatron-bert-uncased \
model.language_model.lm_checkpoint=/home/TestData/nlp/megatron_345m_uncased/model_optim_rng.pt \
model.language_model.config_file=/home/TestData/nlp/megatron_345m_uncased/345m_config.json \
model.language_model.pretrained_model_name=distilbert-base-uncased \
model.dataset.version_2_with_negative=true \
trainer.precision=16 \
trainer.amp_level=O1 \
trainer.gpus=[1] \
exp_manager=null'
}
@ -852,7 +862,6 @@ pipeline {
model.language_model.pretrained_model_name=roberta-base \
model.dataset.version_2_with_negative=false \
trainer.precision=16 \
trainer.amp_level=O1 \
trainer.gpus=[0] \
exp_manager=null'
}
@ -864,7 +873,7 @@ pipeline {
model.dataset.num_classes=6 \
model.train_ds.file_path=/home/TestData/nlp/retail_text_classification/train.tsv \
model.validation_ds.file_path=/home/TestData/nlp/retail_text_classification/dev.tsv \
model.language_model.pretrained_model_name=bert-base-uncased \
model.language_model.pretrained_model_name=distilbert-base-uncased \
model.train_ds.batch_size=10 \
model.dataset.max_seq_length=50 \
model.dataset.use_cache=false \
@ -889,107 +898,106 @@ pipeline {
}
}
stage('L2: Model Parallel Size 2 Megatron Text Classification') {
when {
anyOf{
branch 'main'
changeRequest target: 'main'
}
}
failFast true
steps{
sh 'cd examples/nlp/text_classification && \
python text_classification_with_bert.py \
trainer.gpus=[0,1] \
trainer.num_nodes=1 \
trainer.precision=16 \
trainer.gradient_clip_val=1.0 \
~trainer.amp_level \
+trainer.fast_dev_run=true \
model.dataset.num_classes=6 \
model.train_ds.file_path=/home/TestData/nlp/retail_text_classification/train.tsv \
model.train_ds.batch_size=4 \
model.language_model.pretrained_model_name=megatron-bert-uncased \
model.language_model.config_file=/home/TestData/nlp/mp_2_bert_toy/config.json \
model.language_model.lm_checkpoint=/home/TestData/nlp/mp_2_bert_toy/iter_2000000 \
model.nemo_path=null \
~model.infer_samples \
exp_manager=null'
}
}
// TODO: add when megatron-bert is supported again
// stage('L2: Model Parallel Size 2 Megatron Text Classification') {
// when {
// anyOf{
// branch 'main'
// changeRequest target: 'main'
// }
// }
// failFast true
// steps{
// sh 'cd examples/nlp/text_classification && \
// python text_classification_with_bert.py \
// trainer.gpus=[0,1] \
// trainer.num_nodes=1 \
// trainer.precision=16 \
// trainer.gradient_clip_val=1.0 \
// +trainer.fast_dev_run=true \
// model.dataset.num_classes=6 \
// model.train_ds.file_path=/home/TestData/nlp/retail_text_classification/train.tsv \
// model.train_ds.batch_size=4 \
// model.language_model.pretrained_model_name=megatron-bert-uncased \
// model.language_model.config_file=/home/TestData/nlp/mp_2_bert_toy/config.json \
// model.language_model.lm_checkpoint=/home/TestData/nlp/mp_2_bert_toy/iter_2000000 \
// model.nemo_path=null \
// ~model.infer_samples \
// exp_manager=null'
// }
// }
stage('L2: Model Parallel Size 2 Megatron Autoresume') {
when {
anyOf{
branch 'main'
changeRequest target: 'main'
}
}
failFast true
steps{
sh 'cd examples/nlp/text_classification && \
python text_classification_with_bert.py \
trainer.gpus=[0,1] \
trainer.num_nodes=1 \
trainer.precision=16 \
trainer.gradient_clip_val=1.0 \
trainer.max_epochs=1 \
~trainer.amp_level \
+trainer.fast_dev_run=true \
model.dataset.num_classes=6 \
model.train_ds.file_path=/home/TestData/nlp/retail_text_classification/train.tsv \
model.train_ds.batch_size=4 \
model.language_model.pretrained_model_name=megatron-bert-uncased \
model.language_model.config_file=/home/TestData/nlp/mp_2_bert_toy/config.json \
model.language_model.lm_checkpoint=/home/TestData/nlp/mp_2_bert_toy/iter_2000000 \
model.nemo_path=null \
~model.infer_samples \
+exp_manager.explicit_log_dir=/home/TestData/nlp/mp_autoresume \
+exp_manager.resume_if_exists=true'
}
}
// stage('L2: Model Parallel Size 2 Megatron Autoresume') {
// when {
// anyOf{
// branch 'main'
// changeRequest target: 'main'
// }
// }
// failFast true
// steps{
// sh 'cd examples/nlp/text_classification && \
// python text_classification_with_bert.py \
// trainer.gpus=[0,1] \
// trainer.num_nodes=1 \
// trainer.precision=16 \
// trainer.gradient_clip_val=1.0 \
// trainer.max_epochs=1 \
// +trainer.fast_dev_run=true \
// model.dataset.num_classes=6 \
// model.train_ds.file_path=/home/TestData/nlp/retail_text_classification/train.tsv \
// model.train_ds.batch_size=4 \
// model.language_model.pretrained_model_name=megatron-bert-uncased \
// model.language_model.config_file=/home/TestData/nlp/mp_2_bert_toy/config.json \
// model.language_model.lm_checkpoint=/home/TestData/nlp/mp_2_bert_toy/iter_2000000 \
// model.nemo_path=null \
// ~model.infer_samples \
// +exp_manager.explicit_log_dir=/home/TestData/nlp/mp_autoresume \
// +exp_manager.resume_if_exists=true'
// }
// }
stage('L2: Model Parallel Size 2 Megatron Evaluation from .nemo') {
when {
anyOf{
branch 'main'
changeRequest target: 'main'
}
}
failFast true
steps{
sh 'cd examples/nlp/text_classification && \
python model_parallel_text_classification_evaluation.py \
trainer.gpus=[0,1] \
trainer.num_nodes=1 \
model.dataset.num_classes=6 \
model.test_ds.file_path=/home/TestData/nlp/retail_text_classification/dev.tsv \
model.nemo_path=/home/TestData/nlp/mp_2_nemo/retail_text_class_350M.nemo \
exp_manager=null'
}
}
// stage('L2: Model Parallel Size 2 Megatron Evaluation from .nemo') {
// when {
// anyOf{
// branch 'main'
// changeRequest target: 'main'
// }
// }
// failFast true
// steps{
// sh 'cd examples/nlp/text_classification && \
// python model_parallel_text_classification_evaluation.py \
// trainer.gpus=[0,1] \
// trainer.num_nodes=1 \
// model.dataset.num_classes=6 \
// model.test_ds.file_path=/home/TestData/nlp/retail_text_classification/dev.tsv \
// model.nemo_path=/home/TestData/nlp/mp_2_nemo/retail_text_class_350M.nemo \
// exp_manager=null'
// }
// }
stage('L2: Model Parallel Size 2 Megatron Train from .nemo') {
when {
anyOf{
branch 'main'
changeRequest target: 'main'
}
}
failFast true
steps{
sh 'cd examples/nlp/token_classification && \
python token_classification_train.py \
pretrained_model=/home/TestData/nlp/mp_2_nemo/ner_350M.nemo \
model.dataset.data_dir=/home/TestData/nlp/ner/ \
model.train_ds.batch_size=2 \
model.dataset.use_cache=false \
trainer.gpus=[0,1] \
+trainer.fast_dev_run=true \
model.dataset.class_balancing="weighted_loss" \
exp_manager=null'
}
}
// stage('L2: Model Parallel Size 2 Megatron Train from .nemo') {
// when {
// anyOf{
// branch 'main'
// changeRequest target: 'main'
// }
// }
// failFast true
// steps{
// sh 'cd examples/nlp/token_classification && \
// python token_classification_train.py \
// pretrained_model=/home/TestData/nlp/mp_2_nemo/ner_350M.nemo \
// model.dataset.data_dir=/home/TestData/nlp/ner/ \
// model.train_ds.batch_size=2 \
// model.dataset.use_cache=false \
// trainer.gpus=[0,1] \
// +trainer.fast_dev_run=true \
// model.dataset.class_balancing="weighted_loss" \
// exp_manager=null'
// }
// }
stage('L2: Parallel NLP Examples 2') {
when {
@ -1089,7 +1097,6 @@ pipeline {
--config-name=bert_pretraining_from_text_config.yaml \
trainer.gpus=[0] \
trainer.precision=16 \
trainer.amp_level=O1 \
+trainer.fast_dev_run=true \
model.train_ds.data_file=/home/TestData/nlp/wikitext-2/train.txt \
model.train_ds.batch_size=32 \
@ -1116,7 +1123,6 @@ pipeline {
--config-name=bert_pretraining_from_preprocessed_config.yaml \
trainer.gpus=[1] \
trainer.precision=16 \
trainer.amp_level=O1 \
+trainer.fast_dev_run=true \
model.train_ds.data_file=/home/TestData/nlp/wiki_book_mini/training \
model.train_ds.batch_size=8 \
@ -1351,39 +1357,40 @@ pipeline {
}
}
stage('L2: NMT Megatron BERT Model Parallel Size 2 Encoder') {
when {
anyOf{
branch 'main'
changeRequest target: 'main'
}
}
failFast true
steps{
sh 'cd examples/nlp/machine_translation && \
python enc_dec_nmt.py \
--config-path=conf \
--config-name=megatron \
model.encoder.model_name=megatron-bert-uncased \
model.encoder.checkpoint_file=/home/TestData/nlp/mp_2_bert_toy/iter_2000000 \
model.encoder.hidden_size=1024 \
model.encoder.num_attention_heads=16 \
model.encoder.num_layers=24 \
model.encoder.max_position_embeddings=512 \
model.train_ds.src_file_name=/home/TestData/nlp/nmt/toy_data/wmt14-de-en.src \
model.train_ds.tgt_file_name=/home/TestData/nlp/nmt/toy_data/wmt14-de-en.ref \
model.validation_ds.src_file_name=/home/TestData/nlp/nmt/toy_data/wmt14-de-en.src \
model.validation_ds.tgt_file_name=/home/TestData/nlp/nmt/toy_data/wmt14-de-en.src \
model.test_ds.src_file_name=/home/TestData/nlp/nmt/toy_data/wmt14-de-en.src \
model.test_ds.tgt_file_name=/home/TestData/nlp/nmt/toy_data/wmt14-de-en.src \
model.decoder_tokenizer.tokenizer_model=/home/TestData/nlp/nmt/toy_data/tt_tokenizer.BPE.4096.model \
model.decoder.hidden_size=1024 \
trainer.gpus=[0,1] \
+trainer.fast_dev_run=true \
exp_manager=null \
'
}
}
// TODO: add when megatron bert is supported again in NeMo
// stage('L2: NMT Megatron BERT Model Parallel Size 2 Encoder') {
// when {
// anyOf{
// branch 'main'
// changeRequest target: 'main'
// }
// }
// failFast true
// steps{
// sh 'cd examples/nlp/machine_translation && \
// python enc_dec_nmt.py \
// --config-path=conf \
// --config-name=megatron \
// model.encoder.model_name=megatron-bert-uncased \
// model.encoder.checkpoint_file=/home/TestData/nlp/mp_2_bert_toy/iter_2000000 \
// model.encoder.hidden_size=1024 \
// model.encoder.num_attention_heads=16 \
// model.encoder.num_layers=24 \
// model.encoder.max_position_embeddings=512 \
// model.train_ds.src_file_name=/home/TestData/nlp/nmt/toy_data/wmt14-de-en.src \
// model.train_ds.tgt_file_name=/home/TestData/nlp/nmt/toy_data/wmt14-de-en.ref \
// model.validation_ds.src_file_name=/home/TestData/nlp/nmt/toy_data/wmt14-de-en.src \
// model.validation_ds.tgt_file_name=/home/TestData/nlp/nmt/toy_data/wmt14-de-en.src \
// model.test_ds.src_file_name=/home/TestData/nlp/nmt/toy_data/wmt14-de-en.src \
// model.test_ds.tgt_file_name=/home/TestData/nlp/nmt/toy_data/wmt14-de-en.src \
// model.decoder_tokenizer.tokenizer_model=/home/TestData/nlp/nmt/toy_data/tt_tokenizer.BPE.4096.model \
// model.decoder.hidden_size=1024 \
// trainer.gpus=[0,1] \
// +trainer.fast_dev_run=true \
// exp_manager=null \
// '
// }
// }
stage('L2: NMT Tarred Dataset Creation') {
when {

View file

@ -85,7 +85,7 @@ Requirements
------------
1) Python 3.6, 3.7 or 3.8
2) Pytorch 1.8.1 or above
2) Pytorch 1.10.0 or above
3) NVIDIA GPU for training
Documentation
@ -163,16 +163,28 @@ Note that RNNT requires numba to be installed from conda.
pip uninstall numba
conda install -c numba numba
Megatron GPT
~~~~~~~~~~~~
Megatron GPT training requires NVIDIA Apex to be installed.
.. code-block:: bash
git clone https://github.com/NVIDIA/apex
cd apex
pip install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./
Docker containers:
~~~~~~~~~~~~~~~~~~
If you chose to work with main branch, we recommend using NVIDIA's PyTorch container version 21.08-py3 and then installing from GitHub.
If you chose to work with main branch, we recommend using NVIDIA's PyTorch container version 21.10-py3 and then installing from GitHub.
Note NVIDIA's PyTorch 21.10-py3 has not yet been released publicy. Please use a container with the nightly version of PyTorch installed if you are
unable to access the NVIDIA's PyTorch 21.10 container.
.. code-block:: bash
docker run --gpus all -it --rm -v <nemo_github_folder>:/NeMo --shm-size=8g \
-p 8888:8888 -p 6006:6006 --ulimit memlock=-1 --ulimit \
stack=67108864 --device=/dev/snd nvcr.io/nvidia/pytorch:21.08-py3
stack=67108864 --device=/dev/snd nvcr.io/nvidia/pytorch:21.10-py3
Examples
--------

View file

@ -156,7 +156,6 @@ trainer:
accelerator: ddp
accumulate_grad_batches: 1
gradient_clip_val: 0.0
amp_level: O0 # O1/O2 for mixed precision
precision: 32 # Should be set to 16 for O1 and O2 to enable the AMP.
log_every_n_steps: 10 # Interval of logging.
progress_bar_refresh_rate: 10

View file

@ -131,7 +131,6 @@ trainer:
accelerator: ddp
accumulate_grad_batches: 1
gradient_clip_val: 0.0
amp_level: O0 # O1/O2 for mixed precision
precision: 32 # Should be set to 16 for O1 and O2 to enable the AMP.
log_every_n_steps: 10 # Interval of logging.
progress_bar_refresh_rate: 10

View file

@ -206,7 +206,6 @@ trainer:
accelerator: ddp
accumulate_grad_batches: 1
gradient_clip_val: 0.0
amp_level: O0 # O1/O2 for mixed precision
precision: 32 # Should be set to 16 for O1 and O2 to enable the AMP.
log_every_n_steps: 10 # Interval of logging.
progress_bar_refresh_rate: 10

View file

@ -201,7 +201,6 @@ trainer:
accelerator: ddp
accumulate_grad_batches: 1
gradient_clip_val: 0.0
amp_level: O0 # O1/O2 for mixed precision
precision: 32 # Should be set to 16 for O1 and O2 to enable the AMP.
log_every_n_steps: 10 # Interval of logging.
progress_bar_refresh_rate: 10

View file

@ -100,7 +100,6 @@ trainer:
accelerator: ddp
accumulate_grad_batches: 1
gradient_clip_val: 0.0
amp_level: O0 # O1/O2 for mixed precision
precision: 32 # Should be set to 16 for O1 and O2 to enable the AMP.
log_every_n_steps: 10 # Interval of logging.
resume_from_checkpoint: null # The path to a checkpoint file to continue the training, restores the whole state including the epoch, step, LR schedulers, apex, etc.

View file

@ -100,7 +100,6 @@ trainer:
accelerator: ddp
accumulate_grad_batches: 1
gradient_clip_val: 0.0
amp_level: O0 # O1/O2 for mixed precision
precision: 32 # Should be set to 16 for O1 and O2 to enable the AMP.
log_every_n_steps: 10 # Interval of logging.
resume_from_checkpoint: null # The path to a checkpoint file to continue the training, restores the whole state including the epoch, step, LR schedulers, apex, etc.

View file

@ -22,7 +22,6 @@ trainer:
max_steps: null # precedence over max_epochs
accumulate_grad_batches: 1 # accumulates grads every k batches
gradient_clip_val: 1.0
amp_level: O1 # O1/O2 for mixed precision
precision: 16 # Should be set to 16 for O1 and O2 to enable the AMP.
accelerator: ddp
log_every_n_steps: 5 # Interval of logging.

View file

@ -15,7 +15,6 @@ tagger_trainer:
logger: false # provided by exp_manager
accumulate_grad_batches: 1 # accumulates grads every k batches
gradient_clip_val: 0.0
amp_level: O0 # O1/O2 for mixed precision
precision: 32 # Should be set to 16 for O1 and O2 to enable the AMP.
accelerator: ddp
@ -66,7 +65,6 @@ decoder_trainer:
logger: false # provided by exp_manager
accumulate_grad_batches: 1 # accumulates grads every k batches
gradient_clip_val: 0.0
amp_level: O0 # O1/O2 for mixed precision
precision: 32 # Should be set to 16 for O1 and O2 to enable the AMP.
accelerator: ddp
log_every_n_steps: 1 # Interval of logging.

View file

@ -7,7 +7,6 @@ trainer:
max_steps: null
accumulate_grad_batches: 1
precision: 16
amp_level: O1
accelerator: ddp
gradient_clip_val: 0.0
log_every_n_steps: 1

View file

@ -7,7 +7,6 @@ trainer:
max_steps: null
accumulate_grad_batches: 1
precision: 16
amp_level: O1
accelerator: ddp
gradient_clip_val: 0.0
log_every_n_steps: 1

View file

@ -7,7 +7,6 @@ trainer:
max_epochs: 3
max_steps: null # precedence over max_epochs
accumulate_grad_batches: 1 # accumulates grads every k batches
amp_level: O0 # O1/O2 for mixed precision
precision: 16
accelerator: ddp
checkpoint_callback: False # Provided by exp_manager

View file

@ -7,7 +7,6 @@ trainer:
max_steps: null # precedence over max_epochs
accumulate_grad_batches: 1 # accumulates grads every k batches
precision: 16 # 16 to use AMP
amp_level: O1 # O1 or O2 if using AMP
accelerator: ddp
log_every_n_steps: 1 # Interval of logging.
val_check_interval: 0.05 # Set to 0.25 to check 4 times per epoch, or an int for number of iterations

View file

@ -6,7 +6,6 @@ trainer:
max_epochs: 50
max_steps: null # precedence over max_epochs
accumulate_grad_batches: 1 # accumulates grads every k batches
amp_level: O0 # O1/O2 for mixed precision
precision: 32 # Should be set to 16 for O1 and O2 amp_level to enable the AMP.
accelerator: ddp
log_every_n_steps: 1 # Interval of logging.

View file

@ -8,7 +8,6 @@ trainer:
replace_sampler_ddp: false # needed for bert pretraining from preproc
accumulate_grad_batches: 1 # accumulates grads every k batches
precision: 16 # 16 to use AMP
amp_level: O1 # O1 or O2 if using AMP
accelerator: ddp
gradient_clip_val: 1.0
log_every_n_steps: 1

View file

@ -7,7 +7,6 @@ trainer:
max_steps: null # precedence over max_epochs
accumulate_grad_batches: 1 # accumulates grads every k batches
precision: 16 # 16 to use AMP
amp_level: O1 # O1 or O2 if using AMP
accelerator: ddp
gradient_clip_val: 0.0
log_every_n_steps: 1

View file

@ -0,0 +1,117 @@
name: megatron_gpt
restore_from_path: null # used when starting from a .nemo file
trainer:
gpus: 1
num_nodes: 1
accelerator: ddp
precision: 16
logger: False # logger provided by exp_manager
checkpoint_callback: False
replace_sampler_ddp: False
max_epochs: null
max_steps: 100000 # consumed_samples = global_step * micro_batch_size * data_parallel_size * accumulate_grad_batches
log_every_n_steps: 10
val_check_interval: 100
limit_val_batches: 50
limit_test_batches: 500
accumulate_grad_batches: 1
gradient_clip_val: 1.0
exp_manager:
explicit_log_dir: null
exp_dir: null
name: megatron_gpt
create_wandb_logger: True
wandb_logger_kwargs:
project: null
name: null
resume_if_exists: True
resume_ignore_no_checkpoint: True
create_checkpoint_callback: True
checkpoint_callback_params:
monitor: val_loss
save_top_k: 10
mode: min
always_save_nemo: False # saves nemo file during validation, not implemented for model parallel
filename: 'megatron_gpt--{val_loss:.2f}-{step}-{consumed_samples}'
model:
# model parallelism
micro_batch_size: 4
tensor_model_parallel_size: 1
# model architecture
encoder_seq_length: 512
max_position_embeddings: 512
num_layers: 12
hidden_size: 768
ffn_hidden_size: 3072 # Transformer FFN hidden size. Usually 4 * hidden_size.
num_attention_heads: 12
init_method_std: 0.02 # Standard deviation of the zero mean normal distribution used for weight initialization.')
hidden_dropout: 0.1 # Dropout probability for hidden state transformer.
kv_channels: null # Projection weights dimension in multi-head attention. Set to hidden_size // num_attention_heads if null
apply_query_key_layer_scaling: True # scale Q * K^T by 1 / layer-number.
layernorm_epsilon: 1e-5
make_vocab_size_divisible_by: 128 # Pad the vocab size to be divisible by this value for computation efficiency.
pre_process: True # add embedding
post_process: True # add pooler
tokenizer:
library: 'megatron'
type: 'GPT2BPETokenizer'
model: null
vocab_file: null
merge_file: null
# precision
native_amp_init_scale: 4294967296 # 2 ** 32
native_amp_growth_interval: 1000
fused_fp16: True # False if using fp32 or bf16
fused_bf16: False # True if using bf16
fp32_residual_connection: False # Move residual connections to fp32
fp16_lm_cross_entropy: False # Move the cross entropy unreduced loss calculation for lm head to fp16
# miscellaneous
seed: 1234
use_cpu_initialization: False # Init weights on the CPU (slow for large models)
onnx_safe: False # Use work-arounds for known problems with Torch ONNX exporter.
# not implemented in NeMo yet
activations_checkpoint_method: null # 'uniform', 'block'
activations_checkpoint_num_layers: 1
data:
# Path to data must be specified by the user.
# can override from the CLI: "model.data.data_prefix=[.5,/raid/data/pile/my-gpt3_00_text_document,.5,/raid/data/pile/my-gpt3_01_text_document]",
# Or see example below:
# data_prefix:
# - .5
# - /raid/data/pile/my-gpt3_00_text_document
# - .5
# - /raid/data/pile/my-gpt3_01_text_document
data_prefix: ???
data_impl: mmap
splits_string: 900,50,50
seq_length: 1024
skip_warmup: True
num_workers: 0
dataloader_type: single # cyclic
reset_position_ids: False # Reset position ids after end-of-document token
reset_attention_mask: False # Reset attention mask after end-of-document token
eod_mask_loss: False # Mask loss for the end of document tokens
optim:
name: adam
lr: 2e-4
weight_decay: 0.01
betas:
- 0.9
- 0.98
sched:
name: CosineAnnealing
warmup_steps: 500
constant_steps: 50000
min_lr: 2e-5

View file

@ -88,7 +88,6 @@ trainer:
gpus: 4
num_nodes: 1
max_epochs: 200
amp_level: O2 # O1/O2 for mixed precision
precision: 16 # Should be set to 16 for O1 and O2, default is 16 as PT ignores it when am_level is O0
accelerator: ddp
checkpoint_callback: False

View file

@ -0,0 +1,80 @@
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from argparse import ArgumentParser
import torch.multiprocessing as mp
from pytorch_lightning.trainer.trainer import Trainer
from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel
from nemo.collections.nlp.parts.nlp_overrides import NLPSaveRestoreConnector
from nemo.utils import AppState, logging
def get_args():
parser = ArgumentParser()
parser.add_argument(
"--checkpoint_folder",
type=str,
default=None,
required=True,
help="Path to PTL checkpoints saved during training. Ex: /raid/nemo_experiments/megatron_gpt/checkpoints",
)
parser.add_argument(
"--checkpoint_name",
type=str,
default=None,
required=True,
help="Name of checkpoint to be used. Ex: megatron_gpt--val_loss=6.34-step=649-last.ckpt",
)
parser.add_argument(
"--hparams_file",
type=str,
default=None,
required=False,
help="Path config for restoring. It's created during training and may need to be modified during restore if restore environment is different than training. Ex: /raid/nemo_experiments/megatron_gpt/hparams.yaml",
)
parser.add_argument("--nemo_file_path", type=str, default=None, required=True, help="Path to output .nemo file.")
parser.add_argument("--tensor_model_parallel_size", type=int, required=True, default=None)
args = parser.parse_args()
return args
def convert(rank, world_size, args):
app_state = AppState()
app_state.data_parallel_rank = 0
trainer = Trainer(gpus=args.tensor_model_parallel_size)
# TODO: reach out to PTL For an API-safe local rank override
trainer.accelerator.training_type_plugin._local_rank = rank
checkpoint_path = os.path.join(args.checkpoint_folder, f'mp_rank_{rank:02d}', args.checkpoint_name)
model = MegatronGPTModel.load_from_checkpoint(checkpoint_path, hparams_file=args.hparams_file, trainer=trainer)
model._save_restore_connector = NLPSaveRestoreConnector()
model.save_to(args.nemo_file_path)
logging.info(f'NeMo model saved to: {args.nemo_file_path}')
def main() -> None:
args = get_args()
world_size = args.tensor_model_parallel_size
mp.spawn(convert, args=(world_size, args), nprocs=world_size, join=True)
if __name__ == '__main__':
main() # noqa pylint: disable=no-value-for-parameter

View file

@ -0,0 +1,90 @@
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from argparse import ArgumentParser
import torch
from pytorch_lightning.trainer.trainer import Trainer
from torch.utils.data import DataLoader
from nemo.collections.nlp.data.language_modeling.megatron.gpt_request_dataset import GPTRequestDataset
from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel
from nemo.collections.nlp.modules.common.megatron.megatron_utils import compute_model_parallel_rank
from nemo.collections.nlp.parts.nlp_overrides import NLPDDPPlugin
from nemo.utils import logging
from nemo.utils.app_state import AppState
assert torch.cuda.is_available()
def main():
parser = ArgumentParser()
parser.add_argument("--model_file", type=str, default="", required=True, help="Pass path to model's .nemo file")
parser.add_argument(
"--prompt", type=str, default="", required=True, help="Prompt for the model (a text to complete)"
)
parser.add_argument(
"--tokens_to_generate", type=int, default="64", required=False, help="How many tokens to add to prompt"
)
parser.add_argument(
"--stop_after_sentence",
type=bool,
default="True",
required=False,
help="True/False: whether to stop after full sentence has been generated.",
)
parser.add_argument(
"--tensor_model_parallel_size", type=int, default=1, required=True,
)
parser.add_argument("--precision", default=32, help="PyTorch Lightning Trainer precision flag")
args = parser.parse_args()
# cast precision to int if 32 or 16
if args.precision in ["32", "16"]:
args.precision = int(float(args.precision))
# trainer required for restoring model parallel models
trainer = Trainer(plugins=NLPDDPPlugin(), gpus=args.tensor_model_parallel_size, precision=args.precision)
app_state = AppState()
if args.tensor_model_parallel_size is not None and args.tensor_model_parallel_size > 1:
app_state.model_parallel_size = args.tensor_model_parallel_size
app_state.model_parallel_rank = compute_model_parallel_rank(trainer.local_rank, app_state.model_parallel_size)
model = MegatronGPTModel.restore_from(restore_path=args.model_file, trainer=trainer)
model.freeze()
request = {
"prompt": args.prompt,
"tokens_to_generate": args.tokens_to_generate,
"stop_after_sentence": args.stop_after_sentence,
}
dataset = GPTRequestDataset(request, model.tokenizer)
request_dl = DataLoader(dataset)
response = trainer.predict(model, request_dl)
print("***************************")
print(response[0]['completion']['text'])
print("***************************")
logging.info(f"Generation stopped because: {response[0]['completion']['stop reason']}")
if __name__ == '__main__':
main() # noqa pylint: disable=no-value-for-parameter

View file

@ -0,0 +1,79 @@
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pathlib import Path
from omegaconf.omegaconf import OmegaConf
from pytorch_lightning import Trainer
from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel
from nemo.collections.nlp.modules.common.megatron.megatron_utils import compute_model_parallel_rank
from nemo.collections.nlp.parts.nlp_overrides import (
NLPCheckpointConnector,
NLPDDPPlugin,
NLPNativeBfloat16PrecisionPlugin,
NLPNativeMixedPrecisionPlugin,
NLPPrecisionPlugin,
)
from nemo.core.config import hydra_runner
from nemo.utils import logging
from nemo.utils.exp_manager import exp_manager
@hydra_runner(config_path="conf", config_name="megatron_gpt_config")
def main(cfg) -> None:
logging.info("\n\n************** Experiment configuration ***********")
logging.info(f'\n{OmegaConf.to_yaml(cfg)}')
if cfg.trainer.precision == 16:
trainer = Trainer(
plugins=[
NLPDDPPlugin(num_nodes=cfg.trainer.num_nodes),
NLPNativeMixedPrecisionPlugin(
init_scale=cfg.model.get('native_amp_init_scale', 2 ** 32),
growth_interval=cfg.model.get('native_amp_growth_interval', 1000),
),
],
**cfg.trainer,
)
elif cfg.trainer.precision == 'bf16':
trainer = Trainer(
plugins=[NLPDDPPlugin(num_nodes=cfg.trainer.num_nodes), NLPNativeBfloat16PrecisionPlugin(),],
**cfg.trainer,
)
else:
trainer = Trainer(plugins=[NLPDDPPlugin(num_nodes=cfg.trainer.num_nodes), NLPPrecisionPlugin()], **cfg.trainer)
exp_manager(trainer, cfg.exp_manager)
# update resume from checkpoint found by exp_manager
resume_from_checkpoint = trainer.resume_from_checkpoint
if resume_from_checkpoint is not None:
mp_rank = compute_model_parallel_rank(trainer.local_rank, cfg.model.tensor_model_parallel_size)
resume_from_checkpoint = Path(resume_from_checkpoint)
resume_from_checkpoint = resume_from_checkpoint.parent.parent.joinpath(f'mp_rank_{mp_rank:02d}').joinpath(
resume_from_checkpoint.name
)
resume_from_checkpoint = str(resume_from_checkpoint)
logging.info(f'Resuming training from checkpoint: {resume_from_checkpoint}')
trainer.checkpoint_connector = NLPCheckpointConnector(trainer, resume_from_checkpoint=resume_from_checkpoint)
model = MegatronGPTModel(cfg.model, trainer)
trainer.fit(model)
if __name__ == '__main__':
main()

View file

@ -0,0 +1,72 @@
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from omegaconf.omegaconf import OmegaConf
from pytorch_lightning import Trainer
from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel
from nemo.collections.nlp.modules.common.megatron.megatron_utils import compute_model_parallel_rank
from nemo.collections.nlp.parts.nlp_overrides import (
NLPDDPPlugin,
NLPNativeMixedPrecisionPlugin,
NLPPrecisionPlugin,
NLPSaveRestoreConnector,
)
from nemo.core.config import hydra_runner
from nemo.utils import logging
from nemo.utils.app_state import AppState
@hydra_runner(config_path="conf", config_name="megatron_gpt_config")
def main(cfg) -> None:
logging.info("\n\n************** Experiment configuration ***********")
logging.info(f'\n{OmegaConf.to_yaml(cfg)}')
trainer = None
if cfg.trainer.precision == 16:
trainer = Trainer(
plugins=[
NLPDDPPlugin(num_nodes=cfg.trainer.num_nodes),
NLPNativeMixedPrecisionPlugin(
init_scale=cfg.model.get('native_amp_init_scale', 2 ** 32),
growth_interval=cfg.model.get('native_amp_growth_interval', 1000),
),
],
**cfg.trainer,
)
elif cfg.trainer.precision == 'bf16':
trainer = Trainer(
plugins=[NLPDDPPlugin(num_nodes=cfg.trainer.num_nodes), NLPNativeBfloat16PrecisionPlugin(),],
**cfg.trainer,
)
else:
trainer = Trainer(plugins=[NLPDDPPlugin(num_nodes=cfg.trainer.num_nodes), NLPPrecisionPlugin()], **cfg.trainer)
app_state = AppState()
app_state.model_parallel_size = cfg.model.tensor_model_parallel_size
app_state.model_parallel_rank = compute_model_parallel_rank(trainer.local_rank, app_state.model_parallel_size)
model = MegatronGPTModel.restore_from(
cfg.restore_from_path, trainer=trainer, save_restore_connector=NLPSaveRestoreConnector(),
)
# Note: most nemo models must have the data paths configured before instantiating the model
# MegatronGPTMOdel sets up the data in the PTL method .setup which happens after DDP spawns.
model.cfg.data.splits_string = cfg.model.data.splits_string
trainer.test(model)
if __name__ == '__main__':
main()

View file

@ -145,7 +145,6 @@ trainer:
gpus: 4
num_nodes: 1
max_epochs: 200
amp_level: O2 # O1/O2 for mixed precision
precision: 16 # Should be set to 16 for O1 and O2, default is 16 as PT ignores it when am_level is O0
accelerator: ddp
checkpoint_callback: False

View file

@ -172,7 +172,6 @@ trainer:
gpus: 4
num_nodes: 1
max_epochs: 200
amp_level: O2 # O1/O2 for mixed precision
precision: 16 # Should be set to 16 for O1 and O2, default is 16 as PT ignores it when am_level is O0
accelerator: ddp
checkpoint_callback: False

View file

@ -119,7 +119,6 @@ trainer:
gpus: 4
num_nodes: 1
max_epochs: 200
amp_level: O2 # O1/O2 for mixed precision
precision: 16 # Should be set to 16 for O1 and O2, default is 16 as PT ignores it when am_level is O0
accelerator: ddp
checkpoint_callback: False

View file

@ -147,7 +147,6 @@ trainer:
gpus: 4
num_nodes: 1
max_epochs: 200
amp_level: O2 # O1/O2 for mixed precision
precision: 16 # Should be set to 16 for O1 and O2, default is 16 as PT ignores it when am_level is O0
accelerator: ddp
checkpoint_callback: False

View file

@ -10,7 +10,6 @@ trainer:
max_steps: null # precedence over max_epochs
accumulate_grad_batches: 1 # accumulates grads every k batches
precision: 16 # 16 to use AMP
amp_level: O1 # O1 or O2 if using AMP
accelerator: ddp
gradient_clip_val: 0.0
val_check_interval: 1.0 # check once per epoch .25 for 4 times per epoch

View file

@ -21,7 +21,6 @@ trainer:
max_steps: null # precedence over max_epochs
accumulate_grad_batches: 1 # accumulates grads every k batches
gradient_clip_val: 0.0
amp_level: O0 # O1/O2 for mixed precision
precision: 32 # Should be set to 16 for O1 and O2 to enable the AMP.
accelerator: ddp
log_every_n_steps: 1 # Interval of logging.

View file

@ -24,7 +24,6 @@ trainer:
max_steps: null # precedence over max_epochs
accumulate_grad_batches: 1 # accumulates grads every k batches
gradient_clip_val: 0.0
amp_level: O0 # O1/O2 for mixed precision
precision: 16 # Should be set to 16 for O1 and O2, default is 16 as PT ignores it when am_level is O0
accelerator: ddp
checkpoint_callback: false # Provided by exp_manager

View file

@ -23,7 +23,6 @@ trainer:
max_steps: null # precedence over max_epochs
accumulate_grad_batches: 1 # accumulates grads every k batches
gradient_clip_val: 0.0
amp_level: O0 # O1/O2 for mixed precision
precision: 16 # Should be set to 16 for O1 and O2, default is 16 as PT ignores it when am_level is O0
accelerator: ddp
checkpoint_callback: False # Provided by exp_manager

View file

@ -62,13 +62,7 @@ def main(cfg: DictConfig) -> None:
else:
gpu = 1 if cfg.trainer.gpus != 0 else 0
trainer = pl.Trainer(
gpus=gpu,
precision=cfg.trainer.precision,
amp_level=cfg.trainer.amp_level,
logger=False,
checkpoint_callback=False,
)
trainer = pl.Trainer(gpus=gpu, precision=cfg.trainer.precision, logger=False, checkpoint_callback=False,)
exp_dir = exp_manager(trainer, cfg.exp_manager)
if not cfg.pretrained_model:

View file

@ -70,13 +70,7 @@ def main(cfg: DictConfig) -> None:
else:
gpu = 1 if cfg.trainer.gpus != 0 else 0
trainer = pl.Trainer(
gpus=gpu,
precision=cfg.trainer.precision,
amp_level=cfg.trainer.amp_level,
logger=False,
checkpoint_callback=False,
)
trainer = pl.Trainer(gpus=gpu, precision=cfg.trainer.precision, logger=False, checkpoint_callback=False,)
exp_dir = exp_manager(trainer, cfg.exp_manager)
if not cfg.pretrained_model:

View file

@ -19,7 +19,6 @@ trainer:
max_epochs: 1
max_steps: null # precedence over max_epochs
accumulate_grad_batches: 1 # accumulates grads every k batches
amp_level: O0 # O1/O2 for mixed precision
precision: 16
accelerator: ddp
log_every_n_steps: 1 # Interval of logging.

View file

@ -143,7 +143,6 @@ trainer:
num_nodes: 1
accelerator: ddp
accumulate_grad_batches: 1
amp_level: O0
deterministic: True
checkpoint_callback: False
logger: False

View file

@ -131,7 +131,6 @@ trainer:
num_nodes: 1
accelerator: ddp
accumulate_grad_batches: 1
amp_level: O0
deterministic: True
checkpoint_callback: False
logger: False

View file

@ -93,7 +93,6 @@ trainer:
num_nodes: 1
accelerator: ddp
accumulate_grad_batches: 1
amp_level: O0
deterministic: False
checkpoint_callback: False
logger: False

View file

@ -95,7 +95,6 @@ trainer:
log_every_n_steps: 200
check_val_every_n_epoch: 25
precision: 16
amp_level: O1
exp_manager:
exp_dir: null

View file

@ -29,7 +29,7 @@ class LogEpochTimeCallback(Callback):
self.epoch_start = time.time()
@rank_zero_only
def on_train_epoch_end(self, trainer, pl_module, outputs):
def on_train_epoch_end(self, trainer, pl_module):
curr_time = time.time()
duration = curr_time - self.epoch_start
trainer.logger.log_metrics({"epoch_time": duration}, step=trainer.global_step)

View file

@ -0,0 +1,23 @@
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from pytorch_lightning.plugins.precision.native_amp import NativeMixedPrecisionPlugin
class NeMoNativeMixedPrecisionPlugin(NativeMixedPrecisionPlugin):
def __init__(self, init_scale: float = 2 ** 32, growth_interval: int = 1000) -> None:
super().__init__(precision=16)
self.scaler = torch.cuda.amp.GradScaler(init_scale=init_scale, growth_interval=growth_interval)

View file

@ -33,6 +33,7 @@ class AutoTokenizer(TokenizerSpec):
self,
pretrained_model_name: str,
vocab_file: Optional[str] = None,
merges_file: Optional[str] = None,
mask_token: Optional[str] = None,
bos_token: Optional[str] = None,
eos_token: Optional[str] = None,
@ -60,19 +61,26 @@ class AutoTokenizer(TokenizerSpec):
use_fast: whether to use fast HuggingFace tokenizer
"""
try:
if vocab_file is not None:
message = 'Using "slow" HuggingFace tokenizer'
if use_fast:
message += f'{vocab_file} is ignored in "fast" tokenizers, using a "slow" version'
# this logic deals with different huggingface tokenizers having different positional args
if vocab_file is None:
self.tokenizer = AUTOTOKENIZER.from_pretrained(
pretrained_model_name_or_path=pretrained_model_name, vocab_file=vocab_file, use_fast=False
pretrained_model_name_or_path=pretrained_model_name, use_fast=use_fast,
)
elif merges_file is None:
self.tokenizer = AUTOTOKENIZER.from_pretrained(
pretrained_model_name_or_path=pretrained_model_name, vocab_file=vocab_file, use_fast=use_fast,
)
else:
self.tokenizer = AUTOTOKENIZER.from_pretrained(
pretrained_model_name_or_path=pretrained_model_name, use_fast=use_fast
pretrained_model_name_or_path=pretrained_model_name,
vocab_file=vocab_file,
merges_file=merges_file,
use_fast=use_fast,
)
except Exception as e:
raise ValueError(f'{pretrained_model_name} is not supported by HuggingFace. {e}')
raise ValueError(
f'Unable to instantiate HuggingFace AUTOTOKENIZER for {pretrained_model_name}. Exception: {e}'
)
special_tokens_dict = {}

View file

@ -17,4 +17,5 @@ from nemo.collections.nlp.data.language_modeling.lm_bert_dataset import (
BertPretrainingDataset,
BertPretrainingPreprocessedDataloader,
)
from nemo.collections.nlp.data.language_modeling.megatron import GPTDataset, IndexedDataset, MMapIndexedDataset
from nemo.collections.nlp.data.language_modeling.sentence_dataset import SentenceDataset, TarredSentenceDataset

View file

@ -0,0 +1,23 @@
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
CXXFLAGS += -O3 -Wall -shared -std=c++11 -fPIC -fdiagnostics-color
CPPFLAGS += $(shell python3 -m pybind11 --includes)
LIBNAME = helpers
LIBEXT = $(shell python3-config --extension-suffix)
default: $(LIBNAME)$(LIBEXT)
%$(LIBEXT): %.cpp
$(CXX) $(CXXFLAGS) $(CPPFLAGS) $< -o $@

View file

@ -0,0 +1,20 @@
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from nemo.collections.nlp.data.language_modeling.megatron.gpt_dataset import GPTDataset
from nemo.collections.nlp.data.language_modeling.megatron.indexed_dataset import IndexedDataset, MMapIndexedDataset
# TODO: refactor these datasets to work without megatron-lm dependency
# from nemo.collections.nlp.data.language_modeling.megatron.bert_dataset import BertDataset
# from nemo.collections.nlp.data.language_modeling.megatron.t5_dataset import T5Dataset

View file

@ -0,0 +1,216 @@
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""BERT Style dataset."""
import numpy as np
import torch
from nemo.collections.nlp.data.language_modeling.megatron.dataset_utils import (
create_masked_lm_predictions,
create_tokens_and_tokentypes,
get_a_and_b_segments,
get_samples_mapping,
truncate_segments,
)
class BertDataset(torch.utils.data.Dataset):
def __init__(
self,
name,
indexed_dataset,
data_prefix,
num_epochs,
max_num_samples,
masked_lm_prob,
max_seq_length,
short_seq_prob,
seed,
binary_head,
tokenizer,
):
# Params to store.
self.name = name
self.seed = seed
self.masked_lm_prob = masked_lm_prob
self.max_seq_length = max_seq_length
self.binary_head = binary_head
# Dataset.
self.indexed_dataset = indexed_dataset
# Build the samples mapping.
self.samples_mapping = get_samples_mapping(
self.indexed_dataset,
data_prefix,
num_epochs,
max_num_samples,
self.max_seq_length - 3, # account for added tokens
short_seq_prob,
self.seed,
self.name,
self.binary_head,
)
# Vocab stuff.
self.vocab_id_list = list(tokenizer.inv_vocab.keys())
self.vocab_id_to_token_dict = tokenizer.inv_vocab
self.cls_id = tokenizer.cls_id
self.sep_id = tokenizer.sep_id
self.mask_id = tokenizer.mask_id
self.pad_id = tokenizer.pad_id
def __len__(self):
return self.samples_mapping.shape[0]
def __getitem__(self, idx):
start_idx, end_idx, seq_length = self.samples_mapping[idx]
sample = [self.indexed_dataset[i] for i in range(start_idx, end_idx)]
# Note that this rng state should be numpy and not python since
# python randint is inclusive whereas the numpy one is exclusive.
# We % 2**32 since numpy requres the seed to be between 0 and 2**32 - 1
np_rng = np.random.RandomState(seed=((self.seed + idx) % 2 ** 32))
return build_training_sample(
sample,
seq_length,
self.max_seq_length, # needed for padding
self.vocab_id_list,
self.vocab_id_to_token_dict,
self.cls_id,
self.sep_id,
self.mask_id,
self.pad_id,
self.masked_lm_prob,
np_rng,
self.binary_head,
)
def build_training_sample(
sample,
target_seq_length,
max_seq_length,
vocab_id_list,
vocab_id_to_token_dict,
cls_id,
sep_id,
mask_id,
pad_id,
masked_lm_prob,
np_rng,
binary_head,
):
"""Biuld training sample.
Arguments:
sample: A list of sentences in which each sentence is a list token ids.
target_seq_length: Desired sequence length.
max_seq_length: Maximum length of the sequence. All values are padded to
this length.
vocab_id_list: List of vocabulary ids. Used to pick a random id.
vocab_id_to_token_dict: A dictionary from vocab ids to text tokens.
cls_id: Start of example id.
sep_id: Separator id.
mask_id: Mask token id.
pad_id: Padding token id.
masked_lm_prob: Probability to mask tokens.
np_rng: Random number genenrator. Note that this rng state should be
numpy and not python since python randint is inclusive for
the opper bound whereas the numpy one is exclusive.
"""
if binary_head:
# We assume that we have at least two sentences in the sample
assert len(sample) > 1
assert target_seq_length <= max_seq_length
# Divide sample into two segments (A and B).
if binary_head:
tokens_a, tokens_b, is_next_random = get_a_and_b_segments(sample, np_rng)
else:
tokens_a = []
for j in range(len(sample)):
tokens_a.extend(sample[j])
tokens_b = []
is_next_random = False
# Truncate to `target_sequence_length`.
max_num_tokens = target_seq_length
truncated = truncate_segments(tokens_a, tokens_b, len(tokens_a), len(tokens_b), max_num_tokens, np_rng)
# Build tokens and toketypes.
tokens, tokentypes = create_tokens_and_tokentypes(tokens_a, tokens_b, cls_id, sep_id)
# Masking.
max_predictions_per_seq = masked_lm_prob * max_num_tokens
(tokens, masked_positions, masked_labels, _, _) = create_masked_lm_predictions(
tokens,
vocab_id_list,
vocab_id_to_token_dict,
masked_lm_prob,
cls_id,
sep_id,
mask_id,
max_predictions_per_seq,
np_rng,
)
# Padding.
tokens_np, tokentypes_np, labels_np, padding_mask_np, loss_mask_np = pad_and_convert_to_numpy(
tokens, tokentypes, masked_positions, masked_labels, pad_id, max_seq_length
)
train_sample = {
'text': tokens_np,
'types': tokentypes_np,
'labels': labels_np,
'is_random': int(is_next_random),
'loss_mask': loss_mask_np,
'padding_mask': padding_mask_np,
'truncated': int(truncated),
}
return train_sample
def pad_and_convert_to_numpy(tokens, tokentypes, masked_positions, masked_labels, pad_id, max_seq_length):
"""Pad sequences and convert them to numpy."""
# Some checks.
num_tokens = len(tokens)
padding_length = max_seq_length - num_tokens
assert padding_length >= 0
assert len(tokentypes) == num_tokens
assert len(masked_positions) == len(masked_labels)
# Tokens and token types.
filler = [pad_id] * padding_length
tokens_np = np.array(tokens + filler, dtype=np.int64)
tokentypes_np = np.array(tokentypes + filler, dtype=np.int64)
# Padding mask.
padding_mask_np = np.array([1] * num_tokens + [0] * padding_length, dtype=np.int64)
# Lables and loss mask.
labels = [-1] * max_seq_length
loss_mask = [0] * max_seq_length
for i in range(len(masked_positions)):
assert masked_positions[i] < num_tokens
labels[masked_positions[i]] = masked_labels[i]
loss_mask[masked_positions[i]] = 1
labels_np = np.array(labels, dtype=np.int64)
loss_mask_np = np.array(loss_mask, dtype=np.int64)
return tokens_np, tokentypes_np, labels_np, padding_mask_np, loss_mask_np

View file

@ -0,0 +1,75 @@
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Blendable dataset."""
import time
import numpy as np
import torch
from nemo.utils import logging
from nemo.utils.get_rank import is_global_rank_zero
class BlendableDataset(torch.utils.data.Dataset):
def __init__(self, datasets, weights):
self.datasets = datasets
num_datasets = len(datasets)
assert num_datasets == len(weights)
self.size = 0
for dataset in self.datasets:
self.size += len(dataset)
# Normalize weights.
weights = np.array(weights, dtype=np.float64)
sum_weights = np.sum(weights)
assert sum_weights > 0.0
weights /= sum_weights
# Build indecies.
start_time = time.time()
assert num_datasets < 255
self.dataset_index = np.zeros(self.size, dtype=np.uint8)
self.dataset_sample_index = np.zeros(self.size, dtype=np.int64)
try:
if is_global_rank_zero():
from nemo.collections.nlp.data.language_modeling.megatron.dataset_utils import compile_helper
compile_helper()
from nemo.collections.nlp.data.language_modeling.megatron import helpers
except:
raise Exception(f'Could not compile helpers.')
helpers.build_blending_indices(
self.dataset_index,
self.dataset_sample_index,
weights,
num_datasets,
self.size,
torch.distributed.get_rank() == 0,
)
logging.info(
'> elapsed time for building blendable dataset indices: ' '{:.2f} (sec)'.format(time.time() - start_time)
)
def __len__(self):
return self.size
def __getitem__(self, idx):
dataset_idx = self.dataset_index[idx]
sample_idx = self.dataset_sample_index[idx]
return self.datasets[dataset_idx][sample_idx]

View file

@ -0,0 +1,121 @@
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Dataloaders."""
import torch
from nemo.utils import logging
class MegatronPretrainingSampler:
def __init__(
self, total_samples, consumed_samples, micro_batch_size, data_parallel_rank, data_parallel_size, drop_last=True
):
# Keep a copy of input params for later use.
self.total_samples = total_samples
self.consumed_samples = consumed_samples
self.micro_batch_size = micro_batch_size
self.data_parallel_rank = data_parallel_rank
self.micro_batch_times_data_parallel_size = self.micro_batch_size * data_parallel_size
self.drop_last = drop_last
logging.info(
f'Instantiating MegatronPretrainingSampler with total_samples: {total_samples} and consumed_samples: {consumed_samples}'
)
# Sanity checks.
assert self.total_samples > 0, 'no sample to consume: {}'.format(self.total_samples)
assert self.consumed_samples < self.total_samples, 'no samples left to consume: {}, {}'.format(
self.consumed_samples, self.total_samples
)
assert self.micro_batch_size > 0
assert data_parallel_size > 0
assert self.data_parallel_rank < data_parallel_size, (
'data_parallel_rank should be smaller than data size: {}, '
'{}'.format(self.data_parallel_rank, data_parallel_size)
)
def __len__(self):
return self.total_samples
def get_start_end_idx(self):
start_idx = self.data_parallel_rank * self.micro_batch_size
end_idx = start_idx + self.micro_batch_size
return start_idx, end_idx
def __iter__(self):
batch = []
# Last batch will be dropped if drop_last is not set False
for idx in range(self.consumed_samples, self.total_samples):
batch.append(idx)
if len(batch) == self.micro_batch_times_data_parallel_size:
start_idx, end_idx = self.get_start_end_idx()
yield batch[start_idx:end_idx]
batch = []
# Check the last partial batch and see drop_last is set
if len(batch) > 0 and not self.drop_last:
start_idx, end_idx = self.get_start_end_idx()
yield batch[start_idx:end_idx]
class MegatronPretrainingRandomSampler:
def __init__(self, total_samples, consumed_samples, micro_batch_size, data_parallel_rank, data_parallel_size):
# Keep a copy of input params for later use.
self.total_samples = total_samples
self.consumed_samples = consumed_samples
self.micro_batch_size = micro_batch_size
self.data_parallel_rank = data_parallel_rank
self.data_parallel_size = data_parallel_size
self.micro_batch_times_data_parallel_size = self.micro_batch_size * data_parallel_size
self.last_batch_size = self.total_samples % self.micro_batch_times_data_parallel_size
# Sanity checks.
assert self.total_samples > 0, 'no sample to consume: {}'.format(self.total_samples)
assert self.micro_batch_size > 0
assert data_parallel_size > 0
assert self.data_parallel_rank < data_parallel_size, (
'data_parallel_rank should be smaller than data size: {}, '
'{}'.format(self.data_parallel_rank, data_parallel_size)
)
def __len__(self):
return self.total_samples
def __iter__(self):
active_total_samples = self.total_samples - self.last_batch_size
self.epoch = self.consumed_samples // active_total_samples
current_epoch_samples = self.consumed_samples % active_total_samples
assert current_epoch_samples % self.micro_batch_times_data_parallel_size == 0
# data sharding and random sampling
bucket_size = (self.total_samples // self.micro_batch_times_data_parallel_size) * self.micro_batch_size
bucket_offset = current_epoch_samples // self.data_parallel_size
start_idx = self.data_parallel_rank * bucket_size
g = torch.Generator()
g.manual_seed(self.epoch)
random_idx = torch.randperm(bucket_size, generator=g).tolist()
idx_range = [start_idx + x for x in random_idx[bucket_offset:]]
batch = []
# Last batch if not complete will be dropped.
for idx in idx_range:
batch.append(idx)
if len(batch) == self.micro_batch_size:
self.consumed_samples += self.micro_batch_times_data_parallel_size
yield batch
batch = []

View file

@ -0,0 +1,732 @@
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright 2018 The Google AI Team Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Most of the code here has been copied from:
# https://github.com/google-research/albert/blob/master/create_pretraining_data.py
# with some modifications.
import collections
import math
import os
import subprocess
import time
import numpy as np
import torch
from apex.transformer import parallel_state
from nemo.collections.nlp.data.language_modeling.megatron.blendable_dataset import BlendableDataset
from nemo.collections.nlp.data.language_modeling.megatron.indexed_dataset import make_dataset as make_indexed_dataset
from nemo.utils import logging
from nemo.utils.get_rank import is_global_rank_zero
DSET_TYPE_BERT = 'standard_bert'
DSET_TYPE_ICT = 'ict'
DSET_TYPE_T5 = 't5'
DSET_TYPES = [DSET_TYPE_BERT, DSET_TYPE_ICT, DSET_TYPE_T5]
def get_datasets_weights_and_num_samples(data_prefix, train_valid_test_num_samples):
# The data prefix should be in the format of:
# weight-1, data-prefix-1, weight-2, data-prefix-2, ..
assert len(data_prefix) % 2 == 0
num_datasets = len(data_prefix) // 2
weights = [0] * num_datasets
prefixes = [0] * num_datasets
for i in range(num_datasets):
weights[i] = float(data_prefix[2 * i])
prefixes[i] = (data_prefix[2 * i + 1]).strip()
# Normalize weights
weight_sum = 0.0
for weight in weights:
weight_sum += weight
assert weight_sum > 0.0
weights = [weight / weight_sum for weight in weights]
# Add 0.5% (the 1.005 factor) so in case the bleding dataset does
# not uniformly distribute the number of samples, we still have
# samples left to feed to the network.
datasets_train_valid_test_num_samples = []
for weight in weights:
datasets_train_valid_test_num_samples.append(
[int(math.ceil(val * weight * 1.005)) for val in train_valid_test_num_samples]
)
return prefixes, weights, datasets_train_valid_test_num_samples
def compile_helper():
"""Compile helper function ar runtime. Make sure this
is invoked on a single process."""
path = os.path.abspath(os.path.dirname(__file__))
ret = subprocess.run(['make', '-C', path])
if ret.returncode != 0:
logging.error("Making C++ dataset helpers module failed, exiting.")
import sys
sys.exit(1)
def get_a_and_b_segments(sample, np_rng):
"""Divide sample into a and b segments."""
# Number of sentences in the sample.
n_sentences = len(sample)
# Make sure we always have two sentences.
assert n_sentences > 1, 'make sure each sample has at least two sentences.'
# First part:
# `a_end` is how many sentences go into the `A`.
a_end = 1
if n_sentences >= 3:
# Note that randin in numpy is exclusive.
a_end = np_rng.randint(1, n_sentences)
tokens_a = []
for j in range(a_end):
tokens_a.extend(sample[j])
# Second part:
tokens_b = []
for j in range(a_end, n_sentences):
tokens_b.extend(sample[j])
# Random next:
is_next_random = False
if np_rng.random() < 0.5:
is_next_random = True
tokens_a, tokens_b = tokens_b, tokens_a
return tokens_a, tokens_b, is_next_random
def truncate_segments(tokens_a, tokens_b, len_a, len_b, max_num_tokens, np_rng):
"""Truncates a pair of sequences to a maximum sequence length."""
# print(len_a, len_b, max_num_tokens)
assert len_a > 0
if len_a + len_b <= max_num_tokens:
return False
while len_a + len_b > max_num_tokens:
if len_a > len_b:
len_a -= 1
tokens = tokens_a
else:
len_b -= 1
tokens = tokens_b
if np_rng.random() < 0.5:
del tokens[0]
else:
tokens.pop()
return True
def create_tokens_and_tokentypes(tokens_a, tokens_b, cls_id, sep_id):
"""Merge segments A and B, add [CLS] and [SEP] and build tokentypes."""
tokens = []
tokentypes = []
# [CLS].
tokens.append(cls_id)
tokentypes.append(0)
# Segment A.
for token in tokens_a:
tokens.append(token)
tokentypes.append(0)
# [SEP].
tokens.append(sep_id)
tokentypes.append(0)
# Segment B.
for token in tokens_b:
tokens.append(token)
tokentypes.append(1)
if tokens_b:
# [SEP].
tokens.append(sep_id)
tokentypes.append(1)
return tokens, tokentypes
MaskedLmInstance = collections.namedtuple("MaskedLmInstance", ["index", "label"])
def is_start_piece(piece):
"""Check if the current word piece is the starting piece (BERT)."""
# When a word has been split into
# WordPieces, the first token does not have any marker and any subsequence
# tokens are prefixed with ##. So whenever we see the ## token, we
# append it to the previous set of word indexes.
return not piece.startswith("##")
def create_masked_lm_predictions(
tokens,
vocab_id_list,
vocab_id_to_token_dict,
masked_lm_prob,
cls_id,
sep_id,
mask_id,
max_predictions_per_seq,
np_rng,
max_ngrams=3,
do_whole_word_mask=True,
favor_longer_ngram=False,
do_permutation=False,
geometric_dist=False,
masking_style="bert",
):
"""Creates the predictions for the masked LM objective.
Note: Tokens here are vocab ids and not text tokens."""
cand_indexes = []
# Note(mingdachen): We create a list for recording if the piece is
# the starting piece of current token, where 1 means true, so that
# on-the-fly whole word masking is possible.
token_boundary = [0] * len(tokens)
for (i, token) in enumerate(tokens):
if token == cls_id or token == sep_id:
token_boundary[i] = 1
continue
# Whole Word Masking means that if we mask all of the wordpieces
# corresponding to an original word.
#
# Note that Whole Word Masking does *not* change the training code
# at all -- we still predict each WordPiece independently, softmaxed
# over the entire vocabulary.
if do_whole_word_mask and len(cand_indexes) >= 1 and not is_start_piece(vocab_id_to_token_dict[token]):
cand_indexes[-1].append(i)
else:
cand_indexes.append([i])
if is_start_piece(vocab_id_to_token_dict[token]):
token_boundary[i] = 1
output_tokens = list(tokens)
masked_lm_positions = []
masked_lm_labels = []
if masked_lm_prob == 0:
return (output_tokens, masked_lm_positions, masked_lm_labels, token_boundary)
num_to_predict = min(max_predictions_per_seq, max(1, int(round(len(tokens) * masked_lm_prob))))
ngrams = np.arange(1, max_ngrams + 1, dtype=np.int64)
if not geometric_dist:
# Note(mingdachen):
# By default, we set the probilities to favor shorter ngram sequences.
pvals = 1.0 / np.arange(1, max_ngrams + 1)
pvals /= pvals.sum(keepdims=True)
if favor_longer_ngram:
pvals = pvals[::-1]
ngram_indexes = []
for idx in range(len(cand_indexes)):
ngram_index = []
for n in ngrams:
ngram_index.append(cand_indexes[idx : idx + n])
ngram_indexes.append(ngram_index)
np_rng.shuffle(ngram_indexes)
(masked_lms, masked_spans) = ([], [])
covered_indexes = set()
for cand_index_set in ngram_indexes:
if len(masked_lms) >= num_to_predict:
break
if not cand_index_set:
continue
# Note(mingdachen):
# Skip current piece if they are covered in lm masking or previous ngrams.
for index_set in cand_index_set[0]:
for index in index_set:
if index in covered_indexes:
continue
if not geometric_dist:
n = np_rng.choice(
ngrams[: len(cand_index_set)],
p=pvals[: len(cand_index_set)] / pvals[: len(cand_index_set)].sum(keepdims=True),
)
else:
# Sampling "n" from the geometric distribution and clipping it to
# the max_ngrams. Using p=0.2 default from the SpanBERT paper
# https://arxiv.org/pdf/1907.10529.pdf (Sec 3.1)
n = min(np_rng.geometric(0.2), max_ngrams)
index_set = sum(cand_index_set[n - 1], [])
n -= 1
# Note(mingdachen):
# Repeatedly looking for a candidate that does not exceed the
# maximum number of predictions by trying shorter ngrams.
while len(masked_lms) + len(index_set) > num_to_predict:
if n == 0:
break
index_set = sum(cand_index_set[n - 1], [])
n -= 1
# If adding a whole-word mask would exceed the maximum number of
# predictions, then just skip this candidate.
if len(masked_lms) + len(index_set) > num_to_predict:
continue
is_any_index_covered = False
for index in index_set:
if index in covered_indexes:
is_any_index_covered = True
break
if is_any_index_covered:
continue
for index in index_set:
covered_indexes.add(index)
masked_token = None
if masking_style == "bert":
# 80% of the time, replace with [MASK]
if np_rng.random() < 0.8:
masked_token = mask_id
else:
# 10% of the time, keep original
if np_rng.random() < 0.5:
masked_token = tokens[index]
# 10% of the time, replace with random word
else:
masked_token = vocab_id_list[np_rng.randint(0, len(vocab_id_list))]
elif masking_style == "t5":
masked_token = mask_id
else:
raise ValueError("invalid value of masking style")
output_tokens[index] = masked_token
masked_lms.append(MaskedLmInstance(index=index, label=tokens[index]))
masked_spans.append(MaskedLmInstance(index=index_set, label=[tokens[index] for index in index_set]))
assert len(masked_lms) <= num_to_predict
np_rng.shuffle(ngram_indexes)
select_indexes = set()
if do_permutation:
for cand_index_set in ngram_indexes:
if len(select_indexes) >= num_to_predict:
break
if not cand_index_set:
continue
# Note(mingdachen):
# Skip current piece if they are covered in lm masking or previous ngrams.
for index_set in cand_index_set[0]:
for index in index_set:
if index in covered_indexes or index in select_indexes:
continue
n = np.random.choice(
ngrams[: len(cand_index_set)],
p=pvals[: len(cand_index_set)] / pvals[: len(cand_index_set)].sum(keepdims=True),
)
index_set = sum(cand_index_set[n - 1], [])
n -= 1
while len(select_indexes) + len(index_set) > num_to_predict:
if n == 0:
break
index_set = sum(cand_index_set[n - 1], [])
n -= 1
# If adding a whole-word mask would exceed the maximum number of
# predictions, then just skip this candidate.
if len(select_indexes) + len(index_set) > num_to_predict:
continue
is_any_index_covered = False
for index in index_set:
if index in covered_indexes or index in select_indexes:
is_any_index_covered = True
break
if is_any_index_covered:
continue
for index in index_set:
select_indexes.add(index)
assert len(select_indexes) <= num_to_predict
select_indexes = sorted(select_indexes)
permute_indexes = list(select_indexes)
np_rng.shuffle(permute_indexes)
orig_token = list(output_tokens)
for src_i, tgt_i in zip(select_indexes, permute_indexes):
output_tokens[src_i] = orig_token[tgt_i]
masked_lms.append(MaskedLmInstance(index=src_i, label=orig_token[src_i]))
masked_lms = sorted(masked_lms, key=lambda x: x.index)
# Sort the spans by the index of the first span
masked_spans = sorted(masked_spans, key=lambda x: x.index[0])
for p in masked_lms:
masked_lm_positions.append(p.index)
masked_lm_labels.append(p.label)
return (output_tokens, masked_lm_positions, masked_lm_labels, token_boundary, masked_spans)
def pad_and_convert_to_numpy(tokens, tokentypes, masked_positions, masked_labels, pad_id, max_seq_length):
"""Pad sequences and convert them to numpy."""
# Some checks.
num_tokens = len(tokens)
padding_length = max_seq_length - num_tokens
assert padding_length >= 0
assert len(tokentypes) == num_tokens
assert len(masked_positions) == len(masked_labels)
# Tokens and token types.
filler = [pad_id] * padding_length
tokens_np = np.array(tokens + filler, dtype=np.int64)
tokentypes_np = np.array(tokentypes + filler, dtype=np.int64)
# Padding mask.
padding_mask_np = np.array([1] * num_tokens + [0] * padding_length, dtype=np.int64)
# Lables and loss mask.
labels = [-1] * max_seq_length
loss_mask = [0] * max_seq_length
for i in range(len(masked_positions)):
assert masked_positions[i] < num_tokens
labels[masked_positions[i]] = masked_labels[i]
loss_mask[masked_positions[i]] = 1
labels_np = np.array(labels, dtype=np.int64)
loss_mask_np = np.array(loss_mask, dtype=np.int64)
return tokens_np, tokentypes_np, labels_np, padding_mask_np, loss_mask_np
def build_train_valid_test_datasets(
data_prefix,
data_impl,
splits_string,
train_valid_test_num_samples,
max_seq_length,
masked_lm_prob,
short_seq_prob,
seed,
skip_warmup,
binary_head=False,
max_seq_length_dec=None,
dataset_type='standard_bert',
):
if len(data_prefix) == 1:
return _build_train_valid_test_datasets(
data_prefix[0],
data_impl,
splits_string,
train_valid_test_num_samples,
max_seq_length,
masked_lm_prob,
short_seq_prob,
seed,
skip_warmup,
binary_head,
max_seq_length_dec,
dataset_type=dataset_type,
)
# Blending dataset.
# Parse the values.
output = get_datasets_weights_and_num_samples(data_prefix, train_valid_test_num_samples)
prefixes, weights, datasets_train_valid_test_num_samples = output
# Build individual datasets.
train_datasets = []
valid_datasets = []
test_datasets = []
for i in range(len(prefixes)):
train_ds, valid_ds, test_ds = _build_train_valid_test_datasets(
prefixes[i],
data_impl,
splits_string,
datasets_train_valid_test_num_samples[i],
max_seq_length,
masked_lm_prob,
short_seq_prob,
seed,
skip_warmup,
binary_head,
dataset_type=dataset_type,
)
if train_ds:
train_datasets.append(train_ds)
if valid_ds:
valid_datasets.append(valid_ds)
if test_ds:
test_datasets.append(test_ds)
# Blend.
blending_train_dataset = None
if train_datasets:
blending_train_dataset = BlendableDataset(train_datasets, weights)
blending_valid_dataset = None
if valid_datasets:
blending_valid_dataset = BlendableDataset(valid_datasets, weights)
blending_test_dataset = None
if test_datasets:
blending_test_dataset = BlendableDataset(test_datasets, weights)
return (blending_train_dataset, blending_valid_dataset, blending_test_dataset)
def _build_train_valid_test_datasets(
data_prefix,
data_impl,
splits_string,
train_valid_test_num_samples,
max_seq_length,
masked_lm_prob,
short_seq_prob,
seed,
skip_warmup,
binary_head,
max_seq_length_dec,
dataset_type='standard_bert',
):
if dataset_type not in DSET_TYPES:
raise ValueError("Invalid dataset_type: ", dataset_type)
# Indexed dataset.
indexed_dataset = get_indexed_dataset_(data_prefix, data_impl, skip_warmup)
if dataset_type == DSET_TYPE_ICT:
args = get_args()
title_dataset = get_indexed_dataset_(args.titles_data_path, data_impl, skip_warmup)
# Get start and end indices of train/valid/train into doc-idx
# Note that doc-idx is desinged to be num-docs + 1 so we can
# easily iterate over it.
total_num_of_documents = indexed_dataset.doc_idx.shape[0] - 1
splits = get_train_valid_test_split_(splits_string, total_num_of_documents)
# Print stats about the splits.
logging.info(' > dataset split:')
def print_split_stats(name, index):
logging.info(' {}:'.format(name))
logging.info(
' document indices in [{}, {}) total of {} '
'documents'.format(splits[index], splits[index + 1], splits[index + 1] - splits[index])
)
start_index = indexed_dataset.doc_idx[splits[index]]
end_index = indexed_dataset.doc_idx[splits[index + 1]]
logging.info(
' sentence indices in [{}, {}) total of {} '
'sentences'.format(start_index, end_index, end_index - start_index)
)
print_split_stats('train', 0)
print_split_stats('validation', 1)
print_split_stats('test', 2)
def build_dataset(index, name):
from nemo.collections.nlp.data.language_modeling.megatron.bert_dataset import BertDataset
from nemo.collections.nlp.data.language_modeling.megatron.t5_dataset import T5Dataset
dataset = None
if splits[index + 1] > splits[index]:
# Get the pointer to the original doc-idx so we can set it later.
doc_idx_ptr = indexed_dataset.get_doc_idx()
# Slice the doc-idx
start_index = splits[index]
# Add +1 so we can index into the dataset to get the upper bound.
end_index = splits[index + 1] + 1
# New doc_idx view.
indexed_dataset.set_doc_idx(doc_idx_ptr[start_index:end_index])
# Build the dataset accordingly.
kwargs = dict(
name=name,
data_prefix=data_prefix,
num_epochs=None,
max_num_samples=train_valid_test_num_samples[index],
max_seq_length=max_seq_length,
seed=seed,
)
if dataset_type == DSET_TYPE_T5:
dataset = T5Dataset(
indexed_dataset=indexed_dataset,
masked_lm_prob=masked_lm_prob,
max_seq_length_dec=max_seq_length_dec,
short_seq_prob=short_seq_prob,
**kwargs,
)
elif dataset_type == DSET_TYPE_BERT:
dataset = BertDataset(
indexed_dataset=indexed_dataset,
masked_lm_prob=masked_lm_prob,
short_seq_prob=short_seq_prob,
binary_head=binary_head,
**kwargs,
)
else:
raise NotImplementedError("Dataset type not fully implemented.")
# Set the original pointer so dataset remains the main dataset.
indexed_dataset.set_doc_idx(doc_idx_ptr)
# Checks.
assert indexed_dataset.doc_idx[0] == 0
assert indexed_dataset.doc_idx.shape[0] == (total_num_of_documents + 1)
return dataset
train_dataset = build_dataset(0, 'train')
valid_dataset = build_dataset(1, 'valid')
test_dataset = build_dataset(2, 'test')
return (train_dataset, valid_dataset, test_dataset)
def get_indexed_dataset_(data_prefix, data_impl, skip_warmup):
logging.info(' > building dataset index ...')
start_time = time.time()
indexed_dataset = make_indexed_dataset(data_prefix, data_impl, skip_warmup)
assert indexed_dataset.sizes.shape[0] == indexed_dataset.doc_idx[-1]
logging.info(' > finished creating indexed dataset in {:4f} ' 'seconds'.format(time.time() - start_time))
logging.info(' > indexed dataset stats:')
logging.info(' number of documents: {}'.format(indexed_dataset.doc_idx.shape[0] - 1))
logging.info(' number of sentences: {}'.format(indexed_dataset.sizes.shape[0]))
return indexed_dataset
def get_train_valid_test_split_(splits_string, size):
""" Get dataset splits from comma or '/' separated string list."""
splits = []
if splits_string.find(',') != -1:
splits = [float(s) for s in splits_string.split(',')]
elif splits_string.find('/') != -1:
splits = [float(s) for s in splits_string.split('/')]
else:
splits = [float(splits_string)]
while len(splits) < 3:
splits.append(0.0)
splits = splits[:3]
splits_sum = sum(splits)
assert splits_sum > 0.0
splits = [split / splits_sum for split in splits]
splits_index = [0]
for index, split in enumerate(splits):
splits_index.append(splits_index[index] + int(round(split * float(size))))
diff = splits_index[-1] - size
for index in range(1, len(splits_index)):
splits_index[index] -= diff
assert len(splits_index) == 4
assert splits_index[-1] == size
return splits_index
def get_samples_mapping(
indexed_dataset, data_prefix, num_epochs, max_num_samples, max_seq_length, short_seq_prob, seed, name, binary_head
):
"""Get a list that maps a sample index to a starting sentence index, end sentence index, and length"""
if not num_epochs:
if not max_num_samples:
raise ValueError("Need to specify either max_num_samples " "or num_epochs")
num_epochs = np.iinfo(np.int32).max - 1
if not max_num_samples:
max_num_samples = np.iinfo(np.int64).max - 1
# Filename of the index mapping
indexmap_filename = data_prefix
indexmap_filename += '_{}_indexmap'.format(name)
if num_epochs != (np.iinfo(np.int32).max - 1):
indexmap_filename += '_{}ep'.format(num_epochs)
if max_num_samples != (np.iinfo(np.int64).max - 1):
indexmap_filename += '_{}mns'.format(max_num_samples)
indexmap_filename += '_{}msl'.format(max_seq_length)
indexmap_filename += '_{:0.2f}ssp'.format(short_seq_prob)
indexmap_filename += '_{}s'.format(seed)
indexmap_filename += '.npy'
# Build the indexed mapping if not exist.
if torch.distributed.get_rank() == 0 and not os.path.isfile(indexmap_filename):
print(
' > WARNING: could not find index map file {}, building '
'the indices on rank 0 ...'.format(indexmap_filename)
)
# Make sure the types match the helpers input types.
assert indexed_dataset.doc_idx.dtype == np.int64
assert indexed_dataset.sizes.dtype == np.int32
# Build samples mapping
verbose = torch.distributed.get_rank() == 0
start_time = time.time()
logging.info(' > building samples index mapping for {} ...'.format(name))
# First compile and then import.
try:
if is_global_rank_zero():
compile_helper()
from nemo.collections.nlp.data.language_modeling.megatron import helpers
except:
raise Exception(f'Could not compile helpers.')
samples_mapping = helpers.build_mapping(
indexed_dataset.doc_idx,
indexed_dataset.sizes,
num_epochs,
max_num_samples,
max_seq_length,
short_seq_prob,
seed,
verbose,
2 if binary_head else 1,
)
logging.info(' > done building samples index maping')
np.save(indexmap_filename, samples_mapping, allow_pickle=True)
logging.info(' > saved the index mapping in {}'.format(indexmap_filename))
# Make sure all the ranks have built the mapping
logging.info(
' > elasped time to build and save samples mapping ' '(seconds): {:4f}'.format(time.time() - start_time)
)
# This should be a barrier but nccl barrier assumes
# device_index=rank which is not the case for model
# parallel case
counts = torch.cuda.LongTensor([1])
torch.distributed.all_reduce(counts, group=parallel_state.get_data_parallel_group())
torch.distributed.all_reduce(counts, group=parallel_state.get_pipeline_model_parallel_group())
assert counts[0].item() == (
torch.distributed.get_world_size()
// torch.distributed.get_world_size(group=parallel_state.get_tensor_model_parallel_group())
)
# Load indexed dataset.
logging.info(' > loading indexed mapping from {}'.format(indexmap_filename))
start_time = time.time()
samples_mapping = np.load(indexmap_filename, allow_pickle=True, mmap_mode='r')
logging.info(' loaded indexed file in {:3.3f} seconds'.format(time.time() - start_time))
logging.info(' total number of samples: {}'.format(samples_mapping.shape[0]))
return samples_mapping

View file

@ -0,0 +1,449 @@
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""GPT style dataset."""
import os
import time
import numpy as np
import torch
from apex.transformer import parallel_state
from nemo.collections.nlp.data.language_modeling.megatron.blendable_dataset import BlendableDataset
from nemo.collections.nlp.data.language_modeling.megatron.dataset_utils import (
get_datasets_weights_and_num_samples,
get_train_valid_test_split_,
)
from nemo.collections.nlp.data.language_modeling.megatron.indexed_dataset import make_dataset as make_indexed_dataset
from nemo.collections.nlp.data.language_modeling.megatron.megatron_dataset import MegatronDataset
from nemo.utils import logging
from nemo.utils.get_rank import is_global_rank_zero
def build_train_valid_test_datasets(
cfg, trainer, data_prefix, data_impl, splits_string, train_valid_test_num_samples, seq_length, seed, skip_warmup
):
"""Build train, valid, and test datasets."""
# Single dataset.
if len(data_prefix) == 1:
return _build_train_valid_test_datasets(
cfg,
trainer,
data_prefix[0],
data_impl,
splits_string,
train_valid_test_num_samples,
seq_length,
seed,
skip_warmup,
)
# Blending dataset.
# Parse the values.
output = get_datasets_weights_and_num_samples(data_prefix, train_valid_test_num_samples)
prefixes, weights, datasets_train_valid_test_num_samples = output
# Build individual datasets.
train_datasets = []
valid_datasets = []
test_datasets = []
for i in range(len(prefixes)):
train_ds, valid_ds, test_ds = _build_train_valid_test_datasets(
cfg,
trainer,
prefixes[i],
data_impl,
splits_string,
datasets_train_valid_test_num_samples[i],
seq_length,
seed,
skip_warmup,
)
if train_ds:
train_datasets.append(train_ds)
if valid_ds:
valid_datasets.append(valid_ds)
if test_ds:
test_datasets.append(test_ds)
# Blend.
blending_train_dataset = None
if train_datasets:
blending_train_dataset = BlendableDataset(train_datasets, weights)
blending_valid_dataset = None
if valid_datasets:
blending_valid_dataset = BlendableDataset(valid_datasets, weights)
blending_test_dataset = None
if test_datasets:
blending_test_dataset = BlendableDataset(test_datasets, weights)
return (blending_train_dataset, blending_valid_dataset, blending_test_dataset)
def _build_train_valid_test_datasets(
cfg, trainer, data_prefix, data_impl, splits_string, train_valid_test_num_samples, seq_length, seed, skip_warmup
):
"""Build train, valid, and test datasets."""
# Indexed dataset.
indexed_dataset = get_indexed_dataset_(data_prefix, data_impl, skip_warmup)
total_num_of_documents = indexed_dataset.sizes.shape[0]
splits = get_train_valid_test_split_(splits_string, total_num_of_documents)
# Print stats about the splits.
logging.info(' > dataset split:')
def print_split_stats(name, index):
logging.info(' {}:'.format(name))
logging.info(
' document indices in [{}, {}) total of {} '
'documents'.format(splits[index], splits[index + 1], splits[index + 1] - splits[index])
)
print_split_stats('train', 0)
print_split_stats('validation', 1)
print_split_stats('test', 2)
def build_dataset(index, name):
dataset = None
if splits[index + 1] > splits[index]:
documents = np.arange(start=splits[index], stop=splits[index + 1], step=1, dtype=np.int32)
dataset = GPTDataset(
cfg,
trainer,
name,
data_prefix,
documents,
indexed_dataset,
train_valid_test_num_samples[index],
seq_length,
seed,
)
return dataset
train_dataset = build_dataset(0, 'train')
valid_dataset = build_dataset(1, 'valid')
test_dataset = build_dataset(2, 'test')
return (train_dataset, valid_dataset, test_dataset)
def get_indexed_dataset_(data_prefix, data_impl, skip_warmup):
"""Build indexed dataset."""
logging.info(' > building dataset index ...')
start_time = time.time()
indexed_dataset = make_indexed_dataset(data_prefix, data_impl, skip_warmup)
logging.info(' > finished creating indexed dataset in {:4f} ' 'seconds'.format(time.time() - start_time))
logging.info(' number of documents: {}'.format(indexed_dataset.sizes.shape[0]))
return indexed_dataset
class GPTDataset(MegatronDataset):
def __init__(self, cfg, trainer, name, data_prefix, documents, indexed_dataset, num_samples, seq_length, seed):
super().__init__(cfg, trainer=trainer)
self.name = name
self.indexed_dataset = indexed_dataset
# Checks
assert np.min(documents) >= 0
assert np.max(documents) < indexed_dataset.sizes.shape[0]
# Build index mappings.
self.doc_idx, self.sample_idx, self.shuffle_idx = _build_index_mappings(
self.name, data_prefix, documents, self.indexed_dataset.sizes, num_samples, seq_length, seed
)
def __len__(self):
# -1 is due to data structure used to retieve the index:
# sample i --> [sample_idx[i], sample_idx[i+1])
return self.sample_idx.shape[0] - 1
def __getitem__(self, idx):
# Get the shuffled index.
idx = self.shuffle_idx[idx]
# Start and end documents and offsets.
doc_index_f = self.sample_idx[idx][0]
doc_index_l = self.sample_idx[idx + 1][0]
offset_f = self.sample_idx[idx][1]
offset_l = self.sample_idx[idx + 1][1]
# If we are within the same document, just extract the chunk.
if doc_index_f == doc_index_l:
sample = self.indexed_dataset.get(
self.doc_idx[doc_index_f], offset=offset_f, length=offset_l - offset_f + 1
)
else:
# Otherwise, get the rest of the initial document.
sample_list = [self.indexed_dataset.get(self.doc_idx[doc_index_f], offset=offset_f)]
# Loop over all in between documents and add the entire document.
for i in range(doc_index_f + 1, doc_index_l):
sample_list.append(self.indexed_dataset.get(self.doc_idx[i]))
# And finally add the relevant portion of last document.
sample_list.append(self.indexed_dataset.get(self.doc_idx[doc_index_l], length=offset_l + 1))
sample = np.concatenate(sample_list)
return {'text': np.array(sample, dtype=np.int64)}
def _build_index_mappings(name, data_prefix, documents, sizes, num_samples, seq_length, seed):
"""Build doc-idx, sample-idx, and shuffle-idx.
doc-idx: is an array (ordered) of documents to be used in training.
sample-idx: is the start document index and document offset for each
training sample.
shuffle-idx: maps the sample index into a random index into sample-idx.
"""
# Number of tokens in each epoch and number of required epochs.
tokens_per_epoch = _num_tokens(documents, sizes)
num_epochs = _num_epochs(tokens_per_epoch, seq_length, num_samples)
# rng state
np_rng = np.random.RandomState(seed=seed)
# Filename of the index mappings.
_filename = data_prefix
_filename += '_{}_indexmap'.format(name)
_filename += '_{}ns'.format(num_samples)
_filename += '_{}sl'.format(seq_length)
_filename += '_{}s'.format(seed)
doc_idx_filename = _filename + '_doc_idx.npy'
sample_idx_filename = _filename + '_sample_idx.npy'
shuffle_idx_filename = _filename + '_shuffle_idx.npy'
# Build the indexed mapping if not exist.
if torch.distributed.get_rank() == 0:
if (
(not os.path.isfile(doc_idx_filename))
or (not os.path.isfile(sample_idx_filename))
or (not os.path.isfile(shuffle_idx_filename))
):
logging.info(' > WARNING: could not find index map files, building ' 'the indices on rank 0 ...')
# For the last epoch, decide whether include the entire epoch
# in the global shuffle or not.
# If we need only one epoch, then separating last epoch does
# not mean anything.
if num_epochs == 1:
separate_last_epoch = False
print(' > only one epoch required, setting ' 'separate_last_epoch to False', flush=True)
else:
# Get the number of samples for the last epoch
num_samples_from_epochs_minus_one = ((num_epochs - 1) * tokens_per_epoch - 1) // seq_length
last_epoch_num_samples = num_samples - num_samples_from_epochs_minus_one
assert last_epoch_num_samples >= 0, 'last epoch number of samples should be non-negative.'
num_samples_per_epoch = (tokens_per_epoch - 1) // seq_length
assert last_epoch_num_samples < (
num_samples_per_epoch + 1
), 'last epoch number of samples exceeded max value.'
# If we have less than 80% of the samples for the last epoch,
# seperate out the epoch and treat it differently.
# Note: the 80% number is just based on common sense and can
# be adjusted if needed.
separate_last_epoch = last_epoch_num_samples < int(0.80 * num_samples_per_epoch)
if separate_last_epoch:
string = (
' > last epoch number of samples ({}) is smaller '
'than 80% of number of samples per epoch ({}), '
'setting separate_last_epoch to True'
)
else:
string = (
' > last epoch number of samples ({}) is larger '
'than 80% of number of samples per epoch ({}), '
'setting separate_last_epoch to False'
)
print(string.format(last_epoch_num_samples, num_samples_per_epoch), flush=True)
# doc-idx.
start_time = time.time()
doc_idx = _build_doc_idx(documents, num_epochs, np_rng, separate_last_epoch)
np.save(doc_idx_filename, doc_idx, allow_pickle=True)
logging.info(
' > elasped time to build and save doc-idx mapping '
'(seconds): {:4f}'.format(time.time() - start_time)
)
# sample-idx.
start_time = time.time()
# Use C++ implementation for speed.
# First compile and then import.
assert doc_idx.dtype == np.int32
assert sizes.dtype == np.int32
try:
if is_global_rank_zero():
from nemo.collections.nlp.data.language_modeling.megatron.dataset_utils import compile_helper
compile_helper()
from nemo.collections.nlp.data.language_modeling.megatron import helpers
except:
raise Exception(f'Could not compile helpers.')
sample_idx = helpers.build_sample_idx(sizes, doc_idx, seq_length, num_epochs, tokens_per_epoch)
# sample_idx = _build_sample_idx(sizes, doc_idx, seq_length,
# num_epochs, tokens_per_epoch)
np.save(sample_idx_filename, sample_idx, allow_pickle=True)
logging.info(
' > elasped time to build and save sample-idx mapping '
'(seconds): {:4f}'.format(time.time() - start_time)
)
# shuffle-idx.
start_time = time.time()
# -1 is due to data structure used to retieve the index:
# sample i --> [sample_idx[i], sample_idx[i+1])
if separate_last_epoch:
num_samples_ = num_samples_from_epochs_minus_one
else:
num_samples_ = sample_idx.shape[0] - 1
shuffle_idx = _build_shuffle_idx(num_samples_, sample_idx.shape[0] - 1, np_rng)
np.save(shuffle_idx_filename, shuffle_idx, allow_pickle=True)
logging.info(
' > elasped time to build and save shuffle-idx mapping'
' (seconds): {:4f}'.format(time.time() - start_time)
)
torch.distributed.barrier()
counts = torch.cuda.LongTensor([1])
torch.distributed.all_reduce(counts, group=parallel_state.get_data_parallel_group())
torch.distributed.all_reduce(counts, group=parallel_state.get_pipeline_model_parallel_group())
assert counts[0].item() == (
torch.distributed.get_world_size()
// torch.distributed.get_world_size(group=parallel_state.get_tensor_model_parallel_group())
)
# Load mappings.
start_time = time.time()
logging.info(' > loading doc-idx mapping from {}'.format(doc_idx_filename))
doc_idx = np.load(doc_idx_filename, allow_pickle=True, mmap_mode='r')
logging.info(' > loading sample-idx mapping from {}'.format(sample_idx_filename))
sample_idx = np.load(sample_idx_filename, allow_pickle=True, mmap_mode='r')
logging.info(' > loading shuffle-idx mapping from {}'.format(shuffle_idx_filename))
shuffle_idx = np.load(shuffle_idx_filename, allow_pickle=True, mmap_mode='r')
logging.info(' loaded indexed file in {:3.3f} seconds'.format(time.time() - start_time))
logging.info(' total number of samples: {}'.format(sample_idx.shape[0]))
logging.info(' total number of epochs: {}'.format(num_epochs))
return doc_idx, sample_idx, shuffle_idx
def _num_tokens(documents, sizes):
"""Total number of tokens in the dataset."""
return np.sum(sizes[documents])
def _num_epochs(tokens_per_epoch, seq_length, num_samples):
"""Based on number of samples and sequence lenght, calculate how many
epochs will be needed."""
num_epochs = 0
total_tokens = 0
while True:
num_epochs += 1
total_tokens += tokens_per_epoch
# -1 is because we need to retrieve seq_length + 1 token each time
# but the last token will overlap with the first token of the next
# sample except for the last sample.
if ((total_tokens - 1) // seq_length) >= num_samples:
return num_epochs
def _build_doc_idx(documents, num_epochs, np_rng, separate_last_epoch):
"""Build an array with length = number-of-epochs * number-of-dcuments.
Each index is mapped to a corresponding document."""
if not separate_last_epoch or num_epochs == 1:
doc_idx = np.mgrid[0:num_epochs, 0 : len(documents)][1]
doc_idx[:] = documents
doc_idx = doc_idx.reshape(-1)
doc_idx = doc_idx.astype(np.int32)
np_rng.shuffle(doc_idx)
return doc_idx
doc_idx_first = _build_doc_idx(documents, num_epochs - 1, np_rng, False)
doc_idx_last = _build_doc_idx(documents, 1, np_rng, False)
return np.concatenate((doc_idx_first, doc_idx_last))
def _build_sample_idx(sizes, doc_idx, seq_length, num_epochs, tokens_per_epoch):
"""Sample index mapping is a 2D array with sizes
[number-of-samples + 1, 2] where [..., 0] contains
the index into `doc_idx` and [..., 1] is the
starting offset in that document."""
# Total number of samples. For -1 see comments in `_num_epochs`.
num_samples = (num_epochs * tokens_per_epoch - 1) // seq_length
sample_idx = np.zeros([num_samples + 1, 2], dtype=np.int32)
# Index into sample_idx.
sample_index = 0
# Index into doc_idx.
doc_idx_index = 0
# Begining offset for each document.
doc_offset = 0
# Start with first document and no offset.
sample_idx[sample_index][0] = doc_idx_index
sample_idx[sample_index][1] = doc_offset
sample_index += 1
while sample_index <= num_samples:
# Start with a fresh sequence.
remaining_seq_length = seq_length + 1
while remaining_seq_length != 0:
# Get the document length.
doc_id = doc_idx[doc_idx_index]
doc_length = sizes[doc_id] - doc_offset
# And add it to the current sequence.
remaining_seq_length -= doc_length
# If we have more than a full sequence, adjust offset and set
# remaining length to zero so we return from the while loop.
# Note that -1 here is for the same reason we have -1 in
# `_num_epochs` calculations.
if remaining_seq_length <= 0:
doc_offset += remaining_seq_length + doc_length - 1
remaining_seq_length = 0
else:
# Otherwise, start from the begining of the next document.
doc_idx_index += 1
doc_offset = 0
# Record the sequence.
sample_idx[sample_index][0] = doc_idx_index
sample_idx[sample_index][1] = doc_offset
sample_index += 1
return sample_idx
def _build_shuffle_idx(num_samples, total_size, np_rng):
"""Build the range [0, size) and shuffle."""
print(
' > building shuffle index with split [0, {}) and [{}, {}) '
'...'.format(num_samples, num_samples, total_size),
flush=True,
)
dtype_ = np.uint32
if total_size >= (np.iinfo(np.uint32).max - 1):
dtype_ = np.int64
shuffle_idx_first = np.arange(start=0, stop=num_samples, step=1, dtype=dtype_)
np_rng.shuffle(shuffle_idx_first)
if num_samples == total_size:
return shuffle_idx_first
shuffle_idx_last = np.arange(start=num_samples, stop=total_size, step=1, dtype=dtype_)
np_rng.shuffle(shuffle_idx_last)
return np.concatenate((shuffle_idx_first, shuffle_idx_last))

View file

@ -0,0 +1,37 @@
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict
import torch
from torch.utils.data.dataset import Dataset
class GPTRequestDataset(Dataset):
def __init__(self, request: Dict, tokenizer) -> None:
super().__init__()
self.request = request
self.tokenizer = tokenizer
# tokenize prompt
self.request['tokenized_prompt'] = self.tokenizer.text_to_tokens(request['prompt'])
tokens = self.tokenizer.text_to_ids(request['prompt'])
self.request['tokens'] = torch.tensor(tokens)
def __len__(self):
return 1
def __getitem__(self, index):
return self.request

View file

@ -0,0 +1,716 @@
/*
Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
/* Helper methods for fast index mapping builds */
#include <algorithm>
#include <iostream>
#include <limits>
#include <math.h>
#include <stdexcept>
#include <pybind11/pybind11.h>
#include <pybind11/numpy.h>
#include <random>
namespace py = pybind11;
using namespace std;
const int32_t LONG_SENTENCE_LEN = 512;
void build_blending_indices(py::array_t<uint8_t>& dataset_index,
py::array_t<int64_t>& dataset_sample_index,
const py::array_t<double>& weights,
const int32_t num_datasets,
const int64_t size, const bool verbose) {
/* Given multiple datasets and a weighting array, build samples
such that it follows those wieghts.*/
if (verbose) {
std::cout << "> building indices for blendable datasets ..." << std::endl;
}
// Get the pointer access without the checks.
auto dataset_index_ptr = dataset_index.mutable_unchecked<1>();
auto dataset_sample_index_ptr = dataset_sample_index.mutable_unchecked<1>();
auto weights_ptr = weights.unchecked<1>();
// Initialize buffer for number of samples used for each dataset.
int64_t current_samples[num_datasets];
for(int64_t i = 0; i < num_datasets; ++i) {
current_samples[i] = 0;
}
// For each sample:
for(int64_t sample_idx = 0; sample_idx < size; ++sample_idx) {
// Determine where the max error in sampling is happening.
auto sample_idx_double = std::max(static_cast<double>(sample_idx), 1.0);
int64_t max_error_index = 0;
double max_error = weights_ptr[0] * sample_idx_double -
static_cast<double>(current_samples[0]);
for (int64_t dataset_idx = 1; dataset_idx < num_datasets; ++dataset_idx) {
double error = weights_ptr[dataset_idx] * sample_idx_double -
static_cast<double>(current_samples[dataset_idx]);
if (error > max_error) {
max_error = error;
max_error_index = dataset_idx;
}
}
// Populate the indices.
dataset_index_ptr[sample_idx] = static_cast<uint8_t>(max_error_index);
dataset_sample_index_ptr[sample_idx] = current_samples[max_error_index];
// Update the total samples.
current_samples[max_error_index] += 1;
}
// print info
if (verbose) {
std::cout << " > sample ratios:" << std::endl;
for (int64_t dataset_idx = 0; dataset_idx < num_datasets; ++dataset_idx) {
auto ratio = static_cast<double>(current_samples[dataset_idx]) /
static_cast<double>(size);
std::cout << " dataset " << dataset_idx << ", input: " <<
weights_ptr[dataset_idx] << ", achieved: " << ratio << std::endl;
}
}
}
py::array build_sample_idx(const py::array_t<int32_t>& sizes_,
const py::array_t<int32_t>& doc_idx_,
const int32_t seq_length,
const int32_t num_epochs,
const int64_t tokens_per_epoch) {
/* Sample index (sample_idx) is used for gpt2 like dataset for which
the documents are flattened and the samples are built based on this
1-D flatten array. It is a 2D array with sizes [number-of-samples + 1, 2]
where [..., 0] contains the index into `doc_idx` and [..., 1] is the
starting offset in that document.*/
// Consistency checks.
assert(seq_length > 1);
assert(num_epochs > 0);
assert(tokens_per_epoch > 1);
// Remove bound checks.
auto sizes = sizes_.unchecked<1>();
auto doc_idx = doc_idx_.unchecked<1>();
// Mapping and it's length (1D).
int64_t num_samples = (num_epochs * tokens_per_epoch - 1) / seq_length;
int32_t* sample_idx = new int32_t[2*(num_samples+1)];
cout << " using:" << endl << std::flush;
cout << " number of documents: " <<
doc_idx_.shape(0) / num_epochs << endl << std::flush;
cout << " number of epochs: " << num_epochs <<
endl << std::flush;
cout << " sequence length: " << seq_length <<
endl << std::flush;
cout << " total number of samples: " << num_samples <<
endl << std::flush;
// Index into sample_idx.
int64_t sample_index = 0;
// Index into doc_idx.
int64_t doc_idx_index = 0;
// Begining offset for each document.
int32_t doc_offset = 0;
// Start with first document and no offset.
sample_idx[2 * sample_index] = doc_idx_index;
sample_idx[2 * sample_index + 1] = doc_offset;
++sample_index;
while (sample_index <= num_samples) {
// Start with a fresh sequence.
int32_t remaining_seq_length = seq_length + 1;
while (remaining_seq_length != 0) {
// Get the document length.
auto doc_id = doc_idx[doc_idx_index];
auto doc_length = sizes[doc_id] - doc_offset;
// And add it to the current sequence.
remaining_seq_length -= doc_length;
// If we have more than a full sequence, adjust offset and set
// remaining length to zero so we return from the while loop.
// Note that -1 here is for the same reason we have -1 in
// `_num_epochs` calculations.
if (remaining_seq_length <= 0) {
doc_offset += (remaining_seq_length + doc_length - 1);
remaining_seq_length = 0;
} else {
// Otherwise, start from the begining of the next document.
++doc_idx_index;
doc_offset = 0;
}
}
// Record the sequence.
sample_idx[2 * sample_index] = doc_idx_index;
sample_idx[2 * sample_index + 1] = doc_offset;
++sample_index;
}
// Method to deallocate memory.
py::capsule free_when_done(sample_idx, [](void *mem_) {
int32_t *mem = reinterpret_cast<int32_t*>(mem_);
delete[] mem;
});
// Return the numpy array.
const auto byte_size = sizeof(int32_t);
return py::array(std::vector<int64_t>{num_samples+1, 2}, // shape
{2*byte_size, byte_size}, // C-style contiguous strides
sample_idx, // the data pointer
free_when_done); // numpy array references
}
inline int32_t get_target_sample_len(const int32_t short_seq_ratio,
const int32_t max_length,
std::mt19937& rand32_gen) {
/* Training sample length. */
if (short_seq_ratio == 0) {
return max_length;
}
const auto random_number = rand32_gen();
if ((random_number % short_seq_ratio) == 0) {
return 2 + random_number % (max_length - 1);
}
return max_length;
}
template<typename DocIdx>
py::array build_mapping_impl(const py::array_t<int64_t>& docs_,
const py::array_t<int32_t>& sizes_,
const int32_t num_epochs,
const uint64_t max_num_samples,
const int32_t max_seq_length,
const double short_seq_prob,
const int32_t seed,
const bool verbose,
const int32_t min_num_sent) {
/* Build a mapping of (start-index, end-index, sequence-length) where
start and end index are the indices of the sentences in the sample
and sequence-length is the target sequence length.
*/
// Consistency checks.
assert(num_epochs > 0);
assert(max_seq_length > 1);
assert(short_seq_prob >= 0.0);
assert(short_seq_prob <= 1.0);
assert(seed > 0);
// Remove bound checks.
auto docs = docs_.unchecked<1>();
auto sizes = sizes_.unchecked<1>();
// For efficiency, convert probability to ratio. Note: rand() generates int.
int32_t short_seq_ratio = 0;
if (short_seq_prob > 0) {
short_seq_ratio = static_cast<int32_t>(round(1.0 / short_seq_prob));
}
if (verbose) {
const auto sent_start_index = docs[0];
const auto sent_end_index = docs[docs_.shape(0) - 1];
const auto num_sentences = sent_end_index - sent_start_index;
cout << " using:" << endl << std::flush;
cout << " number of documents: " << docs_.shape(0) - 1 <<
endl << std::flush;
cout << " sentences range: [" << sent_start_index <<
", " << sent_end_index << ")" << endl << std::flush;
cout << " total number of sentences: " << num_sentences <<
endl << std::flush;
cout << " number of epochs: " << num_epochs <<
endl << std::flush;
cout << " maximum number of samples: " << max_num_samples <<
endl << std::flush;
cout << " maximum sequence length: " << max_seq_length <<
endl << std::flush;
cout << " short sequence probability: " << short_seq_prob <<
endl << std::flush;
cout << " short sequence ration (1/prob): " << short_seq_ratio <<
endl << std::flush;
cout << " seed: " << seed << endl <<
std::flush;
}
// Mapping and it's length (1D).
int64_t num_samples = -1;
DocIdx* maps = NULL;
// Perform two iterations, in the first iteration get the size
// and allocate memory and in the second iteration populate the map.
bool second = false;
for (int32_t iteration=0; iteration<2; ++iteration) {
// Set the seed so both iterations produce the same results.
std::mt19937 rand32_gen(seed);
// Set the flag on second iteration.
second = (iteration == 1);
// Counters:
uint64_t empty_docs = 0;
uint64_t one_sent_docs = 0;
uint64_t long_sent_docs = 0;
// Current map index.
uint64_t map_index = 0;
// For each epoch:
for (int32_t epoch=0; epoch<num_epochs; ++epoch) {
if (map_index >= max_num_samples) {
if (verbose && (!second)) {
cout << " reached " << max_num_samples << " samples after "
<< epoch << " epochs ..." << endl << std::flush;
}
break;
}
// For each document:
for (int32_t doc=0; doc<(docs.shape(0) - 1); ++doc) {
// Document sentences are in [sent_index_first, sent_index_last)
const auto sent_index_first = docs[doc];
const auto sent_index_last = docs[doc + 1];
// At the begining of the document previous index is the
// start index.
auto prev_start_index = sent_index_first;
// Remaining documents.
auto num_remain_sent = sent_index_last - sent_index_first;
// Some bookkeeping
if ((epoch == 0) && (!second)) {
if (num_remain_sent == 0) {
++empty_docs;
}
if (num_remain_sent == 1) {
++one_sent_docs;
}
}
// Detect documents with long sentences.
bool contains_long_sentence = false;
if (num_remain_sent > 1) {
for (auto sent_index=sent_index_first;
sent_index < sent_index_last; ++sent_index) {
if (sizes[sent_index] > LONG_SENTENCE_LEN){
if ((epoch == 0) && (!second)) {
++long_sent_docs;
}
contains_long_sentence = true;
break;
}
}
}
// If we have more than two sentences.
if ((num_remain_sent >= min_num_sent) && (!contains_long_sentence)) {
// Set values.
auto seq_len = int32_t{0};
auto num_sent = int32_t{0};
auto target_seq_len = get_target_sample_len(short_seq_ratio,
max_seq_length,
rand32_gen);
// Loop through sentences.
for (auto sent_index=sent_index_first;
sent_index < sent_index_last; ++sent_index) {
// Add the size and number of sentences.
seq_len += sizes[sent_index];
++num_sent;
--num_remain_sent;
// If we have reached the target length.
// and if not only one sentence is left in the document.
// and if we have at least two sentneces.
// and if we have reached end of the document.
if (((seq_len >= target_seq_len) &&
(num_remain_sent > 1) &&
(num_sent >= min_num_sent) ) || (num_remain_sent == 0)) {
// Check for overflow.
if ((3 * map_index + 2) >
std::numeric_limits<int64_t>::max()) {
cout << "number of samples exceeded maximum "
<< "allowed by type int64: "
<< std::numeric_limits<int64_t>::max()
<< endl;
throw std::overflow_error("Number of samples");
}
// Populate the map.
if (second) {
const auto map_index_0 = 3 * map_index;
maps[map_index_0] = static_cast<DocIdx>(prev_start_index);
maps[map_index_0 + 1] = static_cast<DocIdx>(sent_index + 1);
maps[map_index_0 + 2] = static_cast<DocIdx>(target_seq_len);
}
// Update indices / counters.
++map_index;
prev_start_index = sent_index + 1;
target_seq_len = get_target_sample_len(short_seq_ratio,
max_seq_length,
rand32_gen);
seq_len = 0;
num_sent = 0;
}
} // for (auto sent_index=sent_index_first; ...
} // if (num_remain_sent > 1) {
} // for (int doc=0; doc < num_docs; ++doc) {
} // for (int epoch=0; epoch < num_epochs; ++epoch) {
if (!second) {
if (verbose) {
cout << " number of empty documents: " << empty_docs <<
endl << std::flush;
cout << " number of documents with one sentence: " <<
one_sent_docs << endl << std::flush;
cout << " number of documents with long sentences: " <<
long_sent_docs << endl << std::flush;
cout << " will create mapping for " << map_index <<
" samples" << endl << std::flush;
}
assert(maps == NULL);
assert(num_samples < 0);
maps = new DocIdx[3*map_index];
num_samples = static_cast<int64_t>(map_index);
}
} // for (int iteration=0; iteration < 2; ++iteration) {
// Shuffle.
// We need a 64 bit random number generator as we might have more
// than 2 billion samples.
std::mt19937_64 rand64_gen(seed + 1);
for (auto i=(num_samples - 1); i > 0; --i) {
const auto j = static_cast<int64_t>(rand64_gen() % (i + 1));
const auto i0 = 3 * i;
const auto j0 = 3 * j;
// Swap values.
swap(maps[i0], maps[j0]);
swap(maps[i0 + 1], maps[j0 + 1]);
swap(maps[i0 + 2], maps[j0 + 2]);
}
// Method to deallocate memory.
py::capsule free_when_done(maps, [](void *mem_) {
DocIdx *mem = reinterpret_cast<DocIdx*>(mem_);
delete[] mem;
});
// Return the numpy array.
const auto byte_size = sizeof(DocIdx);
return py::array(std::vector<int64_t>{num_samples, 3}, // shape
{3*byte_size, byte_size}, // C-style contiguous strides
maps, // the data pointer
free_when_done); // numpy array references
}
py::array build_mapping(const py::array_t<int64_t>& docs_,
const py::array_t<int>& sizes_,
const int num_epochs,
const uint64_t max_num_samples,
const int max_seq_length,
const double short_seq_prob,
const int seed,
const bool verbose,
const int32_t min_num_sent) {
if (sizes_.size() > std::numeric_limits<uint32_t>::max()) {
if (verbose) {
cout << " using uint64 for data mapping..." << endl << std::flush;
}
return build_mapping_impl<uint64_t>(docs_, sizes_, num_epochs,
max_num_samples, max_seq_length,
short_seq_prob, seed, verbose,
min_num_sent);
} else {
if (verbose) {
cout << " using uint32 for data mapping..." << endl << std::flush;
}
return build_mapping_impl<uint32_t>(docs_, sizes_, num_epochs,
max_num_samples, max_seq_length,
short_seq_prob, seed, verbose,
min_num_sent);
}
}
template<typename DocIdx>
py::array build_blocks_mapping_impl(const py::array_t<int64_t>& docs_,
const py::array_t<int32_t>& sizes_,
const py::array_t<int32_t>& titles_sizes_,
const int32_t num_epochs,
const uint64_t max_num_samples,
const int32_t max_seq_length,
const int32_t seed,
const bool verbose,
const bool use_one_sent_blocks) {
/* Build a mapping of (start-index, end-index, sequence-length) where
start and end index are the indices of the sentences in the sample
and sequence-length is the target sequence length.
*/
// Consistency checks.
assert(num_epochs > 0);
assert(max_seq_length > 1);
assert(seed > 0);
// Remove bound checks.
auto docs = docs_.unchecked<1>();
auto sizes = sizes_.unchecked<1>();
auto titles_sizes = titles_sizes_.unchecked<1>();
if (verbose) {
const auto sent_start_index = docs[0];
const auto sent_end_index = docs[docs_.shape(0) - 1];
const auto num_sentences = sent_end_index - sent_start_index;
cout << " using:" << endl << std::flush;
cout << " number of documents: " << docs_.shape(0) - 1 <<
endl << std::flush;
cout << " sentences range: [" << sent_start_index <<
", " << sent_end_index << ")" << endl << std::flush;
cout << " total number of sentences: " << num_sentences <<
endl << std::flush;
cout << " number of epochs: " << num_epochs <<
endl << std::flush;
cout << " maximum number of samples: " << max_num_samples <<
endl << std::flush;
cout << " maximum sequence length: " << max_seq_length <<
endl << std::flush;
cout << " seed: " << seed << endl <<
std::flush;
}
// Mapping and its length (1D).
int64_t num_samples = -1;
DocIdx* maps = NULL;
// Acceptable number of sentences per block.
int min_num_sent = 2;
if (use_one_sent_blocks) {
min_num_sent = 1;
}
// Perform two iterations, in the first iteration get the size
// and allocate memory and in the second iteration populate the map.
bool second = false;
for (int32_t iteration=0; iteration<2; ++iteration) {
// Set the flag on second iteration.
second = (iteration == 1);
// Current map index.
uint64_t map_index = 0;
uint64_t empty_docs = 0;
uint64_t one_sent_docs = 0;
uint64_t long_sent_docs = 0;
// For each epoch:
for (int32_t epoch=0; epoch<num_epochs; ++epoch) {
// assign every block a unique id
int32_t block_id = 0;
if (map_index >= max_num_samples) {
if (verbose && (!second)) {
cout << " reached " << max_num_samples << " samples after "
<< epoch << " epochs ..." << endl << std::flush;
}
break;
}
// For each document:
for (int32_t doc=0; doc<(docs.shape(0) - 1); ++doc) {
// Document sentences are in [sent_index_first, sent_index_last)
const auto sent_index_first = docs[doc];
const auto sent_index_last = docs[doc + 1];
const auto target_seq_len = max_seq_length - titles_sizes[doc];
// At the begining of the document previous index is the
// start index.
auto prev_start_index = sent_index_first;
// Remaining documents.
auto num_remain_sent = sent_index_last - sent_index_first;
// Some bookkeeping
if ((epoch == 0) && (!second)) {
if (num_remain_sent == 0) {
++empty_docs;
}
if (num_remain_sent == 1) {
++one_sent_docs;
}
}
// Detect documents with long sentences.
bool contains_long_sentence = false;
if (num_remain_sent >= min_num_sent) {
for (auto sent_index=sent_index_first;
sent_index < sent_index_last; ++sent_index) {
if (sizes[sent_index] > LONG_SENTENCE_LEN){
if ((epoch == 0) && (!second)) {
++long_sent_docs;
}
contains_long_sentence = true;
break;
}
}
}
// If we have enough sentences and no long sentences.
if ((num_remain_sent >= min_num_sent) && (!contains_long_sentence)) {
// Set values.
auto seq_len = int32_t{0};
auto num_sent = int32_t{0};
// Loop through sentences.
for (auto sent_index=sent_index_first;
sent_index < sent_index_last; ++sent_index) {
// Add the size and number of sentences.
seq_len += sizes[sent_index];
++num_sent;
--num_remain_sent;
// If we have reached the target length.
// and there are an acceptable number of sentences left
// and if we have at least the minimum number of sentences.
// or if we have reached end of the document.
if (((seq_len >= target_seq_len) &&
(num_remain_sent >= min_num_sent) &&
(num_sent >= min_num_sent) ) || (num_remain_sent == 0)) {
// Populate the map.
if (second) {
const auto map_index_0 = 4 * map_index;
// Each sample has 4 items: the starting sentence index, ending sentence index,
// the index of the document from which the block comes (used for fetching titles)
// and the unique id of the block (used for creating block indexes)
maps[map_index_0] = static_cast<DocIdx>(prev_start_index);
maps[map_index_0 + 1] = static_cast<DocIdx>(sent_index + 1);
maps[map_index_0 + 2] = static_cast<DocIdx>(doc);
maps[map_index_0 + 3] = static_cast<DocIdx>(block_id);
}
// Update indices / counters.
++map_index;
++block_id;
prev_start_index = sent_index + 1;
seq_len = 0;
num_sent = 0;
}
} // for (auto sent_index=sent_index_first; ...
} // if (num_remain_sent > 1) {
} // for (int doc=0; doc < num_docs; ++doc) {
} // for (int epoch=0; epoch < num_epochs; ++epoch) {
if (!second) {
if (verbose) {
cout << " number of empty documents: " << empty_docs <<
endl << std::flush;
cout << " number of documents with one sentence: " <<
one_sent_docs << endl << std::flush;
cout << " number of documents with long sentences: " <<
long_sent_docs << endl << std::flush;
cout << " will create mapping for " << map_index <<
" samples" << endl << std::flush;
}
assert(maps == NULL);
assert(num_samples < 0);
maps = new DocIdx[4*map_index];
num_samples = static_cast<int64_t>(map_index);
}
} // for (int iteration=0; iteration < 2; ++iteration) {
// Shuffle.
// We need a 64 bit random number generator as we might have more
// than 2 billion samples.
std::mt19937_64 rand64_gen(seed + 1);
for (auto i=(num_samples - 1); i > 0; --i) {
const auto j = static_cast<int64_t>(rand64_gen() % (i + 1));
const auto i0 = 4 * i;
const auto j0 = 4 * j;
// Swap values.
swap(maps[i0], maps[j0]);
swap(maps[i0 + 1], maps[j0 + 1]);
swap(maps[i0 + 2], maps[j0 + 2]);
swap(maps[i0 + 3], maps[j0 + 3]);
}
// Method to deallocate memory.
py::capsule free_when_done(maps, [](void *mem_) {
DocIdx *mem = reinterpret_cast<DocIdx*>(mem_);
delete[] mem;
});
// Return the numpy array.
const auto byte_size = sizeof(DocIdx);
return py::array(std::vector<int64_t>{num_samples, 4}, // shape
{4*byte_size, byte_size}, // C-style contiguous strides
maps, // the data pointer
free_when_done); // numpy array references
}
py::array build_blocks_mapping(const py::array_t<int64_t>& docs_,
const py::array_t<int>& sizes_,
const py::array_t<int>& titles_sizes_,
const int num_epochs,
const uint64_t max_num_samples,
const int max_seq_length,
const int seed,
const bool verbose,
const bool use_one_sent_blocks) {
if (sizes_.size() > std::numeric_limits<uint32_t>::max()) {
if (verbose) {
cout << " using uint64 for data mapping..." << endl << std::flush;
}
return build_blocks_mapping_impl<uint64_t>(docs_, sizes_, titles_sizes_,
num_epochs, max_num_samples, max_seq_length, seed, verbose, use_one_sent_blocks);
} else {
if (verbose) {
cout << " using uint32 for data mapping..." << endl << std::flush;
}
return build_blocks_mapping_impl<uint32_t>(docs_, sizes_, titles_sizes_,
num_epochs, max_num_samples, max_seq_length, seed, verbose, use_one_sent_blocks);
}
}
PYBIND11_MODULE(helpers, m) {
m.def("build_mapping", &build_mapping);
m.def("build_blocks_mapping", &build_blocks_mapping);
m.def("build_sample_idx", &build_sample_idx);
m.def("build_blending_indices", &build_blending_indices);
}

View file

@ -0,0 +1,564 @@
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# Most of the code here has been copied from:
# fairseq/fairseq/data/indexed_dataset.py
# with some modifications:
# Removed IndexedRawTextDataset since it relied on Fairseq dictionary
# other slight modifications to remove fairseq dependencies
# Added document index to index file and made it accessible.
# An empty sentence no longer separates documents.
import os
import shutil
import struct
from functools import lru_cache
from itertools import accumulate
import numpy as np
import torch
from nemo.utils import logging
def __best_fitting_dtype(vocab_size=None):
if vocab_size is not None and vocab_size < 65500:
return np.uint16
else:
return np.int32
def get_available_dataset_impl():
return ['lazy', 'cached', 'mmap']
def infer_dataset_impl(path):
if IndexedDataset.exists(path):
with open(index_file_path(path), 'rb') as f:
magic = f.read(8)
if magic == IndexedDataset._HDR_MAGIC:
return 'cached'
elif magic == MMapIndexedDataset.Index._HDR_MAGIC[:8]:
return 'mmap'
else:
return None
else:
print(f"Dataset does not exist: {path}")
print("Path should be a basename that both .idx and .bin can be appended to get full filenames.")
return None
def make_builder(out_file, impl, vocab_size=None):
if impl == 'mmap':
return MMapIndexedDatasetBuilder(out_file, dtype=__best_fitting_dtype(vocab_size))
else:
return IndexedDatasetBuilder(out_file)
def make_dataset(path, impl, skip_warmup=False):
if not IndexedDataset.exists(path):
print(f"Dataset does not exist: {path}")
print("Path should be a basename that both .idx and .bin can be appended to get full filenames.")
return None
if impl == 'infer':
impl = infer_dataset_impl(path)
if impl == 'lazy' and IndexedDataset.exists(path):
return IndexedDataset(path)
elif impl == 'cached' and IndexedDataset.exists(path):
return IndexedCachedDataset(path)
elif impl == 'mmap' and MMapIndexedDataset.exists(path):
return MMapIndexedDataset(path, skip_warmup)
print(f"Unknown dataset implementation: {impl}")
return None
def dataset_exists(path, impl):
if impl == 'mmap':
return MMapIndexedDataset.exists(path)
else:
return IndexedDataset.exists(path)
def read_longs(f, n):
a = np.empty(n, dtype=np.int64)
f.readinto(a)
return a
def write_longs(f, a):
f.write(np.array(a, dtype=np.int64))
dtypes = {1: np.uint8, 2: np.int8, 3: np.int16, 4: np.int32, 5: np.int64, 6: np.float, 7: np.double, 8: np.uint16}
def code(dtype):
for k in dtypes.keys():
if dtypes[k] == dtype:
return k
raise ValueError(dtype)
def index_file_path(prefix_path):
return prefix_path + '.idx'
def data_file_path(prefix_path):
return prefix_path + '.bin'
def create_doc_idx(sizes):
doc_idx = [0]
for i, s in enumerate(sizes):
if s == 0:
doc_idx.append(i + 1)
return doc_idx
class IndexedDataset(torch.utils.data.Dataset):
"""Loader for IndexedDataset"""
_HDR_MAGIC = b'TNTIDX\x00\x00'
def __init__(self, path):
super().__init__()
self.path = path
self.data_file = None
self.read_index(path)
def read_index(self, path):
with open(index_file_path(path), 'rb') as f:
magic = f.read(8)
assert magic == self._HDR_MAGIC, (
'Index file doesn\'t match expected format. ' 'Make sure that --dataset-impl is configured properly.'
)
version = f.read(8)
assert struct.unpack('<Q', version) == (1,)
code, self.element_size = struct.unpack('<QQ', f.read(16))
self.dtype = dtypes[code]
self._len, self.s = struct.unpack('<QQ', f.read(16))
self.doc_count = struct.unpack('<Q', f.read(8))
self.dim_offsets = read_longs(f, self._len + 1)
self.data_offsets = read_longs(f, self._len + 1)
self.sizes = read_longs(f, self.s)
self.doc_idx = read_longs(f, self.doc_count)
def read_data(self, path):
self.data_file = open(data_file_path(path), 'rb', buffering=0)
def check_index(self, i):
if i < 0 or i >= self._len:
raise IndexError('index out of range')
def __del__(self):
if self.data_file:
self.data_file.close()
# @lru_cache(maxsize=8)
def __getitem__(self, idx):
if not self.data_file:
self.read_data(self.path)
if isinstance(idx, int):
i = idx
self.check_index(i)
tensor_size = self.sizes[self.dim_offsets[i] : self.dim_offsets[i + 1]]
a = np.empty(tensor_size, dtype=self.dtype)
self.data_file.seek(self.data_offsets[i] * self.element_size)
self.data_file.readinto(a)
return a
elif isinstance(idx, slice):
start, stop, step = idx.indices(len(self))
if step != 1:
raise ValueError("Slices into indexed_dataset must be contiguous")
sizes = self.sizes[self.dim_offsets[start] : self.dim_offsets[stop]]
size = sum(sizes)
a = np.empty(size, dtype=self.dtype)
self.data_file.seek(self.data_offsets[start] * self.element_size)
self.data_file.readinto(a)
offsets = list(accumulate(sizes))
sents = np.split(a, offsets[:-1])
return sents
def __len__(self):
return self._len
def num_tokens(self, index):
return self.sizes[index]
def size(self, index):
return self.sizes[index]
@staticmethod
def exists(path):
return os.path.exists(index_file_path(path)) and os.path.exists(data_file_path(path))
@property
def supports_prefetch(self):
return False # avoid prefetching to save memory
class IndexedCachedDataset(IndexedDataset):
def __init__(self, path):
super().__init__(path)
self.cache = None
self.cache_index = {}
@property
def supports_prefetch(self):
return True
def prefetch(self, indices):
if all(i in self.cache_index for i in indices):
return
if not self.data_file:
self.read_data(self.path)
indices = sorted(set(indices))
total_size = 0
for i in indices:
total_size += self.data_offsets[i + 1] - self.data_offsets[i]
self.cache = np.empty(total_size, dtype=self.dtype)
ptx = 0
self.cache_index.clear()
for i in indices:
self.cache_index[i] = ptx
size = self.data_offsets[i + 1] - self.data_offsets[i]
a = self.cache[ptx : ptx + size]
self.data_file.seek(self.data_offsets[i] * self.element_size)
self.data_file.readinto(a)
ptx += size
if self.data_file:
# close and delete data file after prefetch so we can pickle
self.data_file.close()
self.data_file = None
# @lru_cache(maxsize=8)
def __getitem__(self, idx):
if isinstance(idx, int):
i = idx
self.check_index(i)
tensor_size = self.sizes[self.dim_offsets[i] : self.dim_offsets[i + 1]]
a = np.empty(tensor_size, dtype=self.dtype)
ptx = self.cache_index[i]
np.copyto(a, self.cache[ptx : ptx + a.size])
return a
elif isinstance(idx, slice):
# Hack just to make this work, can optimizer later if necessary
sents = []
for i in range(*idx.indices(len(self))):
sents.append(self[i])
return sents
class IndexedDatasetBuilder(object):
element_sizes = {np.uint8: 1, np.int8: 1, np.int16: 2, np.int32: 4, np.int64: 8, np.float: 4, np.double: 8}
def __init__(self, out_file, dtype=np.int32):
self.out_file = open(out_file, 'wb')
self.dtype = dtype
self.data_offsets = [0]
self.dim_offsets = [0]
self.sizes = []
self.element_size = self.element_sizes[self.dtype]
self.doc_idx = [0]
def add_item(self, tensor):
bytes = self.out_file.write(np.array(tensor.numpy(), dtype=self.dtype))
self.data_offsets.append(self.data_offsets[-1] + bytes / self.element_size)
for s in tensor.size():
self.sizes.append(s)
self.dim_offsets.append(self.dim_offsets[-1] + len(tensor.size()))
def end_document(self):
self.doc_idx.append(len(self.sizes))
def merge_file_(self, another_file):
index = IndexedDataset(another_file)
assert index.dtype == self.dtype
begin = self.data_offsets[-1]
for offset in index.data_offsets[1:]:
self.data_offsets.append(begin + offset)
self.sizes.extend(index.sizes)
begin = self.dim_offsets[-1]
for dim_offset in index.dim_offsets[1:]:
self.dim_offsets.append(begin + dim_offset)
with open(data_file_path(another_file), 'rb') as f:
while True:
data = f.read(1024)
if data:
self.out_file.write(data)
else:
break
def finalize(self, index_file):
self.out_file.close()
index = open(index_file, 'wb')
index.write(b'TNTIDX\x00\x00')
index.write(struct.pack('<Q', 1))
index.write(struct.pack('<QQ', code(self.dtype), self.element_size))
index.write(struct.pack('<QQ', len(self.data_offsets) - 1, len(self.sizes)))
index.write(struct.pack('<Q', len(self.doc_idx)))
write_longs(index, self.dim_offsets)
write_longs(index, self.data_offsets)
write_longs(index, self.sizes)
write_longs(index, self.doc_idx)
index.close()
def _warmup_mmap_file(path):
with open(path, 'rb') as stream:
while stream.read(100 * 1024 * 1024):
pass
class MMapIndexedDataset(torch.utils.data.Dataset):
class Index(object):
_HDR_MAGIC = b'MMIDIDX\x00\x00'
@classmethod
def writer(cls, path, dtype):
class _Writer(object):
def __enter__(self):
self._file = open(path, 'wb')
self._file.write(cls._HDR_MAGIC)
self._file.write(struct.pack('<Q', 1))
self._file.write(struct.pack('<B', code(dtype)))
return self
@staticmethod
def _get_pointers(sizes):
dtype_size = dtype().itemsize
address = 0
pointers = []
for size in sizes:
pointers.append(address)
address += size * dtype_size
return pointers
def write(self, sizes, doc_idx):
pointers = self._get_pointers(sizes)
self._file.write(struct.pack('<Q', len(sizes)))
self._file.write(struct.pack('<Q', len(doc_idx)))
sizes = np.array(sizes, dtype=np.int32)
self._file.write(sizes.tobytes(order='C'))
del sizes
pointers = np.array(pointers, dtype=np.int64)
self._file.write(pointers.tobytes(order='C'))
del pointers
doc_idx = np.array(doc_idx, dtype=np.int64)
self._file.write(doc_idx.tobytes(order='C'))
def __exit__(self, exc_type, exc_val, exc_tb):
self._file.close()
return _Writer()
def __init__(self, path, skip_warmup=False):
with open(path, 'rb') as stream:
magic_test = stream.read(9)
assert self._HDR_MAGIC == magic_test, (
'Index file doesn\'t match expected format. '
'Make sure that --dataset-impl is configured properly.'
)
version = struct.unpack('<Q', stream.read(8))
assert (1,) == version
(dtype_code,) = struct.unpack('<B', stream.read(1))
self._dtype = dtypes[dtype_code]
self._dtype_size = self._dtype().itemsize
self._len = struct.unpack('<Q', stream.read(8))[0]
self._doc_count = struct.unpack('<Q', stream.read(8))[0]
offset = stream.tell()
if not skip_warmup:
logging.info(" warming up index mmap file...")
_warmup_mmap_file(path)
self._bin_buffer_mmap = np.memmap(path, mode='r', order='C')
self._bin_buffer = memoryview(self._bin_buffer_mmap)
logging.info(" reading sizes...")
self._sizes = np.frombuffer(self._bin_buffer, dtype=np.int32, count=self._len, offset=offset)
logging.info(" reading pointers...")
self._pointers = np.frombuffer(
self._bin_buffer, dtype=np.int64, count=self._len, offset=offset + self._sizes.nbytes
)
logging.info(" reading document index...")
self._doc_idx = np.frombuffer(
self._bin_buffer,
dtype=np.int64,
count=self._doc_count,
offset=offset + self._sizes.nbytes + self._pointers.nbytes,
)
def __del__(self):
self._bin_buffer_mmap._mmap.close()
del self._bin_buffer_mmap
@property
def dtype(self):
return self._dtype
@property
def sizes(self):
return self._sizes
@property
def doc_idx(self):
return self._doc_idx
@lru_cache(maxsize=8)
def __getitem__(self, i):
return self._pointers[i], self._sizes[i]
def __len__(self):
return self._len
def __init__(self, path, skip_warmup=False):
super().__init__()
self._path = None
self._index = None
self._bin_buffer = None
self._do_init(path, skip_warmup)
def __getstate__(self):
return self._path
# def __setstate__(self, state):
# self._do_init(state)
def _do_init(self, path, skip_warmup):
self._path = path
self._index = self.Index(index_file_path(self._path), skip_warmup)
if not skip_warmup:
logging.info(" warming up data mmap file...")
_warmup_mmap_file(data_file_path(self._path))
logging.info(" creating numpy buffer of mmap...")
self._bin_buffer_mmap = np.memmap(data_file_path(self._path), mode='r', order='C')
logging.info(" creating memory view of numpy buffer...")
self._bin_buffer = memoryview(self._bin_buffer_mmap)
def __del__(self):
self._bin_buffer_mmap._mmap.close()
del self._bin_buffer_mmap
del self._index
def __len__(self):
return len(self._index)
# @lru_cache(maxsize=8)
def __getitem__(self, idx):
if isinstance(idx, int):
ptr, size = self._index[idx]
np_array = np.frombuffer(self._bin_buffer, dtype=self._index.dtype, count=size, offset=ptr)
return np_array
elif isinstance(idx, slice):
start, stop, step = idx.indices(len(self))
if step != 1:
raise ValueError("Slices into indexed_dataset must be contiguous")
ptr = self._index._pointers[start]
sizes = self._index._sizes[idx]
offsets = list(accumulate(sizes))
total_size = sum(sizes)
np_array = np.frombuffer(self._bin_buffer, dtype=self._index.dtype, count=total_size, offset=ptr)
sents = np.split(np_array, offsets[:-1])
return sents
def get(self, idx, offset=0, length=None):
""" Retrieves a single item from the dataset with the option to only
return a portion of the item.
get(idx) is the same as [idx] but get() does not support slicing.
"""
ptr, size = self._index[idx]
if length is None:
length = size - offset
ptr += offset * np.dtype(self._index.dtype).itemsize
np_array = np.frombuffer(self._bin_buffer, dtype=self._index.dtype, count=length, offset=ptr)
return np_array
@property
def sizes(self):
return self._index.sizes
@property
def doc_idx(self):
return self._index.doc_idx
def get_doc_idx(self):
return self._index._doc_idx
def set_doc_idx(self, doc_idx_):
self._index._doc_idx = doc_idx_
@property
def supports_prefetch(self):
return False
@staticmethod
def exists(path):
return os.path.exists(index_file_path(path)) and os.path.exists(data_file_path(path))
class MMapIndexedDatasetBuilder(object):
def __init__(self, out_file, dtype=np.int64):
self._data_file = open(out_file, 'wb')
self._dtype = dtype
self._sizes = []
self._doc_idx = [0]
def add_item(self, tensor):
np_array = np.array(tensor.numpy(), dtype=self._dtype)
self._data_file.write(np_array.tobytes(order='C'))
self._sizes.append(np_array.size)
def end_document(self):
self._doc_idx.append(len(self._sizes))
def merge_file_(self, another_file):
# Concatenate index
index = MMapIndexedDataset.Index(index_file_path(another_file))
assert index.dtype == self._dtype
for size in index.sizes:
self._sizes.append(size)
# Concatenate data
with open(data_file_path(another_file), 'rb') as f:
shutil.copyfileobj(f, self._data_file)
def finalize(self, index_file):
self._data_file.close()
with MMapIndexedDataset.Index.writer(index_file, self._dtype) as index:
index.write(self._sizes, self._doc_idx)

View file

@ -0,0 +1,48 @@
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Megatron dataset class that handles initialization and other megatron specific things common to all datasets."""
import torch
from omegaconf.dictconfig import DictConfig
from pytorch_lightning.trainer.trainer import Trainer
from nemo.collections.nlp.modules.common.megatron.megatron_init import initialize_model_parallel_for_nemo
from nemo.utils import AppState, logging
class MegatronDataset(torch.utils.data.Dataset):
"""
Megatron GPT pretraining
"""
def __init__(self, cfg: DictConfig, trainer: Trainer):
app_state = AppState()
if not app_state._is_megatron_initialized:
logging.info(
f"Initializing megatron since it hasn't been initialized by the model. This is normal if you are using a NeMo model with Megatron dataloaders."
)
app_state.global_rank = trainer.global_rank
app_state.world_size = trainer.world_size
app_state.model_parallel_size = 1
app_state.model_parallel_rank = trainer.global_rank
initialize_model_parallel_for_nemo(
world_size=trainer.world_size,
global_rank=trainer.global_rank,
local_rank=trainer.local_rank,
tensor_model_parallel_size=cfg.get('tensor_model_parallel_size', 1),
seed=self.cfg.get('seed', 1234),
)

View file

@ -0,0 +1,322 @@
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""T5 Style dataset."""
import collections
import numpy as np
import torch
from megatron import get_tokenizer
from nemo.collections.nlp.data.language_modeling.megatron.dataset_utils import (
create_masked_lm_predictions,
get_samples_mapping,
)
class T5Dataset(torch.utils.data.Dataset):
def __init__(
self,
name,
indexed_dataset,
data_prefix,
num_epochs,
max_num_samples,
masked_lm_prob,
max_seq_length,
max_seq_length_dec,
short_seq_prob,
seed,
):
# Params to store.
self.name = name
self.seed = seed
self.masked_lm_prob = masked_lm_prob
self.max_seq_length = max_seq_length
self.max_seq_length_dec = max_seq_length_dec
# Dataset.
self.indexed_dataset = indexed_dataset
# Build the samples mapping.
self.samples_mapping = get_samples_mapping(
self.indexed_dataset,
data_prefix,
num_epochs,
max_num_samples,
self.max_seq_length - 2, # account for added tokens
short_seq_prob,
self.seed,
self.name,
False,
)
# Vocab stuff.
tokenizer = get_tokenizer()
self.vocab_id_list = list(tokenizer.inv_vocab.keys())
self.vocab_id_to_token_dict = tokenizer.inv_vocab
self.cls_id = tokenizer.cls
self.sep_id = tokenizer.sep
self.mask_id = tokenizer.mask
self.pad_id = tokenizer.pad
self.bos_id = tokenizer.bos_token_id
self.eos_id = tokenizer.eos_token_id
self.sentinel_tokens = tokenizer.additional_special_tokens_ids
assert len(self.sentinel_tokens) > 0, "Provide the argument --vocab-extra-ids 100 to the script"
def __len__(self):
return self.samples_mapping.shape[0]
def __getitem__(self, idx):
start_index, end_index, seq_length = self.samples_mapping[idx]
sample = []
for index in range(start_index, end_index):
sample.append(self.indexed_dataset[index])
# Note that this rng state should be numpy and not python since
# python randint is inclusive whereas the numpy one is exclusive.
np_rng = np.random.RandomState(seed=(self.seed + idx))
return build_training_sample(
sample,
seq_length,
self.max_seq_length, # needed for padding
self.max_seq_length_dec,
self.vocab_id_list,
self.vocab_id_to_token_dict,
self.cls_id,
self.sep_id,
self.mask_id,
self.pad_id,
self.masked_lm_prob,
np_rng,
self.bos_id,
self.eos_id,
self.sentinel_tokens,
)
def build_training_sample(
sample,
target_seq_length,
max_seq_length,
max_seq_length_dec,
vocab_id_list,
vocab_id_to_token_dict,
cls_id,
sep_id,
mask_id,
pad_id,
masked_lm_prob,
np_rng,
bos_id=None,
eos_id=None,
sentinel_tokens=None,
):
"""Build training sample.
Arguments:
sample: A list of sentences in which each sentence is a list token ids.
target_seq_length: Desired sequence length.
max_seq_length: Maximum length of the sequence. All values are padded to
this length.
vocab_id_list: List of vocabulary ids. Used to pick a random id.
vocab_id_to_token_dict: A dictionary from vocab ids to text tokens.
cls_id: Start of example id.
sep_id: Separator id.
mask_id: Mask token id.
pad_id: Padding token id.
masked_lm_prob: Probability to mask tokens.
np_rng: Random number genenrator. Note that this rng state should be
numpy and not python since python randint is inclusive for
the opper bound whereas the numpy one is exclusive.
bos_id: start of decoder example id
eos_id: end of generation id
sentinel_tokens: unique value to be substituted for every replaced span
"""
assert target_seq_length <= max_seq_length
# flatten sentences into one list
tokens = [token for sentence in sample for token in sentence]
# Truncate to `target_sequence_length`.
max_num_tokens = target_seq_length
truncated = len(tokens) > max_num_tokens
tokens = tokens[:max_num_tokens]
# Masking.
max_predictions_per_seq = masked_lm_prob * max_num_tokens
(tokens, masked_positions, masked_labels, _, masked_spans) = create_masked_lm_predictions(
tokens,
vocab_id_list,
vocab_id_to_token_dict,
masked_lm_prob,
cls_id,
sep_id,
mask_id,
max_predictions_per_seq,
np_rng,
max_ngrams=10,
geometric_dist=True,
masking_style="t5",
)
# Padding.
tokens_enc, tokens_dec_in, labels, enc_mask, dec_mask, enc_dec_mask, loss_mask = pad_and_convert_to_numpy(
tokens,
masked_positions,
masked_labels,
pad_id,
max_seq_length,
max_seq_length_dec,
masked_spans,
bos_id,
eos_id,
sentinel_tokens,
)
train_sample = {
'text_enc': tokens_enc,
'text_dec': tokens_dec_in,
'labels': labels,
'loss_mask': loss_mask,
'truncated': int(truncated),
'enc_mask': enc_mask,
'dec_mask': dec_mask,
'enc_dec_mask': enc_dec_mask,
}
return train_sample
def pad_and_convert_to_numpy(
tokens,
masked_positions,
masked_labels,
pad_id,
max_seq_length,
max_seq_length_dec,
masked_spans=None,
bos_id=None,
eos_id=None,
sentinel_tokens=None,
):
"""Pad sequences and convert them to numpy."""
sentinel_tokens = collections.deque(sentinel_tokens)
t5_input = []
(t5_decoder_in, t5_decoder_out) = ([bos_id], [])
(start_index, end_index) = (0, None)
for span in masked_spans:
flag = sentinel_tokens.popleft()
# Append the same tokens in decoder input and output
t5_decoder_in.append(flag)
t5_decoder_in.extend(span.label)
t5_decoder_out.append(flag)
t5_decoder_out.extend(span.label)
end_index = span.index[0]
t5_input.extend(tokens[start_index:end_index])
t5_input.append(flag)
# the next start index is the token after the last span token
start_index = span.index[-1] + 1
# Add <eos> token to the t5_decoder_out
t5_decoder_out.append(eos_id)
# Add the remaining tokens to the t5 input
t5_input.extend(tokens[start_index:])
# assert (len(t5_input) - len(masked_spans)) + \
# (len(t5_decoder_in) - (len(masked_spans) + 1)) == len(tokens)
# Some checks.
# Encoder-side padding mask.
num_tokens = len(t5_input)
padding_length = max_seq_length - num_tokens
assert padding_length >= 0
assert len(masked_positions) == len(masked_labels)
# Tokens..
filler = [pad_id] * padding_length
tokens_enc = np.array(t5_input + filler, dtype=np.int64)
# Decoder-side padding mask.
num_tokens_dec = len(t5_decoder_in)
padding_length_dec = max_seq_length_dec - num_tokens_dec
assert padding_length_dec >= 0
filler_dec = [pad_id] * padding_length_dec
tokens_dec_in = np.array(t5_decoder_in + filler_dec, dtype=np.int64)
# Create attention masks
enc_mask = make_attention_mask(tokens_enc, tokens_enc)
enc_dec_mask = make_attention_mask(tokens_dec_in, tokens_enc)
dec_mask = make_attention_mask(tokens_dec_in, tokens_dec_in)
dec_mask = dec_mask * make_history_mask(tokens_dec_in)
# Labels mask.
labels = t5_decoder_out + ([-1] * padding_length_dec)
labels = np.array(labels, dtype=np.int64)
# Loss mask
loss_mask = ([1] * num_tokens_dec) + ([0] * padding_length_dec)
loss_mask = np.array(loss_mask, dtype=np.int64)
return tokens_enc, tokens_dec_in, labels, enc_mask, dec_mask, enc_dec_mask, loss_mask
def make_attention_mask(source_block, target_block):
"""
Returns a 2-dimensional (2-D) attention mask
:param source_block: 1-D array
:param target_block: 1-D array
"""
mask = (target_block[None, :] >= 1) * (source_block[:, None] >= 1)
mask = mask.astype(np.int64)
# (source_length, target_length)
return mask
def make_attention_mask_3d(source_block, target_block):
"""
Returns a 3-dimensional (3-D) attention mask
:param source_block: 1-D array
:param target_block: 1-D array
"""
mask = (target_block[:, None, :] >= 1) * (source_block[:, :, None] >= 1)
# (batch, source_length, target_length)
# mask = mask.astype(np.int64)
return mask
def make_history_mask(block):
length = block.shape[0]
arange = np.arange(length)
history_mask = arange[None,] <= arange[:, None]
history_mask = history_mask.astype(np.int64)
return history_mask
def make_history_mask_3d(block):
batch, length = block.shape
arange = torch.arange(length, device=block.device)
history_mask = (arange[None,] <= arange[:, None])[
None,
]
history_mask = history_mask.expand(batch, length, length)
return history_mask

View file

@ -0,0 +1,18 @@
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# from nemo.collections.nlp.models.language_modeling.megatron.bert_model import BertModel
from nemo.collections.nlp.models.language_modeling.megatron.gpt_model import GPTModel
# from nemo.collections.nlp.models.language_modeling.megatron.t5_model import T5Model

View file

@ -0,0 +1,221 @@
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""BERT model."""
import torch
from apex.transformer import tensor_parallel
from nemo.collections.nlp.modules.common.megatron.enums import AttnMaskType
from nemo.collections.nlp.modules.common.megatron.fused_layer_norm import MixedFusedLayerNorm as LayerNorm
from nemo.collections.nlp.modules.common.megatron.language_model import get_language_model, parallel_lm_logits
from nemo.collections.nlp.modules.common.megatron.module import MegatronModule
from nemo.collections.nlp.modules.common.megatron.utils import (
erf_gelu,
get_linear_layer,
init_method_normal,
openai_gelu,
scaled_init_method_normal,
)
def bert_extended_attention_mask(attention_mask):
# We create a 3D attention mask from a 2D tensor mask.
# [b, 1, s]
attention_mask_b1s = attention_mask.unsqueeze(1)
# [b, s, 1]
attention_mask_bs1 = attention_mask.unsqueeze(2)
# [b, s, s]
attention_mask_bss = attention_mask_b1s * attention_mask_bs1
# [b, 1, s, s]
extended_attention_mask = attention_mask_bss.unsqueeze(1)
# Convert attention mask to binary:
extended_attention_mask = extended_attention_mask < 0.5
return extended_attention_mask
def bert_position_ids(token_ids):
# Create position ids
seq_length = token_ids.size(1)
position_ids = torch.arange(seq_length, dtype=torch.long, device=token_ids.device)
position_ids = position_ids.unsqueeze(0).expand_as(token_ids)
return position_ids
class BertLMHead(MegatronModule):
"""Masked LM head for Bert
Arguments:
mpu_vocab_size: model parallel size of vocabulary.
hidden_size: hidden size
init_method: init method for weight initialization
layernorm_epsilon: tolerance for layer norm divisions
parallel_output: whether output logits being distributed or not.
"""
def __init__(self, mpu_vocab_size, hidden_size, init_method, layernorm_epsilon, parallel_output):
super(BertLMHead, self).__init__()
args = get_args()
self.bias = torch.nn.Parameter(torch.zeros(mpu_vocab_size))
parallel_state.set_tensor_model_parallel_attributes(self.bias, True, 0, 1)
self.parallel_output = parallel_output
self.dense = get_linear_layer(hidden_size, hidden_size, init_method)
self.layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon)
self.gelu = torch.nn.functional.gelu
if args.openai_gelu:
self.gelu = openai_gelu
elif args.onnx_safe:
self.gelu = erf_gelu
def forward(self, hidden_states, word_embeddings_weight):
hidden_states = self.dense(hidden_states)
hidden_states = self.gelu(hidden_states)
hidden_states = self.layernorm(hidden_states)
output = parallel_lm_logits(hidden_states, word_embeddings_weight, self.parallel_output, bias=self.bias)
return output
def post_language_model_processing(
lm_output, pooled_output, lm_head, binary_head, lm_labels, logit_weights, fp16_lm_cross_entropy
):
# Output.
lm_logits = lm_head(lm_output, logit_weights)
binary_logits = None
if binary_head is not None:
binary_logits = binary_head(pooled_output)
if lm_labels is None:
return lm_logits, binary_logits
else:
if fp16_lm_cross_entropy:
assert lm_logits.dtype == torch.half
lm_loss = tensor_parallel.vocab_parallel_cross_entropy(lm_logits, lm_labels)
else:
lm_loss = tensor_parallel.vocab_parallel_cross_entropy(lm_logits.float(), lm_labels)
return lm_loss, binary_logits
class BertModel(MegatronModule):
"""Bert Language model."""
def __init__(
self, num_tokentypes=2, add_binary_head=True, parallel_output=True, pre_process=True, post_process=True
):
super(BertModel, self).__init__()
args = get_args()
self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy
self.add_binary_head = add_binary_head
self.parallel_output = parallel_output
self.pre_process = pre_process
self.post_process = post_process
init_method = init_method_normal(args.init_method_std)
scaled_init_method = scaled_init_method_normal(args.init_method_std, args.num_layers)
self.language_model, self._language_model_key = get_language_model(
num_tokentypes=num_tokentypes,
add_pooler=self.add_binary_head,
encoder_attn_mask_type=AttnMaskType.padding,
init_method=init_method,
scaled_init_method=scaled_init_method,
pre_process=self.pre_process,
post_process=self.post_process,
)
self.initialize_word_embeddings(init_method_normal)
if self.post_process:
self.lm_head = BertLMHead(
self.word_embeddings_weight().size(0),
args.hidden_size,
init_method,
args.layernorm_epsilon,
parallel_output,
)
self._lm_head_key = 'lm_head'
self.binary_head = None
if self.add_binary_head:
self.binary_head = get_linear_layer(args.hidden_size, 2, init_method)
self._binary_head_key = 'binary_head'
def set_input_tensor(self, input_tensor):
"""See megatron.model.transformer.set_input_tensor()"""
self.language_model.set_input_tensor(input_tensor)
def forward(self, bert_model_input, attention_mask, tokentype_ids=None, lm_labels=None):
extended_attention_mask = bert_extended_attention_mask(attention_mask)
input_ids = bert_model_input
position_ids = bert_position_ids(input_ids)
lm_output = self.language_model(input_ids, position_ids, extended_attention_mask, tokentype_ids=tokentype_ids)
if self.post_process and self.add_binary_head:
lm_output, pooled_output = lm_output
else:
pooled_output = None
if self.post_process:
return post_language_model_processing(
lm_output,
pooled_output,
self.lm_head,
self.binary_head,
lm_labels,
self.word_embeddings_weight(),
self.fp16_lm_cross_entropy,
)
else:
return lm_output
def state_dict_for_save_checkpoint(self, destination=None, prefix='', keep_vars=False):
"""For easy load when model is combined with other heads,
add an extra key."""
state_dict_ = {}
state_dict_[self._language_model_key] = self.language_model.state_dict_for_save_checkpoint(
destination, prefix, keep_vars
)
if self.post_process:
state_dict_[self._lm_head_key] = self.lm_head.state_dict_for_save_checkpoint(
destination, prefix, keep_vars
)
if self.post_process and self.add_binary_head:
state_dict_[self._binary_head_key] = self.binary_head.state_dict(destination, prefix, keep_vars)
# Save word_embeddings.
if self.post_process and not self.pre_process:
state_dict_[self._word_embeddings_for_head_key] = self.word_embeddings.state_dict(
destination, prefix, keep_vars
)
return state_dict_
def load_state_dict(self, state_dict, strict=True):
"""Customized load."""
self.language_model.load_state_dict(state_dict[self._language_model_key], strict=strict)
if self.post_process:
self.lm_head.load_state_dict(state_dict[self._lm_head_key], strict=strict)
if self.post_process and self.add_binary_head:
self.binary_head.load_state_dict(state_dict[self._binary_head_key], strict=strict)
# Load word_embeddings.
if self.post_process and not self.pre_process:
self.word_embeddings.load_state_dict(state_dict[self._word_embeddings_for_head_key], strict=strict)

View file

@ -0,0 +1,191 @@
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""GPT-2 model."""
import torch
from apex.transformer import tensor_parallel
from apex.transformer.enums import AttnMaskType
from nemo.collections.nlp.modules.common.megatron.language_model import get_language_model, parallel_lm_logits
from nemo.collections.nlp.modules.common.megatron.module import MegatronModule
from nemo.collections.nlp.modules.common.megatron.utils import init_method_normal, scaled_init_method_normal
def post_language_model_processing(
lm_output,
labels,
logit_weights,
get_key_value,
parallel_output,
forward_method_parallel_output,
fp16_lm_cross_entropy,
):
if get_key_value:
lm_output, presents = lm_output
# Output.
if forward_method_parallel_output is not None:
parallel_output = forward_method_parallel_output
output = parallel_lm_logits(lm_output, logit_weights, parallel_output)
if get_key_value:
output = [output, presents]
if labels is None:
return output
else:
if fp16_lm_cross_entropy:
assert output.dtype == torch.half
loss = tensor_parallel.vocab_parallel_cross_entropy(output, labels)
else:
loss = tensor_parallel.vocab_parallel_cross_entropy(output.float(), labels)
return loss
class GPTModel(MegatronModule):
"""GPT-2 Language model."""
def __init__(
self,
vocab_size,
hidden_size,
max_position_embeddings,
num_layers,
num_attention_heads,
ffn_hidden_size,
apply_query_key_layer_scaling=True,
kv_channels=None,
num_tokentypes=0,
parallel_output=True,
pre_process=True,
post_process=True,
init_method_std=0.02,
fp16_lm_cross_entropy=False,
use_cpu_initialization=False,
hidden_dropout=0.1,
fused_fp16=False,
fused_bf16=False,
fp32_residual_connection=False,
activations_checkpoint_method=None,
activations_checkpoint_num_layers=1,
layernorm_epsilon=1e-5,
bias_gelu_fusion=True,
openai_gelu=False,
onnx_safe=False,
):
super(GPTModel, self).__init__()
self.parallel_output = parallel_output
self.pre_process = pre_process
self.post_process = post_process
self.fp16_lm_cross_entropy = fp16_lm_cross_entropy
assert not (fused_fp16 and fused_bf16), "both fused_fp16 and fused_bf16 flags cannot be True at the same time."
if kv_channels is None:
assert (
hidden_size % num_attention_heads == 0
), 'hidden_size must be divisible by num_attention_heads if kv_channels is None'
kv_channels = hidden_size // num_attention_heads
self.language_model, self._language_model_key = get_language_model(
vocab_size=vocab_size,
hidden_size=hidden_size,
hidden_dropout=hidden_dropout,
num_tokentypes=num_tokentypes,
max_position_embeddings=max_position_embeddings,
num_layers=num_layers,
num_attention_heads=num_attention_heads,
apply_query_key_layer_scaling=apply_query_key_layer_scaling,
kv_channels=kv_channels,
ffn_hidden_size=ffn_hidden_size,
add_pooler=False,
encoder_attn_mask_type=AttnMaskType.causal,
init_method=init_method_normal(init_method_std),
scaled_init_method=scaled_init_method_normal(init_method_std, num_layers),
pre_process=self.pre_process,
post_process=self.post_process,
init_method_std=init_method_std,
use_cpu_initialization=use_cpu_initialization,
fused_fp16=fused_fp16,
fused_bf16=fused_bf16,
fp32_residual_connection=fp32_residual_connection,
activations_checkpoint_method=activations_checkpoint_method,
activations_checkpoint_num_layers=activations_checkpoint_num_layers,
layernorm_epsilon=layernorm_epsilon,
bias_gelu_fusion=bias_gelu_fusion,
openai_gelu=openai_gelu,
onnx_safe=onnx_safe,
)
self.initialize_word_embeddings(
init_method=init_method_normal(init_method_std), vocab_size=vocab_size, hidden_size=hidden_size
)
def set_input_tensor(self, input_tensor):
"""See megatron.model.transformer.set_input_tensor()"""
self.language_model.set_input_tensor(input_tensor)
def forward(
self,
input_ids,
position_ids,
attention_mask,
labels=None,
tokentype_ids=None,
layer_past=None,
get_key_value=False,
forward_method_parallel_output=None,
):
lm_output = self.language_model(
input_ids, position_ids, attention_mask, layer_past=layer_past, get_key_value=get_key_value
)
if self.post_process:
return post_language_model_processing(
lm_output,
labels,
self.word_embeddings_weight(),
get_key_value,
self.parallel_output,
forward_method_parallel_output,
self.fp16_lm_cross_entropy,
)
else:
return lm_output
def state_dict_for_save_checkpoint(self, destination=None, prefix='', keep_vars=False):
state_dict_ = {}
state_dict_[self._language_model_key] = self.language_model.state_dict_for_save_checkpoint(
destination, prefix, keep_vars
)
# Save word_embeddings.
if self.post_process and not self.pre_process:
state_dict_[self._word_embeddings_for_head_key] = self.word_embeddings.state_dict(
destination, prefix, keep_vars
)
return state_dict_
def load_state_dict(self, state_dict, strict=True):
"""Customized load."""
# Load word_embeddings.
if self.post_process and not self.pre_process:
self.word_embeddings.load_state_dict(state_dict[self._word_embeddings_for_head_key], strict=strict)
if self._language_model_key in state_dict:
state_dict = state_dict[self._language_model_key]
self.language_model.load_state_dict(state_dict, strict=strict)

View file

@ -0,0 +1,163 @@
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""T5 model."""
import torch
from apex.transformer import tensor_parallel
from apex.transformer.enums import AttnMaskType
from megatron import get_args
from nemo.collections.nlp.modules.common.megatron.language_model import get_language_model, parallel_lm_logits
from nemo.collections.nlp.modules.common.megatron.module import MegatronModule
from nemo.collections.nlp.modules.common.megatron.transformer import LayerNorm
from nemo.collections.nlp.modules.common.megatron.utils import init_method_normal, scaled_init_method_normal
def t5_extended_attention_mask(attention_mask_list):
def attn_mask_postprocess(attn_mask):
# [b, 1, s, s]
extended_attention_mask = attn_mask.unsqueeze(1)
return extended_attention_mask
return [attn_mask_postprocess(attn_mask) for attn_mask in attention_mask_list]
def t5_position_ids(token_ids):
# Create position ids
seq_length = token_ids.size(1)
position_ids = torch.arange(seq_length, dtype=torch.long, device=token_ids.device)
position_ids = position_ids.unsqueeze(0).expand_as(token_ids)
return position_ids
class T5LMHead(MegatronModule):
"""Masked LM head for T5
Arguments:
mpu_vocab_size: model parallel size of vocabulary.
hidden_size: hidden size
init_method: init method for weight initialization
layernorm_epsilon: tolerance for layer norm divisions
parallel_output: wether output logits being distributed or not.
"""
def __init__(self, mpu_vocab_size, parallel_output):
super(T5LMHead, self).__init__()
args = get_args()
self.bias = torch.nn.Parameter(torch.zeros(mpu_vocab_size))
self.bias.model_parallel = True
self.bias.partition_dim = 0
self.bias.stride = 1
self.parallel_output = parallel_output
def forward(self, hidden_states, word_embeddings_weight):
output = parallel_lm_logits(hidden_states, word_embeddings_weight, self.parallel_output, bias=self.bias)
return output
class T5Model(MegatronModule):
"""T5 Language model."""
def __init__(self, num_tokentypes=0, parallel_output=True):
super(T5Model, self).__init__()
args = get_args()
self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy
self.parallel_output = parallel_output
init_method = init_method_normal(args.init_method_std)
scaled_init_method = scaled_init_method_normal(args.init_method_std, args.num_layers)
self.language_model, self._language_model_key = get_language_model(
num_tokentypes=num_tokentypes,
add_pooler=False,
add_decoder=True,
encoder_attn_mask_type=AttnMaskType.padding,
init_method=init_method,
scaled_init_method=scaled_init_method,
)
self.lm_head = T5LMHead(self.language_model.embedding.word_embeddings.weight.size(0), parallel_output)
self._lm_head_key = 'lm_head'
def set_input_tensor(self, input_tensor):
"""See megatron.model.transformer.set_input_tensor()"""
self.language_model.set_input_tensor(input_tensor)
def forward(
self,
encoder_input_ids,
decoder_input_ids,
encoder_attn_mask,
decoder_attn_mask,
encoder_decoder_attn_mask,
tokentype_ids=None,
lm_labels=None,
enc_hidden_states=None,
):
# Converting the attention masks to proper parameter settings
encoder_attn_mask, decoder_attn_mask, encoder_decoder_attn_mask = t5_extended_attention_mask(
[encoder_attn_mask, decoder_attn_mask, encoder_decoder_attn_mask]
)
encoder_position_ids = t5_position_ids(encoder_input_ids)
decoder_position_ids = t5_position_ids(decoder_input_ids)
lm_output = self.language_model(
encoder_input_ids,
encoder_position_ids,
encoder_attn_mask,
decoder_input_ids,
decoder_position_ids,
decoder_attn_mask,
encoder_decoder_attn_mask,
tokentype_ids=tokentype_ids,
enc_hidden_states=enc_hidden_states,
)
decoder_output, encoder_output = lm_output
# Output.
lm_logits = self.lm_head(decoder_output, self.language_model.embedding.word_embeddings.weight)
if lm_labels is None:
return lm_logits, encoder_output
else:
if self.fp16_lm_cross_entropy:
assert lm_logits.dtype == torch.half
lm_loss = tensor_parallel.vocab_parallel_cross_entropy(lm_logits, lm_labels)
else:
lm_loss = tensor_parallel.vocab_parallel_cross_entropy(lm_logits.float(), lm_labels)
return lm_loss, encoder_output
def state_dict_for_save_checkpoint(self, destination=None, prefix='', keep_vars=False):
"""For easy load when model is combined with other heads,
add an extra key."""
state_dict_ = {}
state_dict_[self._language_model_key] = self.language_model.state_dict_for_save_checkpoint(
destination, prefix, keep_vars
)
state_dict_[self._lm_head_key] = self.lm_head.state_dict_for_save_checkpoint(destination, prefix, keep_vars)
return state_dict_
def load_state_dict(self, state_dict, strict=True):
"""Customized load."""
self.language_model.load_state_dict(state_dict[self._language_model_key], strict=strict)
self.lm_head.load_state_dict(state_dict[self._lm_head_key], strict=strict)

View file

@ -0,0 +1,406 @@
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import re
from typing import Any, Dict, Optional
import torch
from apex.transformer import parallel_state, tensor_parallel
from omegaconf.dictconfig import DictConfig
from pytorch_lightning.trainer.trainer import Trainer
from nemo.collections.nlp.data.language_modeling.megatron.data_samplers import (
MegatronPretrainingRandomSampler,
MegatronPretrainingSampler,
)
from nemo.collections.nlp.data.language_modeling.megatron.gpt_dataset import build_train_valid_test_datasets
from nemo.collections.nlp.models.language_modeling.megatron.gpt_model import GPTModel
from nemo.collections.nlp.models.nlp_model import NLPModel
from nemo.collections.nlp.modules.common.megatron.clip_grads import clip_grad_norm_fp32
from nemo.collections.nlp.modules.common.megatron.megatron_init import (
initialize_model_parallel_for_nemo,
set_jit_fusion_options,
)
from nemo.collections.nlp.modules.common.megatron.utils import (
average_losses_across_data_parallel_group,
get_ltor_masks_and_position_ids,
)
from nemo.collections.nlp.modules.common.tokenizer_utils import get_nmt_tokenizer
from nemo.collections.nlp.parts.nlp_overrides import NLPNativeMixedPrecisionPlugin
from nemo.utils import AppState, logging
class MegatronGPTModel(NLPModel):
"""
Megatron GPT pretraining
"""
def __init__(self, cfg: DictConfig, trainer: Trainer):
super().__init__(cfg, trainer=trainer)
self.cfg = cfg
if self.cfg.get('use_cpu_initialization', False) is False:
torch.cuda.set_device(trainer.local_rank)
# buffer used during train_step for logging average loss over gradient accumulation steps
self._reduced_loss_buffer = []
initialize_model_parallel_for_nemo(
world_size=trainer.world_size,
global_rank=trainer.global_rank,
local_rank=trainer.local_rank,
tensor_model_parallel_size=cfg.get('tensor_model_parallel_size', 1),
seed=self.cfg.get('seed', 1234),
)
if not self.cfg.get('fused_bf16'):
set_jit_fusion_options()
self.tokenizer = get_nmt_tokenizer(
library=self.cfg.tokenizer.library,
model_name=self.cfg.tokenizer.type,
tokenizer_model=self.register_artifact("tokenizer_model", self.cfg.tokenizer.model),
vocab_file=self.register_artifact("vocab_file", self.cfg.tokenizer.vocab_file),
merges_file=self.register_artifact("merges_file", self.cfg.tokenizer.merge_file),
)
vocab_size = self.tokenizer.vocab_size
padded_vocab_size = self._vocab_size_with_padding(
orig_vocab_size=vocab_size,
make_vocab_size_divisible_by=cfg.get('make_vocab_size_divisible_by', 128),
tensor_model_parallel_size=cfg.get('tensor_model_parallel_size', 1),
)
self.model = GPTModel(
vocab_size=padded_vocab_size,
hidden_size=cfg.hidden_size,
max_position_embeddings=cfg.max_position_embeddings,
num_layers=cfg.num_layers,
num_attention_heads=cfg.num_attention_heads,
apply_query_key_layer_scaling=cfg.get('apply_query_key_layer_scaling', True),
kv_channels=cfg.get('kv_channels', None),
ffn_hidden_size=cfg.ffn_hidden_size,
num_tokentypes=0,
parallel_output=True,
pre_process=cfg.get('pre_process', True),
post_process=cfg.get('post_process', True),
init_method_std=cfg.get('init_method_std', 0.02),
fp16_lm_cross_entropy=cfg.get('fp16_lm_cross_entropy', False),
use_cpu_initialization=cfg.get('use_cpu_initialization', False),
hidden_dropout=cfg.get('hidden_dropout', 0.1),
fused_fp16=cfg.get('fused_fp16', False),
fused_bf16=cfg.get('fused_bf16', False),
fp32_residual_connection=cfg.get('fp32_residual_connection', False),
activations_checkpoint_method=cfg.get('activations_checkpoint_method', None),
activations_checkpoint_num_layers=cfg.get('activations_checkpoint_num_layers', 1),
layernorm_epsilon=cfg.get('layernorm_epsilon', 1e-5),
onnx_safe=cfg.get('onnx_safe', False),
)
def forward(self, tokens, position_ids, attention_mask, labels):
output_tensor = self.model(tokens, position_ids, attention_mask, labels=labels)
return output_tensor
def training_step(self, batch, batch_idx):
tokens, labels, loss_mask, attention_mask, position_ids = self.process_batch(batch)
output_tensor = self(tokens, position_ids, attention_mask, labels)
loss = self.loss_func(loss_mask, output_tensor)
reduced_loss = average_losses_across_data_parallel_group([loss])
# cache reduced loss while accumulating gradients
self._reduced_loss_buffer.append(reduced_loss[0])
if (batch_idx + 1) % self.trainer.accumulate_grad_batches == 0:
# Reduced loss for logging.
average_reduced_loss = sum(self._reduced_loss_buffer) / len(self._reduced_loss_buffer)
self.log('reduced_train_loss', average_reduced_loss, prog_bar=True)
lr = self._optimizer.param_groups[0]['lr']
self.log('lr', lr)
self.log('global_step', self.trainer.global_step, prog_bar=True)
self.log('consumed_samples', self.compute_consumed_samples(self.trainer.global_step), prog_bar=True)
self._reduced_loss_buffer = []
return loss
def validation_step(self, batch, batch_idx):
tokens, labels, loss_mask, attention_mask, position_ids = self.process_batch(batch)
output_tensor = self(tokens, position_ids, attention_mask, labels)
loss = self.loss_func(loss_mask, output_tensor)
# TODO: add text generation
# take the first k tokens and then generate text - compare with ground truth (won't look similar in general)
"""
k = num_context
n = max_generate_length
context_tokens = tokens[0:k]
while k < n:
output_tensor = self(context_tokens)
next_token = sample(output_tensor)
context_tokens.append(next_token)
k += 1
"""
return loss
def validation_epoch_end(self, outputs):
averaged_loss = average_losses_across_data_parallel_group(outputs)
self.log('val_loss', averaged_loss[0], prog_bar=True)
self.log('consumed_samples', self.compute_consumed_samples(self.trainer.global_step))
def test_step(self, batch, batch_idx):
return self.validation_step(batch, batch_idx)
def test_epoch_end(self, outputs):
averaged_loss = average_losses_across_data_parallel_group(outputs)
logging.info(f'test_loss: {averaged_loss[0]}')
def loss_func(self, loss_mask, output_tensor):
losses = output_tensor.float()
loss_mask = loss_mask.view(-1).float()
# TODO: add nemo version here
loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum() # sequence level nll
return loss
def process_batch(self, batch):
# Items and their type.
keys = ['text']
datatype = torch.int64
data = batch
data_b = tensor_parallel.broadcast_data(keys, data, datatype)
# Unpack.
tokens_ = data_b['text'].long()
labels = tokens_[:, 1:].contiguous()
tokens = tokens_[:, :-1].contiguous()
# Get the masks and postition ids.
attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids(
tokens,
self.tokenizer.eos_id,
self.cfg.data.get('reset_position_ids', False),
self.cfg.data.get('reset_attention_mask', False),
self.cfg.data.get('eod_mask_loss', False),
)
return tokens, labels, loss_mask, attention_mask, position_ids
def build_train_valid_test_datasets(self):
logging.info('Building GPT datasets.')
global_batch_size = self.trainer.world_size * self.cfg.micro_batch_size / self.cfg.tensor_model_parallel_size
eval_iters = (self.trainer.max_steps // self.trainer.val_check_interval + 1) * self.trainer.limit_val_batches
test_iters = self.trainer.limit_test_batches
train_valid_test_num_samples = [
self.trainer.max_steps * global_batch_size,
eval_iters * global_batch_size,
test_iters * global_batch_size,
]
self._train_ds, self._validation_ds, self._test_ds = build_train_valid_test_datasets(
cfg=self.cfg,
trainer=self.trainer,
data_prefix=self.cfg.data.data_prefix,
data_impl=self.cfg.data.data_impl,
splits_string=self.cfg.data.splits_string,
train_valid_test_num_samples=train_valid_test_num_samples,
seq_length=self.cfg.data.seq_length,
seed=self.cfg.seed,
skip_warmup=self.cfg.data.get('skip_warmup', True),
)
if self._train_ds is not None:
logging.info(f'Length of train dataset: {len(self._train_ds)}')
if self._validation_ds is not None:
logging.info(f'Length of val dataset: {len(self._validation_ds)}')
if self._test_ds is not None:
logging.info(f'Length of test dataset: {len(self._test_ds)}')
logging.info(f'Finished building GPT datasets.')
return self._train_ds, self._validation_ds, self._test_ds
def build_pretraining_data_loader(self, dataset, consumed_samples):
"""Buld dataloader given an input dataset."""
if dataset is None:
return None
# Megatron sampler
if hasattr(self.cfg.data, 'dataloader_type') and self.cfg.data.dataloader_type is not None:
if self.cfg.data.dataloader_type == 'single':
batch_sampler = MegatronPretrainingSampler(
total_samples=len(dataset),
consumed_samples=consumed_samples,
micro_batch_size=self.cfg.micro_batch_size,
data_parallel_rank=parallel_state.get_data_parallel_rank(),
data_parallel_size=parallel_state.get_data_parallel_world_size(),
)
elif self.cfg.data.dataloader_type == 'cyclic':
batch_sampler = MegatronPretrainingRandomSampler(
total_samples=len(dataset),
consumed_samples=consumed_samples,
micro_batch_size=self.cfg.micro_batch_size,
data_parallel_rank=parallel_state.get_data_parallel_rank(),
data_parallel_size=parallel_state.get_data_parallel_world_size(),
)
else:
raise ValueError('cfg.data.dataloader_type must be "single" or "cyclic"')
else:
raise ValueError('cfg.data.dataloader_type not found. Must be "single" or "cyclic"')
# Torch dataloader.
return torch.utils.data.DataLoader(
dataset, batch_sampler=batch_sampler, num_workers=self.cfg.data.num_workers, pin_memory=True,
)
def setup(self, stage=None):
if stage == 'predict':
return
# TODO: consider adding a ModelPT guard to check if model is being restored.
# allowing restored models to optionally setup datasets
self.build_train_valid_test_datasets()
self.setup_training_data(self.cfg.data)
self.setup_validation_data(self.cfg.data)
self.setup_test_data(self.cfg.data)
def setup_training_data(self, cfg):
if hasattr(self, '_train_ds'):
resume_checkpoint_path = self.trainer.checkpoint_connector.resume_checkpoint_path
if resume_checkpoint_path:
consumed_samples = int(
float(re.findall(r"consumed_samples\=([0-9]+.[0-9]+)", resume_checkpoint_path)[0])
)
else:
consumed_samples = 0
logging.info(
f'Setting up train dataloader with len(len(self._train_ds)): {len(self._train_ds)} and consumed samples: {consumed_samples}'
)
self._train_dl = self.build_pretraining_data_loader(self._train_ds, consumed_samples)
def setup_validation_data(self, cfg):
if hasattr(self, '_validation_ds'):
consumed_samples = 0
logging.info(
f'Setting up validation dataloader with len(len(self._validation_ds)): {len(self._validation_ds)} and consumed samples: {consumed_samples}'
)
self._validation_dl = self.build_pretraining_data_loader(self._validation_ds, consumed_samples)
def setup_test_data(self, cfg):
if hasattr(self, '_test_ds'):
consumed_samples = 0
logging.info(
f'Setting up test dataloader with len(len(self._test_ds)): {len(self._test_ds)} and consumed samples: {consumed_samples}'
)
self._test_dl = self.build_pretraining_data_loader(self._test_ds, consumed_samples)
def compute_consumed_samples(self, global_step):
app_state = AppState()
consumed_samples = (
global_step
* app_state.data_parallel_size
* self.cfg.micro_batch_size
* self.trainer.accumulate_grad_batches
)
return int(consumed_samples)
def on_before_optimizer_step(self, optimizer, optimizer_idx):
"""PTL hook that is called after unscaling gradients when using native amp.
We use gradient clipping implementation from megatron-lm.
"""
clip_val = self.trainer.gradient_clip_val
if clip_val is None:
return
clip_val = float(clip_val)
if clip_val <= 0:
return
if isinstance(self.trainer.accelerator_connector.precision_plugin, NLPNativeMixedPrecisionPlugin):
parameters = self.model.parameters()
clip_grad_norm_fp32(parameters=parameters, max_norm=clip_val)
else:
return
def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: Optional[int] = None) -> Any:
request = batch
response = self.complete(request)
return response
def complete(self, request: Dict):
"""
Autoregressively invokes language model in the inference mode
Args:
request: Dictionary with the following fields
* prompt: a string which text the model should complete.
* tokens_to_generate: how many tokens to generate while doing prompt completion.
* stop_after_sentence: (default True) whether to stop generation once sentence end is reached.
Returns:
response: A python dictionary with the following fields
* prompt: original text of the prompt
* tokenized_prompt: list of (str) tokens from prompt
* completion: a python dictionary with the following subfields:
* tokens: a list of triples (token, token_id, log_prob) comprising completion
* stop reason: either 'eos', 'sentence_end' or 'limit' indicating why generation stopped
* text: completion text (as a single string)
"""
response = {}
self.freeze()
logsoftmaxlayer = torch.nn.LogSoftmax(dim=-1)
response['tokenized_prompt'] = request['tokenized_prompt']
tokens = request['tokens']
# naive greedy slow loop
# TODO: add option for BeamSearchDecoder
response['prompt'] = request['prompt']
response['completion'] = {}
response['completion']['stop reason'] = 'limit'
for i in range(request.get("tokens_to_generate", 64)):
attention_mask, _, position_ids = get_ltor_masks_and_position_ids(
data=tokens,
eod_token=self.tokenizer.eos_id,
reset_position_ids=self.cfg.get('reset_position_ids', False),
reset_attention_mask=self.cfg.get('reset_attention_mask', False),
eod_mask_loss=self.cfg.get('eod_mask_loss', False),
)
# No labels during inference. Still need masks to not attend to the right
output_tensor = self(tokens, position_ids, attention_mask, labels=None)
output_tensor = tensor_parallel.gather_from_tensor_model_parallel_region(output_tensor)
log_probs, token_ids = torch.max(logsoftmaxlayer(output_tensor), dim=-1)
reached_eos = token_ids[0, -1].item() == self.tokenizer.eos_id
tokens = torch.cat([torch.squeeze(tokens), token_ids[:, -1]])
response['completion']["tokens"] = list(
zip(self.tokenizer.ids_to_tokens(tokens), tokens.tolist(), log_probs.tolist()[0])
)
completion_text = self.tokenizer.ids_to_text(x[1] for x in response['completion']["tokens"])
if reached_eos: # Will it actually ever reach that?
response['completion']['stop reason'] = 'eos'
break
elif request.get("stop_after_sentence", True) and completion_text.endswith(('.', '!', '?')):
response['completion']['stop reason'] = 'sentence_end'
break
tokens = torch.unsqueeze(tokens, 0)
response['completion']["text"] = self.tokenizer.ids_to_text(x[1] for x in response['completion']["tokens"])
self.unfreeze()
return response
def list_available_models(self):
return None
def _vocab_size_with_padding(self, orig_vocab_size, make_vocab_size_divisible_by, tensor_model_parallel_size):
"""Pad vocab size so it is divisible by model parallel size and
still having GPU friendly size."""
after = orig_vocab_size
multiple = make_vocab_size_divisible_by * tensor_model_parallel_size
while (after % multiple) != 0:
after += 1
logging.info(
f'Padded vocab_size: {after}, original vocab_size: {orig_vocab_size}, dummy tokens: {after - orig_vocab_size}.'
)
return after

View file

@ -15,32 +15,29 @@
import hashlib
import json
import os
import shutil
import tarfile
import tempfile
from typing import Any, Dict, Optional, Union
from typing import Any, Dict, Optional
import torch
from megatron import mpu
from megatron.checkpointing import get_checkpoint_version, set_checkpoint_version
from omegaconf import DictConfig, OmegaConf
from pytorch_lightning import Trainer
from pytorch_lightning.accelerators.accelerator import Accelerator
from pytorch_lightning.core.saving import load_hparams_from_tags_csv, load_hparams_from_yaml
from pytorch_lightning.utilities import rank_zero_only
from pytorch_lightning.utilities.cloud_io import load as pl_load
from pytorch_lightning.utilities.migration import pl_legacy_patch
from transformers import TRANSFORMERS_CACHE
from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import AutoTokenizer
from nemo.collections.nlp.modules import BertModule, MegatronBertEncoder
from nemo.collections.nlp.modules.common.huggingface.huggingface_utils import VOCAB_FILE_NAME
from nemo.collections.nlp.modules.common.megatron.megatron_bert import (
get_megatron_checkpoint_version,
set_megatron_checkpoint_version,
)
from nemo.collections.nlp.modules.common.megatron.megatron_encoder import MegatronEncoderModule
from nemo.collections.nlp.modules.common.megatron.megatron_utils import compute_model_parallel_rank
from nemo.collections.nlp.modules.common.tokenizer_utils import get_tokenizer
from nemo.collections.nlp.parts.nlp_overrides import NLPCheckpointConnector
from nemo.collections.nlp.parts.nlp_overrides import NLPCheckpointConnector, NLPSaveRestoreConnector
from nemo.core.classes import ModelPT
from nemo.core.classes.exportable import Exportable
from nemo.core.connectors.save_restore_connector import SaveRestoreConnector
from nemo.utils import AppState, logging
from nemo.utils.get_rank import is_global_rank_zero
__all__ = ['NLPModel']
@ -55,6 +52,8 @@ class NLPModel(ModelPT, Exportable):
def __init__(self, cfg: DictConfig, trainer: Trainer = None):
super().__init__(cfg, trainer)
# handles model parallel save and restore logic
self._save_restore_connector = NLPSaveRestoreConnector()
self.set_world_size(trainer)
def register_artifact(
@ -180,274 +179,31 @@ class NLPModel(ModelPT, Exportable):
f'Registering tokenizer vocab for {self.tokenizer} is not yet supported. Please override this method if needed.'
)
def _clip_gradients(self, optimizer, clip_val=None):
""" Override of PTL Gradient Clipping.
Enables model parallel gradient clipping from Megatron-LM.
Args:
optimizer ([type]): [description]
clip_val ([type], optional): [description]. Defaults to None.
"""
app_state = AppState()
# get clip_val from trainer if None is provided
if clip_val is None:
clip_val = float(self._trainer.gradient_clip_val)
if app_state.model_parallel_size is not None:
model = self._trainer.get_model()
parameters = model.parameters()
if mpu.model_parallel_is_initialized():
mpu.grads.clip_grad_norm(parameters=parameters, max_norm=clip_val)
else:
raise ValueError('Model parallel groups must be intialized to use model parallel gradient clipping.')
else:
return Accelerator._clip_gradients(self, optimizer, clip_val)
def setup(self, stage: str) -> None:
""" PTL hook that is called on all DDP processes. """
if stage == 'fit':
# adds self.bert_model config to .nemo file
if hasattr(self, 'bert_model') and self.bert_model is not None:
self.register_bert_model()
app_state = AppState()
if app_state.model_parallel_size is not None:
self._trainer.checkpoint_connector = NLPCheckpointConnector(self._trainer)
# # Configure checkpointing for model parallel
# if app_state.create_checkpoint_callback:
# # global rank 0 is configured by exp_manager
# if not is_global_rank_zero() and app_state.data_parallel_rank == 0:
# configure_checkpointing(
# self._trainer,
# app_state.log_dir,
# app_state.checkpoint_name,
# app_state.checkpoint_callback_params,
# )
def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
""" LightningModule hook that's used to save things in addition to model weights. """
if hasattr(self, "bert_model") and isinstance(self.bert_model, MegatronBertEncoder):
checkpoint['checkpoint_version'] = get_checkpoint_version()
checkpoint['checkpoint_version'] = get_megatron_checkpoint_version()
return None
def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
""" LightningModule hook that's used to restore things saved with on_save_checkpoint."""
if hasattr(self, "bert_model") and isinstance(self.bert_model, MegatronBertEncoder):
if get_checkpoint_version():
if get_megatron_checkpoint_version():
assert (
checkpoint['checkpoint_version'] == get_checkpoint_version()
), 'checkpoint version found on_load_checkpoint different than get_checkpoint_version'
checkpoint['checkpoint_version'] == get_megatron_checkpoint_version()
), 'checkpoint version found on_load_checkpoint different than get_megatron_checkpoint_version'
else:
set_checkpoint_version(checkpoint['checkpoint_version'])
set_megatron_checkpoint_version(checkpoint['checkpoint_version'])
logging.info(f"Setting Megatron checkpoint version: {checkpoint['checkpoint_version']}")
return None
# no rank check as model parallel models need to be saved on data parallel rank 0
def save_to(self, save_path: str):
"""
Saves model instance (weights and configuration) into .nemo file
You can use "restore_from" method to fully restore instance from .nemo file.
.nemo file is an archive (tar.gz) with the following:
model_config.yaml - model configuration in .yaml format. You can deserialize this into cfg argument for model's constructor
model_weights.ckpt - model checkpoint
Args:
save_path: Path to .nemo file where model instance should be saved
"""
save_path = os.path.abspath(save_path)
app_state = AppState()
if app_state.model_parallel_size is not None:
self._default_save_to(save_path)
else:
# super.save_to only runs on global rank 0
return super().save_to(save_path)
def _default_save_to(self, save_path: str):
app_state = AppState()
if app_state.model_parallel_size is not None:
# each model parallel rank creates a .nemo file
# after all .nemo files are created, each rank
# will add their checkpoint to global rank 0
base_dir = os.path.dirname(save_path) # use the directory to merge mp_rank .nemo files into one
# update save_path based on model parallel_rank
base_path = os.path.splitext(save_path)[0] # everything except the extension
mp_save_path = f'{base_path}_mp_rank_{app_state.model_parallel_rank:02d}.nemo'
if app_state.data_parallel_rank == 0:
super()._default_save_to(mp_save_path)
# barrier so that all processes have finished writing their weights before creating .nemo file
torch.distributed.barrier()
if is_global_rank_zero():
# extract all tar files
for mp_rank in range(app_state.model_parallel_size):
mp_tar_path = f'{base_path}_mp_rank_{mp_rank:02d}.nemo'
mp_tar = tarfile.open(mp_tar_path, 'r:gz')
mp_tar.extractall(path=os.path.join(base_dir, f'mp_rank_{mp_rank:02d}'))
mp_tar.close()
os.remove(mp_tar_path)
# move rank 0 .nemo extract to base_path
shutil.move(os.path.join(base_dir, 'mp_rank_00'), base_path)
# move mp_rank_00 checkpoint to mp_rank_00 directory inside base_path
os.mkdir(os.path.join(base_path, 'mp_rank_00'))
shutil.move(os.path.join(base_path, 'model_weights.ckpt'), os.path.join(base_path, 'mp_rank_00'))
# move other mp_rank checkpoints from base_dir to base_path
for mp_rank in range(1, app_state.model_parallel_size):
os.mkdir(os.path.join(base_path, f'mp_rank_{mp_rank:02d}'))
shutil.move(
os.path.join(base_dir, f'mp_rank_{mp_rank:02d}', 'model_weights.ckpt'),
os.path.join(base_path, f'mp_rank_{mp_rank:02d}'),
)
# clean up leftover directory
shutil.rmtree(os.path.join(base_dir, f'mp_rank_{mp_rank:02d}'))
# create tar file from base_path
self._make_nemo_file_from_folder(save_path, base_path)
# clean up base_path
shutil.rmtree(base_path)
elif is_global_rank_zero():
return super()._default_save_to(save_path)
else:
return
@classmethod
def restore_from(
cls,
restore_path: str,
override_config_path: Optional[Union[OmegaConf, str]] = None,
map_location: Optional[torch.device] = None,
strict: bool = True,
return_config: bool = False,
trainer: Trainer = None,
save_restore_connector: SaveRestoreConnector = None,
):
"""
Restores model instance (weights and configuration) from .nemo file.
Args:
restore_path: path to .nemo file from which model should be instantiated
override_config_path: path to a yaml config that will override the internal
config file or an OmegaConf / DictConfig object representing the model config.
map_location: Optional torch.device() to map the instantiated model to a device.
By default (None), it will select a GPU if available, falling back to CPU otherwise.
strict: Passed to load_state_dict. Set to True by default.
return_config: If set to true, will return just the underlying config of the restored
model as an OmegaConf DictConfig object without instantiating the model.
trainer: PyTorch Lightning trainer. Must be passed in order to use model parallel .nemo
Example:
```
model = nemo.collections.nlp.models.TokenClassificationModel.restore_from('token_classification.nemo')
assert isinstance(model, nemo.collections.nlp.models.TokenClassificationModel)
```
Returns:
An instance of type cls or its underlying config (if return_config is set).
"""
if save_restore_connector is None:
save_restore_connector = SaveRestoreConnector()
if not os.path.exists(restore_path):
raise FileNotFoundError(f"Can't find {restore_path}")
app_state = AppState()
app_state.model_restore_path = os.path.abspath(os.path.expanduser(restore_path))
# detect if we have a model parallel .nemo file
with tempfile.TemporaryDirectory() as tmpdir:
cwd = os.getcwd()
os.chdir(tmpdir)
# detect if model parallel from tarfile
tar = tarfile.open(app_state.model_restore_path, "r:gz")
names = tar.getnames()
mp_ranks = []
for name in names:
if 'mp_rank' in name:
mp_ranks.append(name)
if mp_ranks:
app_state.model_parallel_size = len(mp_ranks) // 2 # directory and file are included in getnames()
# get checkpoint version
checkpoint_version_member = None
for member in tar.getmembers():
if 'megatron_checkpoint_version.json' in member.name:
checkpoint_version_member = member
tar.extract(checkpoint_version_member, tmpdir)
with open(checkpoint_version_member.name, 'r') as f:
checkpoint_version = json.load(f).get('checkpoint_version', None)
logging.info(
(
f'Detected model parallel .nemo file: {restore_path}. '
f'Assuming megatron model parallelism with '
f'model_parallel_size: {app_state.model_parallel_size} '
f'and checkpoint version: {checkpoint_version}'
)
)
tar.close()
os.chdir(cwd)
if app_state.model_parallel_size is not None:
if not isinstance(trainer, Trainer):
raise ValueError("trainer must be a PyTorch Lightning Trainer to restore model parallel .nemo files.")
if checkpoint_version is None:
raise ValueError(
"Restoring from megatron model parallel .nemo but could not find megatron checkpoint version."
)
else:
logging.info(f"Setting megatron checkpoint version: {checkpoint_version}")
set_checkpoint_version(checkpoint_version)
app_state.world_size = trainer.num_gpus * trainer.num_nodes
if trainer.local_rank is not None:
app_state.local_rank = trainer.local_rank
else:
raise ValueError("trainer.local_rank is None. local_rank needed to restore model parallel models.")
model_parallel_rank = compute_model_parallel_rank(trainer.local_rank, app_state.model_parallel_size)
app_state.model_parallel_rank = model_parallel_rank
cls.update_save_restore_connector(save_restore_connector)
restored_model = cls._save_restore_connector.restore_from(
cls, app_state.model_restore_path, override_config_path, map_location, strict, return_config
)
restored_model.set_trainer(trainer)
return restored_model
else:
return super().restore_from(
app_state.model_restore_path,
override_config_path,
map_location,
strict,
return_config,
save_restore_connector=save_restore_connector,
)
@rank_zero_only
def register_megatron_checkpoint_version(self):
""" Adds checkpoint version to .nemo archive """
if self.has_megatron_encoder:
checkpoint_version = get_checkpoint_version()
checkpoint_version = get_megatron_checkpoint_version()
if checkpoint_version is None:
raise ValueError('Unable to get megatron checkpoint version.')
else:
@ -511,3 +267,72 @@ class NLPModel(ModelPT, Exportable):
if isinstance(self.encoder, MegatronEncoderModule):
logging.info(f"Restoring from pretrained model parallel checkpoint: {self.encoder.checkpoint_file}")
self.encoder._encoder.restore_weights(self.encoder.checkpoint_file)
@classmethod
def load_from_checkpoint(
cls,
checkpoint_path: str,
map_location: Any = None,
hparams_file: Optional[str] = None,
strict: bool = True,
**kwargs,
):
"""
Loads ModelPT from checkpoint, with some maintenance of restoration.
For documentation, please refer to LightningModule.load_from_checkpoin() documentation.
"""
checkpoint = None
try:
cls._set_model_restore_state(is_being_restored=True)
# TODO: replace with proper PTL API
with pl_legacy_patch():
if map_location is not None:
checkpoint = pl_load(checkpoint_path, map_location=map_location)
else:
checkpoint = pl_load(checkpoint_path, map_location=lambda storage, loc: storage)
if hparams_file is not None:
extension = hparams_file.split(".")[-1]
if extension.lower() == "csv":
hparams = load_hparams_from_tags_csv(hparams_file)
elif extension.lower() in ("yml", "yaml"):
hparams = load_hparams_from_yaml(hparams_file)
else:
raise ValueError(".csv, .yml or .yaml is required for `hparams_file`")
hparams["on_gpu"] = False
# overwrite hparams by the given file
checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY] = hparams
# for past checkpoint need to add the new key
if cls.CHECKPOINT_HYPER_PARAMS_KEY not in checkpoint:
checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY] = {}
# override the hparams with values that were passed in
# TODO: can we do this without overriding?
config_kwargs = kwargs.copy()
if 'trainer' in config_kwargs:
config_kwargs.pop('trainer')
checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY].update(config_kwargs)
model = cls._load_model_state(checkpoint, strict=strict, **kwargs)
checkpoint = model
finally:
cls._set_model_restore_state(is_being_restored=False)
return checkpoint
def save_to(self, save_path: str):
app_state = AppState()
# Add NeMo rank check as well
if app_state.model_parallel_size is not None:
if app_state.model_parallel_size > 1:
if not isinstance(self._save_restore_connector, NLPSaveRestoreConnector):
logging.warning(
f"Using {self._save_restore_connector.__class__} to save a model parallel model. Overriding with NLPSaveRestoreConnector. Make sure to subclass NLPSaveRestoreConnector."
)
self._save_restore_connector = NLPSaveRestoreConnector()
save_path = os.path.abspath(os.path.expanduser(save_path))
self._save_restore_connector.save_to(self, save_path)
else:
super(NLPModel, self).save_to(save_path=save_path)

View file

@ -39,16 +39,6 @@ class BertModule(NeuralModule, Exportable):
def output_types(self) -> Optional[Dict[str, NeuralType]]:
return {"last_hidden_states": NeuralType(('B', 'T', 'D'), ChannelType())}
@classmethod
def restore_from(cls, restore_path: str):
"""Restores module/model with weights"""
pass
@classmethod
def save_to(self, save_path: str):
"""Saves module/model with weights"""
pass
def restore_weights(self, restore_path: str):
"""Restores module/model's weights"""
logging.info(f"Restoring weights from {restore_path}")

View file

@ -89,13 +89,15 @@ def get_lm_model(
)
if "megatron" in pretrained_model_name:
model, checkpoint_file = get_megatron_lm_model(
config_dict=config_dict,
config_file=config_file,
pretrained_model_name=pretrained_model_name,
checkpoint_file=checkpoint_file,
vocab_file=vocab_file,
)
raise ValueError('megatron-lm BERT models have been deprecated in NeMo 1.5+. Please use NeMo 1.4 for support.')
# TODO: enable megatron bert in nemo
# model, checkpoint_file = get_megatron_lm_model(
# config_dict=config_dict,
# config_file=config_file,
# pretrained_model_name=pretrained_model_name,
# checkpoint_file=checkpoint_file,
# vocab_file=vocab_file,
# )
else:
model = get_huggingface_lm_model(
config_dict=config_dict, config_file=config_file, pretrained_model_name=pretrained_model_name,
@ -180,13 +182,17 @@ def get_transformer(
)
elif library == 'megatron':
model = get_megatron_transformer(
model_name=model_name,
pretrained=pretrained,
config_dict=config_dict,
encoder=encoder,
checkpoint_file=checkpoint_file,
raise ValueError(
f'megatron-lm bert support has been deprecated in NeMo 1.5+. Please use NeMo 1.4 for support.'
)
# TODO: enable megatron bert in nemo
# model = get_megatron_transformer(
# model_name=model_name,
# pretrained=pretrained,
# config_dict=config_dict,
# encoder=encoder,
# checkpoint_file=checkpoint_file,
# )
else:
raise ValueError("Libary must be 'nemo', 'huggingface' or 'megatron'")

View file

@ -0,0 +1,137 @@
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Gradient clipping."""
import amp_C
import torch
from apex.multi_tensor_apply import multi_tensor_applier
from apex.transformer import parallel_state
from apex.transformer.tensor_parallel.layers import param_is_not_tensor_parallel_duplicate
from torch._six import inf
from nemo.collections.nlp.modules.common.megatron.module import param_is_not_shared
def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
"""Clips gradient norm of an iterable of parameters whose gradients
are in fp32.
This is adapted from torch.nn.utils.clip_grad.clip_grad_norm_ and
added functionality to handle model parallel parameters. Note that
the gradients are modified in place.
Arguments:
parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
single Tensor that will have gradients normalized
max_norm (float or int): max norm of the gradients
norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for
infinity norm.
Returns:
Total norm of the parameters (viewed as a single vector).
"""
if isinstance(parameters, torch.Tensor):
parameters = [parameters]
# Filter parameters based on:
# - grad should not be none
# - parameter should not be shared
# - should not be a replica due to tensor model parallelism
grads = []
grads_for_norm = []
for param in parameters:
grad_not_none = param.grad is not None
is_not_shared = param_is_not_shared(param)
is_not_tp_duplicate = param_is_not_tensor_parallel_duplicate(param)
grad = param.grad.detach()
if grad_not_none:
# Make sure the grads are in fp32
assert isinstance(param.grad, torch.cuda.FloatTensor)
grads.append(grad)
if grad_not_none and is_not_shared and is_not_tp_duplicate:
grads_for_norm.append(grad)
# Norm parameters.
max_norm = float(max_norm)
norm_type = float(norm_type)
total_norm = 0.0
# Calculate norm.
if norm_type == inf:
total_norm = max(grad.abs().max() for grad in grads_for_norm)
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
# Take max across all model-parallel GPUs.
torch.distributed.all_reduce(
total_norm_cuda, op=torch.distributed.ReduceOp.MAX, group=parallel_state.get_model_parallel_group()
)
total_norm = total_norm_cuda[0].item()
else:
if norm_type == 2.0:
dummy_overflow_buf = torch.cuda.IntTensor([0])
# Use apex's multi-tensor applier for efficiency reasons.
# Multi-tensor applier takes a function and a list of list
# and performs the operation on that list all in one kernel.
grad_norm, _ = multi_tensor_applier(
amp_C.multi_tensor_l2norm, dummy_overflow_buf, [grads_for_norm], False # no per-parameter norm
)
# Since we will be summing across data parallel groups,
# we need the pow(norm-type).
total_norm = grad_norm ** norm_type
else:
for grad in grads_for_norm:
grad_norm = torch.norm(grad, norm_type)
total_norm += grad_norm ** norm_type
# Sum across all model-parallel GPUs.
torch.distributed.all_reduce(
total_norm, op=torch.distributed.ReduceOp.SUM, group=parallel_state.get_model_parallel_group()
)
total_norm = total_norm.item() ** (1.0 / norm_type)
# Scale.
clip_coeff = max_norm / (total_norm + 1.0e-6)
if clip_coeff < 1.0:
dummy_overflow_buf = torch.cuda.IntTensor([0])
multi_tensor_applier(amp_C.multi_tensor_scale, dummy_overflow_buf, [grads, grads], clip_coeff)
return total_norm
def count_zeros_fp32(parameters):
if isinstance(parameters, torch.Tensor):
parameters = [parameters]
# Filter parameters based on:
# - grad should not be none
# - parameter should not be shared
# - should not be a replica due to tensor model parallelism
total_num_zeros = 0.0
for param in parameters:
grad_not_none = param.grad is not None
is_not_shared = param_is_not_shared(param)
is_not_tp_duplicate = param_is_not_tensor_parallel_duplicate(param)
if grad_not_none and is_not_shared and is_not_tp_duplicate:
grad = param.grad.detach()
num_zeros = grad.numel() - torch.count_nonzero(grad)
total_num_zeros = num_zeros + total_num_zeros
# Sum across all model-parallel GPUs.
torch.distributed.all_reduce(
total_num_zeros, op=torch.distributed.ReduceOp.SUM, group=parallel_state.get_model_parallel_group()
)
total_num_zeros = total_num_zeros.item()
return total_num_zeros

View file

@ -0,0 +1,30 @@
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import enum
class LayerType(enum.Enum):
encoder = 1
decoder = 2
class AttnType(enum.Enum):
self_attn = 1
cross_attn = 2
class AttnMaskType(enum.Enum):
padding = 1
causal = 2

View file

@ -0,0 +1,62 @@
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from nemo.collections.nlp.modules.common.megatron.utils import AutocastModuleWrapper
def bias_dropout_add(x, bias, residual, prob, training):
# type: (Tensor, Tensor, Tensor, float, bool) -> Tensor
out = torch.nn.functional.dropout(x + bias, p=prob, training=training)
out = residual + out
return out
@torch.jit.script
def bias_dropout_add_fused_train_(
x: torch.Tensor, bias: torch.Tensor, residual: torch.Tensor, prob: float
) -> torch.Tensor:
# type: (Tensor, Tensor, Tensor, float) -> Tensor
return bias_dropout_add(x, bias, residual, prob, True)
class BiasDropoutAddFusedTrain(AutocastModuleWrapper):
def __init__(self, fp16=False, bf16=False):
super(BiasDropoutAddFusedTrain, self).__init__(fp16, bf16)
self.func = bias_dropout_add_fused_train_
def forward(self, x, bias, residual, prob):
return self.autocast_forward(x, bias, residual, prob)
@torch.jit.script
def bias_dropout_add_fused_inference_(
x: torch.Tensor, bias: torch.Tensor, residual: torch.Tensor, prob: float
) -> torch.Tensor:
# type: (Tensor, Tensor, Tensor, float) -> Tensor
return bias_dropout_add(x, bias, residual, prob, False)
class BiasDropoutAddFusedInference(AutocastModuleWrapper):
def __init__(self, fp16=False, bf16=False):
super(BiasDropoutAddFusedInference, self).__init__(fp16, bf16)
self.func = bias_dropout_add_fused_inference_
def forward(self, x, bias, residual, prob):
return self.autocast_forward(x, bias, residual, prob)

View file

@ -0,0 +1,68 @@
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from nemo.collections.nlp.modules.common.megatron.utils import AutocastModuleWrapper
###### BIAS GELU FUSION/ NO AUTOGRAD ################
# 1/sqrt(2*pi)-> 0.3989423
# 1/sqrt(2) -> 0.70710678
# sqrt(2/pi) -> 0.79788456
# this function is tanh approximation of gelu
# actual gelu is:
# x * 0.5 * (1.0 + torch.erf(x * 0.70710678))
@torch.jit.script
def bias_gelu(bias, y):
x = bias + y
return x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))
# gradient of tanh approximation of gelu
# gradient of actual gelu is:
# 0.5 * (1. + torch.erf(x * 0.70710678)) + 0.3989423 * x * torch.exp(-0.5 * x * x)
@torch.jit.script
def bias_gelu_back(g, bias, y):
x = bias + y
tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))
# sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243
ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out)
return ff * g
class GeLUFunction(torch.autograd.Function):
@staticmethod
# bias is an optional argument
def forward(ctx, input, bias):
ctx.save_for_backward(input, bias)
return bias_gelu(bias, input)
@staticmethod
def backward(ctx, grad_output):
input, bias = ctx.saved_tensors
tmp = bias_gelu_back(grad_output, bias, input)
return tmp, tmp
class FusedBiasGeLU(AutocastModuleWrapper):
def __init__(self, fp16=False, bf16=False):
super(FusedBiasGeLU, self).__init__(fp16, bf16)
self.func = GeLUFunction
def forward(self, input, bias):
return self.autocast_forward(input, bias)

View file

@ -0,0 +1,107 @@
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import pathlib
import subprocess
from torch.utils import cpp_extension
# Setting this param to a list has a problem of generating different
# compilation commands (with diferent order of architectures) and
# leading to recompilation of fused kernels. Set it to empty string
# to avoid recompilation and assign arch flags explicity in
# extra_cuda_cflags below
os.environ["TORCH_CUDA_ARCH_LIST"] = ""
def load():
# Check if cuda 11 is installed for compute capability 8.0
cc_flag = []
_, bare_metal_major, _ = _get_cuda_bare_metal_version(cpp_extension.CUDA_HOME)
if int(bare_metal_major) >= 11:
cc_flag.append('-gencode')
cc_flag.append('arch=compute_80,code=sm_80')
# Build path
srcpath = pathlib.Path(__file__).parent.absolute()
buildpath = srcpath / 'build'
_create_build_dir(buildpath)
# Helper function to build the kernels.
def _cpp_extention_load_helper(name, sources, extra_cuda_flags):
return cpp_extension.load(
name=name,
sources=sources,
build_directory=buildpath,
extra_cflags=['-O3',],
extra_cuda_cflags=['-O3', '-gencode', 'arch=compute_70,code=sm_70', '--use_fast_math']
+ extra_cuda_flags
+ cc_flag,
verbose=True,
)
# ==============
# Fused softmax.
# ==============
extra_cuda_flags = [
'-U__CUDA_NO_HALF_OPERATORS__',
'-U__CUDA_NO_HALF_CONVERSIONS__',
'--expt-relaxed-constexpr',
'--expt-extended-lambda',
]
# Upper triangular softmax.
sources = [
srcpath / 'scaled_upper_triang_masked_softmax.cpp',
srcpath / 'scaled_upper_triang_masked_softmax_cuda.cu',
]
scaled_upper_triang_masked_softmax_cuda = _cpp_extention_load_helper(
"scaled_upper_triang_masked_softmax_cuda", sources, extra_cuda_flags
)
# Masked softmax.
sources = [srcpath / 'scaled_masked_softmax.cpp', srcpath / 'scaled_masked_softmax_cuda.cu']
scaled_masked_softmax_cuda = _cpp_extention_load_helper("scaled_masked_softmax_cuda", sources, extra_cuda_flags)
# =================================
# Mixed precision fused layer norm.
# =================================
extra_cuda_flags = ['-maxrregcount=50']
sources = [srcpath / 'layer_norm_cuda.cpp', srcpath / 'layer_norm_cuda_kernel.cu']
fused_mix_prec_layer_norm_cuda = _cpp_extention_load_helper(
"fused_mix_prec_layer_norm_cuda", sources, extra_cuda_flags
)
def _get_cuda_bare_metal_version(cuda_dir):
raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True)
output = raw_output.split()
release_idx = output.index("release") + 1
release = output[release_idx].split(".")
bare_metal_major = release[0]
bare_metal_minor = release[1][0]
return raw_output, bare_metal_major, bare_metal_minor
def _create_build_dir(buildpath):
try:
os.mkdir(buildpath)
except OSError:
if not os.path.isdir(buildpath):
print(f"Creation of the build directory {buildpath} failed")

View file

@ -0,0 +1,31 @@
/* coding=utf-8
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*This code is copied fron NVIDIA apex:
* https://github.com/NVIDIA/apex
* with minor changes. */
#ifndef TORCH_CHECK
#define TORCH_CHECK AT_CHECK
#endif
#ifdef VERSION_GE_1_3
#define DATA_PTR data_ptr
#else
#define DATA_PTR data
#endif

View file

@ -0,0 +1,201 @@
/* coding=utf-8
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*This code is copied fron NVIDIA apex:
* https://github.com/NVIDIA/apex
* with minor changes. */
#include <torch/extension.h>
#include <vector>
#include <cassert>
#include "compat.h"
namespace {
void compute_n1_n2(
at::Tensor input,
at::IntArrayRef normalized_shape,
int& n1,
int& n2) {
int idiff = input.ndimension() - normalized_shape.size();
n2 = 1;
for (int i = 0; i < (int)normalized_shape.size(); ++i) {
assert( input.sizes()[i+idiff] == normalized_shape[i] );
n2 *= normalized_shape[i];
}
n1 = 1;
for (int i = 0; i < idiff; ++i) {
n1 *= input.sizes()[i];
}
}
void check_args(
at::IntArrayRef normalized_shape,
at::Tensor gamma,
at::Tensor beta
)
{
TORCH_CHECK(!gamma.defined() || gamma.sizes().equals(normalized_shape));
TORCH_CHECK(!beta.defined() || beta.sizes().equals(normalized_shape));
}
void check_args(
at::Tensor input,
at::IntArrayRef normalized_shape,
int& n1,
int& n2
)
{
int64_t normalized_ndim = normalized_shape.size();
if (normalized_ndim < 1) {
std::stringstream ss;
ss << "Expected normalized_shape to be at least 1-dimensional, i.e., "
<< "containing at least one element, but got normalized_shape="
<< normalized_shape;
throw std::runtime_error(ss.str());
}
auto input_shape = input.sizes();
auto input_ndim = input.dim();
if (input_ndim < normalized_ndim ||
!input_shape.slice(input_ndim - normalized_ndim).equals(normalized_shape)) {
std::stringstream ss;
ss << "Given normalized_shape=" << normalized_shape
<< ", expected input with shape [*";
for (auto size : normalized_shape) {
ss << ", " << size;
}
ss << "], but got input of size" << input_shape;
throw std::runtime_error(ss.str());
}
compute_n1_n2(input,normalized_shape,n1,n2);
}
void check_args(
at::Tensor input,
at::IntArrayRef normalized_shape,
at::Tensor gamma,
at::Tensor beta,
int& n1,
int& n2
)
{
check_args(input,normalized_shape,n1,n2);
check_args(normalized_shape,gamma,beta);
}
}
void cuda_layer_norm(
at::Tensor* output,
at::Tensor* mean,
at::Tensor* invvar,
at::Tensor* input,
int n1,
int n2,
at::IntArrayRef normalized_shape,
at::Tensor* gamma,
at::Tensor* beta,
double epsilon);
#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
std::vector<at::Tensor> layer_norm_affine(
at::Tensor input,
at::IntArrayRef normalized_shape,
at::Tensor gamma,
at::Tensor beta,
double epsilon) {
CHECK_INPUT(input);
CHECK_INPUT(gamma);
CHECK_INPUT(beta);
int n1, n2;
check_args(input, normalized_shape, gamma, beta, n1, n2);
at::Tensor output = at::empty_like(
input, gamma.options().dtype(gamma.scalar_type()));
at::Tensor mean = at::empty(
{n1}, input.options().dtype(at::ScalarType::Float));
at::Tensor invvar = at::empty_like(mean);
cuda_layer_norm(&output, &mean, &invvar, &input, n1, n2,
normalized_shape, &gamma, &beta, epsilon);
return {output, mean, invvar};
}
void cuda_layer_norm_gradient(
at::Tensor* dout,
at::Tensor* mean,
at::Tensor* invvar,
at::Tensor* input,
int n1,
int n2,
at::IntArrayRef normalized_shape,
at::Tensor* gamma,
at::Tensor* beta,
double epsilon,
at::Tensor* grad_input,
at::Tensor* grad_gamma,
at::Tensor* grad_beta
);
std::vector<at::Tensor> layer_norm_gradient_affine(
at::Tensor dout,
at::Tensor mean,
at::Tensor invvar,
at::Tensor input,
at::IntArrayRef normalized_shape,
at::Tensor gamma,
at::Tensor beta,
double epsilon) {
CHECK_INPUT(dout);
CHECK_INPUT(mean);
CHECK_INPUT(invvar);
CHECK_INPUT(input);
CHECK_INPUT(gamma);
CHECK_INPUT(beta);
int n1, n2;
check_args(input, normalized_shape, gamma, beta, n1, n2);
at::Tensor grad_input = at::empty_like(input);
at::Tensor grad_gamma = at::empty_like(gamma);
at::Tensor grad_beta = at::empty_like(beta);
cuda_layer_norm_gradient(&dout, &mean, &invvar, &input, n1, n2,
normalized_shape, &gamma, &beta, epsilon,
&grad_input, &grad_gamma, &grad_beta);
return {grad_input, grad_gamma, grad_beta};
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward_affine", &layer_norm_affine,
"LayerNorm forward (CUDA)");
m.def("backward_affine", &layer_norm_gradient_affine,
"LayerNorm backward (CUDA)");
}

View file

@ -0,0 +1,829 @@
/* coding=utf-8
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*This code is copied fron NVIDIA apex:
* https://github.com/NVIDIA/apex
* with minor changes. */
#include "ATen/ATen.h"
#include "ATen/AccumulateType.h"
#include "ATen/cuda/CUDAContext.h"
#include "ATen/cuda/DeviceUtils.cuh"
#include <cuda.h>
#include <cuda_runtime.h>
#include "type_shim.h"
template<typename U> __device__
void cuWelfordOnlineSum(
const U curr,
U& mu,
U& sigma2,
U& count)
{
count = count + U(1);
U delta = curr - mu;
U lmean = mu + delta / count;
mu = lmean;
U delta2 = curr - lmean;
sigma2 = sigma2 + delta * delta2;
}
template<typename U> __device__
void cuChanOnlineSum(
const U muB,
const U sigma2B,
const U countB,
U& mu,
U& sigma2,
U& count)
{
U delta = muB - mu;
U nA = count;
U nB = countB;
count = count + countB;
U nX = count;
if (nX > U(0)) {
nA = nA / nX;
nB = nB / nX;
mu = nA*mu + nB*muB;
sigma2 = sigma2 + sigma2B + delta * delta * nA * nB * nX;
} else {
mu = U(0);
sigma2 = U(0);
}
}
template<typename T, typename U> __device__
void cuWelfordMuSigma2(
const T* __restrict__ vals,
const int n1,
const int n2,
const int i1,
U& mu,
U& sigma2,
U* buf)
{
// Assumptions:
// 1) blockDim.x == warpSize
// 2) Tensor is contiguous
// 3) 2*blockDim.y*sizeof(U)+blockDim.y*sizeof(int) shared memory available.
//
// compute variance and mean over n2
U count = U(0);
mu= U(0);
sigma2 = U(0);
if (i1 < n1) {
// one warp normalizes one n1 index,
// synchronization is implicit
// initialize with standard Welford algorithm
const int numx = blockDim.x * blockDim.y;
const int thrx = threadIdx.x + threadIdx.y * blockDim.x;
const T* lvals = vals + i1*n2;
int l = 4*thrx;
for (; l+3 < n2; l+=4*numx) {
for (int k = 0; k < 4; ++k) {
U curr = static_cast<U>(lvals[l+k]);
cuWelfordOnlineSum<U>(curr,mu,sigma2,count);
}
}
for (; l < n2; ++l) {
U curr = static_cast<U>(lvals[l]);
cuWelfordOnlineSum<U>(curr,mu,sigma2,count);
}
// intra-warp reductions
for (int l = 0; l <= 4; ++l) {
int srcLaneB = (threadIdx.x+(1<<l))&31;
U muB = WARP_SHFL(mu, srcLaneB);
U countB = WARP_SHFL(count, srcLaneB);
U sigma2B = WARP_SHFL(sigma2, srcLaneB);
cuChanOnlineSum<U>(muB,sigma2B,countB,mu,sigma2,count);
}
// threadIdx.x == 0 has correct values for each warp
// inter-warp reductions
if (blockDim.y > 1) {
U* ubuf = (U*)buf;
U* ibuf = (U*)(ubuf + blockDim.y);
for (int offset = blockDim.y/2; offset > 0; offset /= 2) {
// upper half of warps write to shared
if (threadIdx.x == 0 && threadIdx.y >= offset && threadIdx.y < 2*offset) {
const int wrt_y = threadIdx.y - offset;
ubuf[2*wrt_y] = mu;
ubuf[2*wrt_y+1] = sigma2;
ibuf[wrt_y] = count;
}
__syncthreads();
// lower half merges
if (threadIdx.x == 0 && threadIdx.y < offset) {
U muB = ubuf[2*threadIdx.y];
U sigma2B = ubuf[2*threadIdx.y+1];
U countB = ibuf[threadIdx.y];
cuChanOnlineSum<U>(muB,sigma2B,countB,mu,sigma2,count);
}
__syncthreads();
}
// threadIdx.x = 0 && threadIdx.y == 0 only thread that has correct values
if (threadIdx.x == 0 && threadIdx.y == 0) {
ubuf[0] = mu;
ubuf[1] = sigma2;
}
__syncthreads();
mu = ubuf[0];
sigma2 = ubuf[1]/U(n2);
// don't care about final value of count, we know count == n2
} else {
mu = WARP_SHFL(mu, 0);
sigma2 = WARP_SHFL(sigma2/U(n2), 0);
}
}
}
template<> __device__
void cuWelfordMuSigma2(
const at::Half* __restrict__ vals,
const int n1,
const int n2,
const int i1,
float& mu,
float& sigma2,
float* buf)
{
// Assumptions:
// 1) blockDim.x == warpSize
// 2) Tensor is contiguous
// 3) 2*blockDim.y*sizeof(U)+blockDim.y*sizeof(int) shared memory available.
//
// compute variance and mean over n2
float count = 0.0f;
mu= float(0);
sigma2 = float(0);
if (i1 < n1) {
// one warp normalizes one n1 index,
// synchronization is implicit
// initialize with standard Welford algorithm
const int numx = blockDim.x * blockDim.y;
const int thrx = threadIdx.x + threadIdx.y * blockDim.x;
const at::Half* lvals = vals + i1*n2;
int l = 8*thrx;
if ((((size_t)lvals)&3) != 0) {
// 16 bit alignment
// first thread consumes first point
if (thrx == 0) {
float curr = static_cast<float>(lvals[0]);
cuWelfordOnlineSum(curr,mu,sigma2,count);
}
++l;
}
// at this point, lvals[l] are 32 bit aligned for all threads.
for (; l+7 < n2; l+=8*numx) {
for (int k = 0; k < 8; k+=2) {
float2 curr = __half22float2(*((__half2*)(lvals+l+k)));
cuWelfordOnlineSum(curr.x,mu,sigma2,count);
cuWelfordOnlineSum(curr.y,mu,sigma2,count);
}
}
for (; l < n2; ++l) {
float curr = static_cast<float>(lvals[l]);
cuWelfordOnlineSum(curr,mu,sigma2,count);
}
// intra-warp reductions
for (int l = 0; l <= 4; ++l) {
int srcLaneB = (threadIdx.x+(1<<l))&31;
float muB = WARP_SHFL(mu, srcLaneB);
float countB = WARP_SHFL(count, srcLaneB);
float sigma2B = WARP_SHFL(sigma2, srcLaneB);
cuChanOnlineSum(muB,sigma2B,countB,mu,sigma2,count);
}
// threadIdx.x == 0 has correct values for each warp
// inter-warp reductions
if (blockDim.y > 1) {
float* ubuf = (float*)buf;
float* ibuf = (float*)(ubuf + blockDim.y);
for (int offset = blockDim.y/2; offset > 0; offset /= 2) {
// upper half of warps write to shared
if (threadIdx.x == 0 && threadIdx.y >= offset && threadIdx.y < 2*offset) {
const int wrt_y = threadIdx.y - offset;
ubuf[2*wrt_y] = mu;
ubuf[2*wrt_y+1] = sigma2;
ibuf[wrt_y] = count;
}
__syncthreads();
// lower half merges
if (threadIdx.x == 0 && threadIdx.y < offset) {
float muB = ubuf[2*threadIdx.y];
float sigma2B = ubuf[2*threadIdx.y+1];
float countB = ibuf[threadIdx.y];
cuChanOnlineSum(muB,sigma2B,countB,mu,sigma2,count);
}
__syncthreads();
}
// threadIdx.x = 0 && threadIdx.y == 0 only thread that has correct values
if (threadIdx.x == 0 && threadIdx.y == 0) {
ubuf[0] = mu;
ubuf[1] = sigma2;
}
__syncthreads();
mu = ubuf[0];
sigma2 = ubuf[1]/float(n2);
// don't care about final value of count, we know count == n2
} else {
mu = WARP_SHFL(mu, 0);
sigma2 = WARP_SHFL(sigma2/float(n2), 0);
}
}
}
template<typename U> U rsqrt(U v) {
return U(1) / sqrt(v);
}
template<> float rsqrt(float v) {
return rsqrtf(v);
}
template<> double rsqrt(double v) {
return rsqrt(v);
}
namespace {
// This is the un-specialized struct. Note that we prevent instantiation of this
// struct by putting an undefined symbol in the function body so it won't compile.
// template <typename T>
// struct SharedMemory
// {
// // Ensure that we won't compile any un-specialized types
// __device__ T *getPointer()
// {
// extern __device__ void error(void);
// error();
// return NULL;
// }
// };
// https://github.com/NVIDIA/apex/issues/246
template <typename T>
struct SharedMemory;
template <>
struct SharedMemory <float>
{
__device__ float *getPointer()
{
extern __shared__ float s_float[];
return s_float;
}
};
}
template<typename T, typename U, typename V> __global__
void cuApplyLayerNorm(
V* __restrict__ output_vals,
U* __restrict__ mean,
U* __restrict__ invvar,
const T* __restrict__ vals,
const int n1,
const int n2,
const U epsilon,
const V* __restrict__ gamma,
const V* __restrict__ beta
)
{
// Assumptions:
// 1) blockDim.x == warpSize
// 2) Tensors are contiguous
//
for (auto i1=blockIdx.y; i1 < n1; i1 += gridDim.y) {
SharedMemory<U> shared;
U* buf = shared.getPointer();
U mu,sigma2;
cuWelfordMuSigma2(vals,n1,n2,i1,mu,sigma2,buf);
const T* lvals = vals + i1*n2;
V* ovals = output_vals + i1*n2;
U c_invvar = rsqrt(sigma2 + epsilon);
const int numx = blockDim.x * blockDim.y;
const int thrx = threadIdx.x + threadIdx.y * blockDim.x;
if (gamma != NULL && beta != NULL) {
for (int i = thrx; i < n2; i+=numx) {
U curr = static_cast<U>(lvals[i]);
ovals[i] = gamma[i] * static_cast<V>(c_invvar * (curr - mu)) + beta[i];
}
} else {
for (int i = thrx; i < n2; i+=numx) {
U curr = static_cast<U>(lvals[i]);
ovals[i] = static_cast<V>(c_invvar * (curr - mu));
}
}
if (threadIdx.x == 0 && threadIdx.y == 0) {
mean[i1] = mu;
invvar[i1] = c_invvar;
}
}
}
template<typename T, typename U, typename V> __device__
void cuLoadWriteStridedInputs(
const int i1_block,
const int thr_load_row_off,
const int thr_load_col_off,
const int i2_off,
const int row_stride,
U* warp_buf1,
U* warp_buf2,
const T* input,
const V* dout,
const int i1_end,
const int n2,
const U* __restrict__ mean,
const U* __restrict__ invvar
)
{
int i1 = i1_block+thr_load_row_off;
if (i1 < i1_end) {
U curr_mean = mean[i1];
U curr_invvar = invvar[i1];
for (int k = 0; k < blockDim.y; ++k) {
int i2 = i2_off + k;
int load_idx = i1*n2+i2;
int write_idx = thr_load_row_off*row_stride+thr_load_col_off+k;
if (i2<n2) {
U curr_input = static_cast<U>(input[load_idx]);
U curr_dout = static_cast<U>(dout[load_idx]);
warp_buf1[write_idx] = curr_dout;
warp_buf2[write_idx] = curr_dout * (curr_input - curr_mean) * curr_invvar;
} else {
warp_buf1[write_idx] = U(0);
warp_buf2[write_idx] = U(0);
}
}
} else {
for (int k = 0; k < blockDim.y; ++k) {
int write_idx = thr_load_row_off*row_stride+thr_load_col_off+k;
warp_buf1[write_idx] = U(0);
warp_buf2[write_idx] = U(0);
}
}
}
template<typename T, typename U, typename V> __device__
void cuLoadAddStridedInputs(
const int i1_block,
const int thr_load_row_off,
const int thr_load_col_off,
const int i2_off,
const int row_stride,
U* warp_buf1,
U* warp_buf2,
const T* input,
const V* dout,
const int i1_end,
const int n2,
const U* __restrict__ mean,
const U* __restrict__ invvar
)
{
int i1 = i1_block+thr_load_row_off;
if (i1 < i1_end) {
U curr_mean = mean[i1];
U curr_invvar = invvar[i1];
for (int k = 0; k < blockDim.y; ++k) {
int i2 = i2_off + k;
int load_idx = i1*n2+i2;
int write_idx = thr_load_row_off*row_stride+thr_load_col_off+k;
if (i2<n2) {
U curr_input = static_cast<U>(input[load_idx]);
U curr_dout = static_cast<U>(dout[load_idx]);
warp_buf1[write_idx] += curr_dout;
warp_buf2[write_idx] += curr_dout * (curr_input - curr_mean) * curr_invvar;
}
}
}
}
template<typename T, typename U, typename V> __global__
void cuComputePartGradGammaBeta(
const V* __restrict__ dout,
const T* __restrict__ input,
const int n1,
const int n2,
const U* __restrict__ mean,
const U* __restrict__ invvar,
U epsilon,
U* part_grad_gamma,
U* part_grad_beta)
{
const int numsegs_n1 = (n1+blockDim.y*blockDim.y-1) / (blockDim.y*blockDim.y);
const int segs_per_block = (numsegs_n1 + gridDim.y - 1) / gridDim.y;
const int i1_beg = blockIdx.y * segs_per_block * blockDim.y*blockDim.y;
const int i1_beg_plus_one = (blockIdx.y+1) * segs_per_block * blockDim.y*blockDim.y;
const int i1_end = i1_beg_plus_one < n1 ? i1_beg_plus_one : n1;
const int row_stride = blockDim.x+1;
const int thr_load_col_off = (threadIdx.x*blockDim.y)&(blockDim.x-1);
const int thr_load_row_off = (threadIdx.x*blockDim.y)/blockDim.x + threadIdx.y*blockDim.y;
const int i2_off = blockIdx.x * blockDim.x + thr_load_col_off;
SharedMemory<U> shared;
U* buf = shared.getPointer(); // buf has at least blockDim.x * blockDim.y * blockDim.y + (blockDim.y - 1)*(blockDim.x/blockDim.y) elements
U* warp_buf1 = (U*)buf;
U* warp_buf2 = warp_buf1 + blockDim.y * blockDim.y * row_stride;
// compute partial sums from strided inputs
// do this to increase number of loads in flight
cuLoadWriteStridedInputs(i1_beg,thr_load_row_off,thr_load_col_off,i2_off,row_stride,warp_buf1,warp_buf2,input,dout,i1_end,n2,mean,invvar);
for (int i1_block = i1_beg+blockDim.y*blockDim.y; i1_block < i1_end; i1_block+=blockDim.y*blockDim.y) {
cuLoadAddStridedInputs(i1_block,thr_load_row_off,thr_load_col_off,i2_off,row_stride,warp_buf1,warp_buf2,input,dout,i1_end,n2,mean,invvar);
}
__syncthreads();
// inter-warp reductions
// sum within each warp
U acc1 = U(0);
U acc2 = U(0);
for (int k = 0; k < blockDim.y; ++k) {
int row1 = threadIdx.y + k*blockDim.y;
int idx1 = row1*row_stride + threadIdx.x;
acc1 += warp_buf1[idx1];
acc2 += warp_buf2[idx1];
}
warp_buf1[threadIdx.y*row_stride+threadIdx.x] = acc1;
warp_buf2[threadIdx.y*row_stride+threadIdx.x] = acc2;
__syncthreads();
// sum all warps
for (int offset = blockDim.y/2; offset > 1; offset /= 2) {
if (threadIdx.y < offset) {
int row1 = threadIdx.y;
int row2 = threadIdx.y + offset;
int idx1 = row1*row_stride + threadIdx.x;
int idx2 = row2*row_stride + threadIdx.x;
warp_buf1[idx1] += warp_buf1[idx2];
warp_buf2[idx1] += warp_buf2[idx2];
}
__syncthreads();
}
int i2 = blockIdx.x * blockDim.x + threadIdx.x;
if (threadIdx.y == 0 && i2 < n2) {
int row1 = threadIdx.y;
int row2 = threadIdx.y + 1;
int idx1 = row1*row_stride + threadIdx.x;
int idx2 = row2*row_stride + threadIdx.x;
part_grad_beta[blockIdx.y*n2+i2] = warp_buf1[idx1] + warp_buf1[idx2];
part_grad_gamma[blockIdx.y*n2+i2] = warp_buf2[idx1] + warp_buf2[idx2];
}
}
template<typename U, typename V> __global__
void cuComputeGradGammaBeta(
const U* part_grad_gamma,
const U* part_grad_beta,
const int part_size,
const int n1,
const int n2,
V* grad_gamma,
V* grad_beta)
{
// sum partial gradients for gamma and beta
SharedMemory<U> shared;
U* buf = shared.getPointer();
int i2 = blockIdx.x * blockDim.x + threadIdx.x;
if (i2 < n2) {
// each warp does sequential reductions until reduced part_size is num_warps
int num_warp_reductions = part_size / blockDim.y;
U sum_gamma = U(0);
U sum_beta = U(0);
const U* part_grad_gamma_ptr = part_grad_gamma + threadIdx.y * num_warp_reductions * n2 + i2;
const U* part_grad_beta_ptr = part_grad_beta + threadIdx.y * num_warp_reductions * n2 + i2;
for (int warp_offset = 0; warp_offset < num_warp_reductions; ++warp_offset) {
sum_gamma += part_grad_gamma_ptr[warp_offset*n2];
sum_beta += part_grad_beta_ptr[warp_offset*n2];
}
// inter-warp reductions
const int nbsize3 = blockDim.x * blockDim.y / 2;
for (int offset = blockDim.y/2; offset >= 1; offset /= 2) {
// top half write to shared memory
if (threadIdx.y >= offset && threadIdx.y < 2*offset) {
const int write_idx = (threadIdx.y - offset) * blockDim.x + threadIdx.x;
buf[write_idx] = sum_gamma;
buf[write_idx+nbsize3] = sum_beta;
}
__syncthreads();
// bottom half sums
if (threadIdx.y < offset) {
const int read_idx = threadIdx.y * blockDim.x + threadIdx.x;
sum_gamma += buf[read_idx];
sum_beta += buf[read_idx+nbsize3];
}
__syncthreads();
}
// write out fully summed gradients
if (threadIdx.y == 0) {
grad_gamma[i2] = sum_gamma;
grad_beta[i2] = sum_beta;
}
}
}
template<typename T, typename U, typename V> __global__
void cuComputeGradInput(
const V* __restrict__ dout,
const T* __restrict__ input,
const int n1,
const int n2,
const U* __restrict__ mean,
const U* __restrict__ invvar,
U epsilon,
const V* gamma,
T* grad_input)
{
for (auto i1=blockIdx.y; i1 < n1; i1 += gridDim.y) {
U sum_loss1 = U(0);
U sum_loss2 = U(0);
const U c_mean = mean[i1];
const U c_invvar = invvar[i1];
const T* k_input = input + i1*n2;
const V* k_dout = dout + i1*n2;
const int numx = blockDim.x * blockDim.y;
const int thrx = threadIdx.x + threadIdx.y * blockDim.x;
if (gamma != NULL) {
int l = 4*thrx;
for (; l+3 < n2; l+=4*numx) {
for (int k = 0; k < 4; ++k) {
const U c_h = static_cast<U>(k_input[l+k]);
const U c_loss = static_cast<U>(k_dout[l+k]);
sum_loss1 += c_loss * gamma[l+k];
sum_loss2 += c_loss * gamma[l+k] * (c_h - c_mean) * c_invvar;
}
}
for (; l < n2; ++l) {
const U c_h = static_cast<U>(k_input[l]);
const U c_loss = static_cast<U>(k_dout[l]);
sum_loss1 += c_loss * gamma[l];
sum_loss2 += c_loss * gamma[l] * (c_h - c_mean) * c_invvar;
}
} else {
int l = 4*thrx;
for (; l+3 < n2; l+=4*numx) {
for (int k = 0; k < 4; ++k) {
const U c_h = static_cast<U>(k_input[l+k]);
const U c_loss = static_cast<U>(k_dout[l+k]);
sum_loss1 += c_loss;
sum_loss2 += c_loss * (c_h - c_mean) * c_invvar;
}
}
for (; l < n2; ++l) {
const U c_h = static_cast<U>(k_input[l]);
const U c_loss = static_cast<U>(k_dout[l]);
sum_loss1 += c_loss;
sum_loss2 += c_loss * (c_h - c_mean) * c_invvar;
}
}
// intra-warp reductions
for (int mask = blockDim.x/2; mask > 0; mask /= 2) {
sum_loss1 += WARP_SHFL_XOR(sum_loss1, mask);
sum_loss2 += WARP_SHFL_XOR(sum_loss2, mask);
}
// inter-warp reductions
if (blockDim.y > 1) {
SharedMemory<U> shared;
U* buf = shared.getPointer();
for (int offset = blockDim.y/2; offset > 0; offset /= 2) {
// upper half of warps write to shared
if (threadIdx.y >= offset && threadIdx.y < 2*offset) {
const int wrt_i = (threadIdx.y - offset) * blockDim.x + threadIdx.x;
buf[2*wrt_i] = sum_loss1;
buf[2*wrt_i+1] = sum_loss2;
}
__syncthreads();
// lower half merges
if (threadIdx.y < offset) {
const int read_i = threadIdx.y * blockDim.x + threadIdx.x;
sum_loss1 += buf[2*read_i];
sum_loss2 += buf[2*read_i+1];
}
__syncthreads();
}
if (threadIdx.y == 0) {
buf[2*threadIdx.x] = sum_loss1;
buf[2*threadIdx.x+1] = sum_loss2;
}
__syncthreads();
if (threadIdx.y !=0) {
sum_loss1 = buf[2*threadIdx.x];
sum_loss2 = buf[2*threadIdx.x+1];
}
}
// all threads now have the two sums over l
U fH = (U)n2;
U term1 = (U(1) / fH) * c_invvar;
T* k_grad_input = grad_input + i1*n2;
if (gamma != NULL) {
for (int l = thrx; l < n2; l+=numx) {
const U c_h = static_cast<U>(k_input[l]);
const U c_loss = static_cast<U>(k_dout[l]);
U f_grad_input = fH * c_loss * gamma[l];
f_grad_input -= sum_loss1;
f_grad_input -= (c_h - c_mean) * c_invvar * sum_loss2;
f_grad_input *= term1;
k_grad_input[l] = static_cast<T>(f_grad_input);
}
} else {
for (int l = thrx; l < n2; l+=numx) {
const U c_h = static_cast<U>(k_input[l]);
const U c_loss = static_cast<U>(k_dout[l]);
U f_grad_input = fH * c_loss;
f_grad_input -= sum_loss1;
f_grad_input -= (c_h - c_mean) * c_invvar * sum_loss2;
f_grad_input *= term1;
k_grad_input[l] = static_cast<T>(f_grad_input);
}
}
}
}
template<typename T, typename U, typename V>
void HostApplyLayerNorm(
V* output,
U* mean,
U* invvar,
const T* input,
int n1,
int n2,
double epsilon,
const V* gamma,
const V* beta
)
{
auto stream = at::cuda::getCurrentCUDAStream().stream();
const dim3 threads(32,4,1);
const uint64_t maxGridY =
at::cuda::getCurrentDeviceProperties()->maxGridSize[1];
const dim3 blocks(1, std::min((uint64_t)n1, maxGridY), 1);
int nshared =
threads.y > 1 ?
threads.y*sizeof(U)+(threads.y/2)*sizeof(U) :
0;
cuApplyLayerNorm<<<blocks, threads, nshared, stream>>>(
output,
mean,
invvar,
input,
n1,n2,
U(epsilon),
gamma,beta);
}
void cuda_layer_norm(
at::Tensor* output,
at::Tensor* mean,
at::Tensor* invvar,
at::Tensor* input,
int n1,
int n2,
#ifdef VERSION_GE_1_1
at::IntArrayRef normalized_shape,
#else
at::IntList normalized_shape,
#endif
at::Tensor* gamma,
at::Tensor* beta,
double epsilon)
{
using namespace at;
DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(
input->scalar_type(), output->scalar_type(), "cuda_layer_norm_kernel",
HostApplyLayerNorm(
output->DATA_PTR<scalar_t_out>(),
mean->DATA_PTR<float>(),
invvar->DATA_PTR<float>(),
input->DATA_PTR<scalar_t_in>(),
n1,n2,
epsilon,
gamma != NULL ? gamma->DATA_PTR<scalar_t_out>() : NULL,
beta != NULL ? beta->DATA_PTR<scalar_t_out>() : NULL);
)
}
template<typename T, typename U, typename V>
void HostLayerNormGradient(
const V* dout,
const U* mean,
const U* invvar,
at::Tensor* input,
int n1,
int n2,
const V* gamma,
const V* beta,
double epsilon,
T* grad_input,
V* grad_gamma,
V* grad_beta
)
{
auto stream = at::cuda::getCurrentCUDAStream().stream();
if (gamma != NULL && beta != NULL) {
// compute grad_gamma(j) and grad_beta(j)
const int part_size = 16;
const dim3 threads2(32,4,1);
const dim3 blocks2((n2+threads2.x-1)/threads2.x,part_size,1);
const int nshared2_a = 2 * sizeof(U) * threads2.y * threads2.y *
(threads2.x + 1);
const int nshared2_b = threads2.x * threads2.y * sizeof(U);
const int nshared2 = nshared2_a > nshared2_b ? nshared2_a : nshared2_b;
at::Tensor part_grad_gamma = at::empty(
{part_size,n2}, input->options().dtype(at::ScalarType::Float));
at::Tensor part_grad_beta = at::empty_like(part_grad_gamma);
cuComputePartGradGammaBeta<<<blocks2, threads2, nshared2, stream>>>(
dout,
input->DATA_PTR<T>(),
n1,n2,
mean,
invvar,
U(epsilon),
part_grad_gamma.DATA_PTR<U>(),
part_grad_beta.DATA_PTR<U>());
const dim3 threads3(32,8,1);
const dim3 blocks3((n2+threads2.x-1)/threads2.x,1,1);
const int nshared3 = threads3.x * threads3.y * sizeof(U);
cuComputeGradGammaBeta<<<blocks3, threads3, nshared3, stream>>>(
part_grad_gamma.DATA_PTR<U>(),
part_grad_beta.DATA_PTR<U>(),
part_size,
n1,n2,
grad_gamma,
grad_beta);
}
// compute grad_input
const uint64_t maxGridY =
at::cuda::getCurrentDeviceProperties()->maxGridSize[1];
const dim3 blocks1(1, std::min((uint64_t)n1, maxGridY), 1);
const dim3 threads1(32,4,1);
int nshared =
threads1.y > 1 ?
threads1.y*threads1.x*sizeof(U) :
0;
cuComputeGradInput<<<blocks1, threads1, nshared, stream>>>(
dout,
input->DATA_PTR<T>(),
n1,n2,
mean,
invvar,
U(epsilon),
gamma,
grad_input);
}
void cuda_layer_norm_gradient(
at::Tensor* dout,
at::Tensor* mean,
at::Tensor* invvar,
at::Tensor* input,
int n1,
int n2,
#ifdef VERSION_GE_1_1
at::IntArrayRef normalized_shape,
#else
at::IntList normalized_shape,
#endif
at::Tensor* gamma,
at::Tensor* beta,
double epsilon,
at::Tensor* grad_input,
at::Tensor* grad_gamma,
at::Tensor* grad_beta)
{
using namespace at;
DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(
input->scalar_type(), gamma->scalar_type(),
"cuda_layer_norm_gradient_kernel",
HostLayerNormGradient(
dout->DATA_PTR<scalar_t_out>(),
mean->DATA_PTR<float>(),
invvar->DATA_PTR<float>(),
input,
n1,n2,
// TMJ pass NULL argument for gamma, beta, grad_gamma and grad_beta
// if gamma Tensor is NULL on input.
gamma != NULL ? gamma->DATA_PTR<scalar_t_out>() : NULL,
gamma != NULL ? beta->DATA_PTR<scalar_t_out>() : NULL,
epsilon,
grad_input->DATA_PTR<scalar_t_in>(),
gamma != NULL ? grad_gamma->DATA_PTR<scalar_t_out>() : NULL,
gamma != NULL ? grad_beta->DATA_PTR<scalar_t_out>() : NULL);
)
}

View file

@ -0,0 +1,97 @@
/* coding=utf-8
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <cuda_fp16.h>
#include <torch/extension.h>
#include <vector>
namespace multihead_attn {
namespace fused_softmax {
namespace scaled_masked_softmax {
torch::Tensor fwd_cuda(
torch::Tensor const& input,
torch::Tensor const& mask,
float scale_factor);
torch::Tensor bwd_cuda(
torch::Tensor const& output_grads,
torch::Tensor const& softmax_results,
float scale_factor);
int get_batch_per_block_cuda(
int query_seq_len,
int key_seq_len,
int batches,
int attn_heads);
torch::Tensor fwd(
torch::Tensor const& input,
torch::Tensor const& mask,
float scale_factor) {
AT_ASSERTM(input.dim() == 4, "expected 4D tensor");
AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) ||
(input.scalar_type() == at::ScalarType::BFloat16),
"Only fp16 and bf16 are supported");
AT_ASSERTM(mask.dim() == 4, "expected 4D tensor");
return fwd_cuda(input, mask, scale_factor);
}
torch::Tensor bwd(
torch::Tensor const& output_grads,
torch::Tensor const& softmax_results,
float scale_factor) {
AT_ASSERTM(output_grads.dim() == 4, "expected 3D tensor");
AT_ASSERTM(softmax_results.dim() == 4, "expected 3D tensor");
AT_ASSERTM((output_grads.scalar_type() == at::ScalarType::Half) ||
(output_grads.scalar_type() == at::ScalarType::BFloat16),
"Only fp16 and bf16 are supported");
AT_ASSERTM((softmax_results.scalar_type() == at::ScalarType::Half) ||
(softmax_results.scalar_type() == at::ScalarType::BFloat16),
"Only fp16 and bf16 are supported");
return bwd_cuda(output_grads, softmax_results, scale_factor);
}
int get_batch_per_block(
int query_seq_len,
int key_seq_len,
int batches,
int attn_heads) {
return get_batch_per_block_cuda(query_seq_len, key_seq_len, batches, attn_heads);
}
} // end namespace scaled_masked_softmax
} // end namespace fused_softmax
} // end namespace multihead_attn
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward",
&multihead_attn::fused_softmax::scaled_masked_softmax::fwd,
"Self Multihead Attention scaled, time masked softmax -- Forward.");
m.def("backward",
&multihead_attn::fused_softmax::scaled_masked_softmax::bwd,
"Self Multihead Attention scaled, time masked softmax -- Backward.");
m.def("get_batch_per_block",
&multihead_attn::fused_softmax::scaled_masked_softmax::get_batch_per_block,
"Return Batch per block size."
);
}

View file

@ -0,0 +1,505 @@
/* coding=utf-8
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <assert.h>
#include <cuda_fp16.h>
#include <cfloat>
#include <limits>
#include <stdint.h>
#include <cuda_fp16.h>
#include <c10/macros/Macros.h>
namespace {
template <typename Datatype, int ELEMENTS_PER_LDG>
__device__ __inline__ void copy_vector(Datatype *dst, const Datatype *src);
template <>
__device__ __inline__ void copy_vector<c10::BFloat16, 1>(c10::BFloat16 *dst, const c10::BFloat16 *src) { *dst = *src; }
template <>
__device__ __inline__ void copy_vector<c10::BFloat16, 4>(c10::BFloat16 *dst, const c10::BFloat16 *src) { *((float2*) dst) = *((float2*) src); }
template <>
__device__ __inline__ void copy_vector<c10::Half, 1>(c10::Half *dst, const c10::Half *src) { *dst = *src; }
template <>
__device__ __inline__ void copy_vector<c10::Half, 4>(c10::Half *dst, const c10::Half *src) { *((float2*) dst) = *((float2*) src); }
template <>
__device__ __inline__ void copy_vector<uint8_t, 1>(uint8_t *dst, const uint8_t *src) { *dst = *src; }
template <>
__device__ __inline__ void copy_vector<uint8_t, 4>(uint8_t *dst, const uint8_t *src) {*((half2*) dst) = *((half2*) src); }
int log2_ceil(int value) {
int log2_value = 0;
while ((1 << log2_value) < value) ++log2_value;
return log2_value;
}
template<typename T>
struct Add {
__device__ __forceinline__ T operator()(T a, T b) const {
return a + b;
}
};
template<typename T>
struct Max {
__device__ __forceinline__ T operator()(T a, T b) const {
return a < b ? b : a;
}
};
template <typename T>
__device__ __forceinline__ T WARP_SHFL_XOR_NATIVE(T value, int laneMask, int width = warpSize, unsigned int mask = 0xffffffff)
{
#if CUDA_VERSION >= 9000
return __shfl_xor_sync(mask, value, laneMask, width);
#else
return __shfl_xor(value, laneMask, width);
#endif
}
template <typename acc_t, int WARP_BATCH, int WARP_SIZE, template<typename> class ReduceOp>
__device__ __forceinline__ void warp_reduce(acc_t* sum) {
ReduceOp<acc_t> r;
#pragma unroll
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
acc_t b = WARP_SHFL_XOR_NATIVE(sum[i], offset, WARP_SIZE);
sum[i] = r(sum[i], b);
}
}
}
/*
* Extended softmax (from native aten pytorch) with following additional features
* 1) input scaling
* 2) Explicit masking
*/
template <typename input_t, typename output_t, typename acc_t, int log2_elements>
__global__ void scaled_masked_softmax_warp_forward(
output_t *dst,
const input_t *src,
const uint8_t *mask,
const acc_t scale,
int micro_batch_size,
int element_count,
int pad_batches)
{
// WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and
// warp_size of method warp_softmax_forward_kernel.
constexpr int next_power_of_two = 1 << log2_elements;
constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE;
constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1;
constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4;
// blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, )
// gridDim/blockIdx = (seq_len, attn_heads, batches)
int first_batch = (blockDim.y * (blockIdx.x + gridDim.x * (blockIdx.y + gridDim.y * blockIdx.z))+ threadIdx.y) * WARP_BATCH;
int pad_first_batch = 0;
if (pad_batches != 1) { // bert style
pad_first_batch = (blockDim.y * (blockIdx.x + gridDim.x * blockIdx.z) + threadIdx.y) * WARP_BATCH;
} else { // gpt2 style
pad_first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH;
}
// micro_batch_size might not be a multiple of WARP_BATCH. Check how
// many batches have to computed within this WARP.
int local_batches = micro_batch_size - first_batch;
if (local_batches > WARP_BATCH)
local_batches = WARP_BATCH;
// there might be multiple batches per warp. compute the index within the batch
int local_idx = threadIdx.x;
src += first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx;
dst += first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx;
mask += pad_first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx;
// load data from global memory
acc_t elements[WARP_BATCH][WARP_ITERATIONS];
input_t temp_data[ELEMENTS_PER_LDG_STG];
uint8_t temp_mask[ELEMENTS_PER_LDG_STG];
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
int batch_element_count = (i >= local_batches) ? 0 : element_count;
#pragma unroll
for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) {
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
if (element_index < batch_element_count) {
int itr_idx = i*element_count+it*WARP_SIZE;
copy_vector<input_t, ELEMENTS_PER_LDG_STG>(temp_data, src + itr_idx);
copy_vector<uint8_t, ELEMENTS_PER_LDG_STG>(temp_mask, mask + itr_idx);
#pragma unroll
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
if (temp_mask[element] != 1) {
elements[i][it + element] = (acc_t)temp_data[element] * scale;
} else {
elements[i][it + element] = -10000.0;
}
}
} else {
#pragma unroll
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
elements[i][it + element] = -std::numeric_limits<acc_t>::infinity();
}
}
}
}
// compute max_value
acc_t max_value[WARP_BATCH];
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
max_value[i] = elements[i][0];
#pragma unroll
for (int it = 1; it < WARP_ITERATIONS; ++it) {
max_value[i] = (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it];
}
}
warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Max>(max_value);
acc_t sum[WARP_BATCH] { 0.0f };
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
#pragma unroll
for (int it = 0; it < WARP_ITERATIONS; ++it) {
elements[i][it] = std::exp((elements[i][it] - max_value[i]));
sum[i] += elements[i][it];
}
}
warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Add>(sum);
// store result
output_t out[ELEMENTS_PER_LDG_STG];
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
if (i >= local_batches)
break;
#pragma unroll
for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) {
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
if (element_index < element_count) {
#pragma unroll
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
out[element] = elements[i][it + element] / sum[i];
}
copy_vector<output_t, ELEMENTS_PER_LDG_STG>(dst + i * element_count + it * WARP_SIZE, out);
} else {
break;
}
}
}
}
template <typename input_t, typename output_t, typename acc_t, int log2_elements>
__global__ void scaled_masked_softmax_warp_backward(
output_t *gradInput,
input_t *grad,
const input_t *output,
acc_t scale,
int micro_batch_size,
int element_count)
{
// WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and
// warp_size of method warp_softmax_backward_kernel.
constexpr int next_power_of_two = 1 << log2_elements;
constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE;
constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1;
constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4;
// blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, )
// gridDim/blockIdx = (seq_len, attn_heads, batches)
int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH;
// micro_batch_size might not be a multiple of WARP_BATCH. Check how
// many batches have to computed within this WARP.
int local_batches = micro_batch_size - first_batch;
if (local_batches > WARP_BATCH)
local_batches = WARP_BATCH;
// there might be multiple batches per warp. compute the index within the batch
int local_idx = threadIdx.x;
// the first element to process by the current thread
int thread_offset = first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx;
grad += thread_offset;
output += thread_offset;
gradInput += thread_offset;
// load data from global memory
acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS] { 0.0f };
acc_t output_reg[WARP_BATCH][WARP_ITERATIONS] { 0.0f };
input_t temp_grad[ELEMENTS_PER_LDG_STG];
input_t temp_output[ELEMENTS_PER_LDG_STG];
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
int batch_element_count = (i >= local_batches) ? 0 : element_count;
#pragma unroll
for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) {
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
if (element_index < batch_element_count) {
copy_vector<input_t, ELEMENTS_PER_LDG_STG>(temp_grad, grad + i * element_count + it * WARP_SIZE);
copy_vector<input_t, ELEMENTS_PER_LDG_STG>(temp_output, output + i * element_count + it * WARP_SIZE);
#pragma unroll
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
output_reg[i][it + element] = (acc_t)temp_output[element];
}
#pragma unroll
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
grad_reg[i][it + element] = (acc_t)temp_grad[element] * output_reg[i][it + element];
}
}
}
}
acc_t sum[WARP_BATCH];
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
sum[i] = grad_reg[i][0];
#pragma unroll
for (int it = 1; it < WARP_ITERATIONS; ++it) {
sum[i] += grad_reg[i][it];
}
}
warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Add>(sum);
// store result
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
if (i >= local_batches)
break;
#pragma unroll
for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) {
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
if (element_index < element_count) {
// compute gradients
output_t out[ELEMENTS_PER_LDG_STG];
#pragma unroll
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
out[element] = (output_t)(scale * (grad_reg[i][it + element] - output_reg[i][it + element] * sum[i]));
}
copy_vector<output_t, ELEMENTS_PER_LDG_STG>(gradInput + i * element_count + it * WARP_SIZE, out);
}
}
}
}
} // end of anonymous namespace
int get_batch_per_block(int query_seq_len, int key_seq_len, int batches, int attn_heads){
int log2_elements = log2_ceil(key_seq_len);
const int next_power_of_two = 1 << log2_elements;
int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;
constexpr int threads_per_block = 128;
int warps_per_block = (threads_per_block / warp_size);
int batches_per_block = warps_per_block * batches_per_warp;
return batches_per_block;
}
template<typename input_t, typename output_t, typename acc_t>
void dispatch_scaled_masked_softmax_forward(
output_t *dst,
const input_t *src,
const uint8_t *mask,
const input_t scale,
int query_seq_len,
int key_seq_len,
int batches,
int attn_heads,
int pad_batches)
{
TORCH_INTERNAL_ASSERT(key_seq_len >= 0 && key_seq_len <= 2048 );
if (key_seq_len == 0) {
return;
} else {
int log2_elements = log2_ceil(key_seq_len);
const int next_power_of_two = 1 << log2_elements;
int batch_count = batches * attn_heads * query_seq_len;
// This value must match the WARP_SIZE constexpr value computed inside softmax_warp_forward.
int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
// This value must match the WARP_BATCH constexpr value computed inside softmax_warp_forward.
int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;
// use 128 threads per block to maximimize gpu utilization
constexpr int threads_per_block = 128;
int warps_per_block = (threads_per_block / warp_size);
int batches_per_block = warps_per_block * batches_per_warp;
TORCH_INTERNAL_ASSERT(query_seq_len%batches_per_block == 0);
dim3 blocks(query_seq_len/batches_per_block, attn_heads, batches);
dim3 threads(warp_size, warps_per_block, 1);
// Launch code would be more elegant if C++ supported FOR CONSTEXPR
switch (log2_elements) {
case 0: // 1
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 0>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
break;
case 1: // 2
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 1>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
break;
case 2: // 4
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 2>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
break;
case 3: // 8
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 3>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
break;
case 4: // 16
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 4>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
break;
case 5: // 32
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 5>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
break;
case 6: // 64
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 6>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
break;
case 7: // 128
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 7>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
break;
case 8: // 256
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 8>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
break;
case 9: // 512
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 9>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
break;
case 10: // 1024
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 10>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
break;
case 11: // 2048
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 11>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
break;
default:
break;
}
}
}
template<typename input_t, typename output_t, typename acc_t>
void dispatch_scaled_masked_softmax_backward(
output_t *grad_input,
input_t *grad,
const input_t *output,
const acc_t scale,
int query_seq_len,
int key_seq_len,
int batches,
int attn_heads)
{
TORCH_INTERNAL_ASSERT( key_seq_len >= 0 && key_seq_len <= 2048 );
if (key_seq_len == 0) {
return;
} else {
int log2_elements = log2_ceil(key_seq_len);
const int next_power_of_two = 1 << log2_elements;
int batch_count = batches * attn_heads * query_seq_len;
// This value must match the WARP_SIZE constexpr value computed inside softmax_warp_backward.
int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
// This value must match the WARP_BATCH constexpr value computed inside softmax_warp_backward.
int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;
// use 128 threads per block to maximimize gpu utilization
constexpr int threads_per_block = 128;
int warps_per_block = (threads_per_block / warp_size);
int batches_per_block = warps_per_block * batches_per_warp;
int blocks = batch_count/batches_per_block;
dim3 threads(warp_size, warps_per_block, 1);
// Launch code would be more elegant if C++ supported FOR CONSTEXPR
switch (log2_elements) {
case 0: // 1
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 0>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
break;
case 1: // 2
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 1>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
break;
case 2: // 4
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 2>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
break;
case 3: // 8
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 3>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
break;
case 4: // 16
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 4>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
break;
case 5: // 32
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 5>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
break;
case 6: // 64
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 6>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
break;
case 7: // 128
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 7>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
break;
case 8: // 256
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 8>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
break;
case 9: // 512
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 9>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
break;
case 10: // 1024
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 10>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
break;
case 11: // 2048
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 11>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
break;
default:
break;
}
}
}

View file

@ -0,0 +1,117 @@
/* coding=utf-8
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <ATen/ATen.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cuda_profiler_api.h>
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include "scaled_masked_softmax.h"
#include "type_shim.h"
namespace multihead_attn {
namespace fused_softmax {
namespace scaled_masked_softmax {
int get_batch_per_block_cuda(int query_seq_len, int key_seq_len, int batches, int attn_heads){
return get_batch_per_block(query_seq_len, key_seq_len, batches, attn_heads);
}
torch::Tensor fwd_cuda(
torch::Tensor const& input,
torch::Tensor const& mask,
float scale_factor)
{
// input is a 4d tensor with dimensions [batches, attn_heads, seq_len, seq_len]
const int batches = input.size(0);
const int pad_batches = mask.size(0);
const int attn_heads = input.size(1);
const int query_seq_len = input.size(2);
const int key_seq_len = input.size(3);
TORCH_INTERNAL_ASSERT(key_seq_len <= 2048);
TORCH_INTERNAL_ASSERT(query_seq_len > 1);
TORCH_INTERNAL_ASSERT(pad_batches == 1 || pad_batches == batches);
TORCH_INTERNAL_ASSERT(mask.size(1) == 1);
TORCH_INTERNAL_ASSERT(mask.size(2) == query_seq_len);
TORCH_INTERNAL_ASSERT(mask.size(3) == key_seq_len);
// Output
auto act_options = input.options().requires_grad(false);
torch::Tensor softmax_results =
torch::empty({batches, attn_heads, query_seq_len, key_seq_len}, act_options);
// Softmax Intermediate Result Ptr
void* input_ptr = static_cast<void*>(input.data_ptr());
void* mask_ptr = static_cast<void*>(mask.data_ptr());
void* softmax_results_ptr = static_cast<void*>(softmax_results.data_ptr());
DISPATCH_HALF_AND_BFLOAT(
input.scalar_type(),
"dispatch_scaled_masked_softmax_forward",
dispatch_scaled_masked_softmax_forward<scalar_t, scalar_t, float>(
reinterpret_cast<scalar_t*>(softmax_results_ptr),
reinterpret_cast<const scalar_t*>(input_ptr),
reinterpret_cast<const uint8_t*>(mask_ptr),
scale_factor,
query_seq_len,
key_seq_len,
batches,
attn_heads,
pad_batches);
);
return softmax_results;
}
torch::Tensor bwd_cuda(
torch::Tensor const& output_grads_,
torch::Tensor const& softmax_results_,
float scale_factor) {
auto output_grads = output_grads_.contiguous();
auto softmax_results = softmax_results_.contiguous();
//output grads is a 4d tensor with dimensions [batches, attn_heads, seq_len, seq_len]
const int batches = output_grads.size(0);
const int attn_heads = output_grads.size(1);
const int query_seq_len = output_grads.size(2);
const int key_seq_len = output_grads.size(3);
void* output_grads_ptr = static_cast<void*>(output_grads.data_ptr());
//Softmax Grad
DISPATCH_HALF_AND_BFLOAT(
output_grads_.scalar_type(),
"dispatch_scaled_masked_softmax_backward",
dispatch_scaled_masked_softmax_backward<scalar_t, scalar_t, float>(
reinterpret_cast<scalar_t*>(output_grads_ptr),
reinterpret_cast<scalar_t*>(output_grads_ptr),
reinterpret_cast<scalar_t const*>(softmax_results.data_ptr()),
scale_factor,
query_seq_len,
key_seq_len,
batches,
attn_heads);
);
//backward pass is completely in-place
return output_grads;
}
}
}
}

View file

@ -0,0 +1,72 @@
/* coding=utf-8
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <cuda_fp16.h>
#include <torch/extension.h>
#include <vector>
namespace multihead_attn {
namespace fused_softmax {
namespace scaled_upper_triang_masked_softmax {
torch::Tensor fwd_cuda(
torch::Tensor const& input,
float scale_factor);
torch::Tensor bwd_cuda(
torch::Tensor const& output_grads,
torch::Tensor const& softmax_results,
float scale_factor);
torch::Tensor fwd(torch::Tensor const& input, float scale_factor) {
AT_ASSERTM(input.dim() == 3, "expected 3D tensor");
AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) ||
(input.scalar_type() == at::ScalarType::BFloat16),
"Only fp16 and bf16 are supported");
return fwd_cuda(input, scale_factor);
}
torch::Tensor bwd(
torch::Tensor const& output_grads,
torch::Tensor const& softmax_results,
float scale_factor) {
AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor");
AT_ASSERTM(softmax_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM((output_grads.scalar_type() == at::ScalarType::Half) ||
(output_grads.scalar_type() == at::ScalarType::BFloat16),
"Only fp16 and bf16 are supported");
AT_ASSERTM((softmax_results.scalar_type() == at::ScalarType::Half) ||
(softmax_results.scalar_type() == at::ScalarType::BFloat16),
"Only fp16 and bf16 are supported");
return bwd_cuda(output_grads, softmax_results, scale_factor);
}
} // end namespace scaled_upper_triang_masked_softmax
} // end namespace fused_softmax
} // end namespace multihead_attn
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward",
&multihead_attn::fused_softmax::scaled_upper_triang_masked_softmax::fwd,
"Self Multihead Attention scaled, time masked softmax -- Forward.");
m.def("backward",
&multihead_attn::fused_softmax::scaled_upper_triang_masked_softmax::bwd,
"Self Multihead Attention scaled, time masked softmax -- Backward.");
}

View file

@ -0,0 +1,513 @@
/* coding=utf-8
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <assert.h>
#include <cuda_fp16.h>
#include <cfloat>
#include <limits>
#include <stdint.h>
#include <c10/macros/Macros.h>
namespace {
template <typename Datatype, int ELEMENTS_PER_LDG>
__device__ __inline__ void copy_vector(Datatype *dst, const Datatype *src);
template <>
__device__ __inline__ void copy_vector<c10::BFloat16, 1>(c10::BFloat16 *dst, const c10::BFloat16 *src) { *dst = *src; }
template <>
__device__ __inline__ void copy_vector<c10::BFloat16, 4>(c10::BFloat16 *dst, const c10::BFloat16 *src) { *((float2*) dst) = *((float2*) src); }
template <>
__device__ __inline__ void copy_vector<c10::Half, 1>(c10::Half *dst, const c10::Half *src) { *dst = *src; }
template <>
__device__ __inline__ void copy_vector<c10::Half, 4>(c10::Half *dst, const c10::Half *src) { *((float2*) dst) = *((float2*) src); }
template <>
__device__ __inline__ void copy_vector<uint8_t, 1>(uint8_t *dst, const uint8_t *src) { *dst = *src; }
template <>
__device__ __inline__ void copy_vector<uint8_t, 4>(uint8_t *dst, const uint8_t *src) {*((half2*) dst) = *((half2*) src); }
template <typename Datatype, int ELEMENTS_PER_LDG>
__device__ __inline__ void copy_zero_vector(Datatype *dst);
template <>
__device__ __inline__ void copy_zero_vector<c10::BFloat16, 1>(c10::BFloat16 *dst) { *dst = 0.0; }
template <>
__device__ __inline__ void copy_zero_vector<c10::BFloat16, 4>(c10::BFloat16 *dst) { *((float2*) dst) = make_float2(0.0f, 0.0f); }
template <>
__device__ __inline__ void copy_zero_vector<c10::Half, 1>(c10::Half *dst) { *dst = 0.0; }
template <>
__device__ __inline__ void copy_zero_vector<c10::Half, 4>(c10::Half *dst) { *((float2*) dst) = make_float2(0.0f, 0.0f); }
int log2_ceil(int value) {
int log2_value = 0;
while ((1 << log2_value) < value) ++log2_value;
return log2_value;
}
template<typename T>
struct Add {
__device__ __forceinline__ T operator()(T a, T b) const {
return a + b;
}
};
template<typename T>
struct Max {
__device__ __forceinline__ T operator()(T a, T b) const {
return a < b ? b : a;
}
};
template <typename T>
__device__ __forceinline__ T WARP_SHFL_XOR_NATIVE(T value, int laneMask, int width = warpSize, unsigned int mask = 0xffffffff)
{
#if CUDA_VERSION >= 9000
return __shfl_xor_sync(mask, value, laneMask, width);
#else
return __shfl_xor(value, laneMask, width);
#endif
}
template <typename acc_t, int WARP_BATCH, int WARP_SIZE, template<typename> class ReduceOp>
__device__ __forceinline__ void warp_reduce(acc_t* sum) {
ReduceOp<acc_t> r;
#pragma unroll
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
acc_t b = WARP_SHFL_XOR_NATIVE(sum[i], offset, WARP_SIZE);
sum[i] = r(sum[i], b);
}
}
}
/*
* Extended softmax (from native aten pytorch) with following additional features
* 1) input scaling
* 2) Implicit time (diagonal masking)
*/
template <typename input_t, typename output_t, typename acc_t, int log2_elements>
__global__ void scaled_upper_triang_masked_softmax_warp_forward(
output_t *dst,
const input_t *src,
const acc_t scale,
int micro_batch_size,
int stride,
int element_count)
{
// WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and
// warp_size of method warp_softmax_forward_kernel.
constexpr int next_power_of_two = 1 << log2_elements;
constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE;
constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1;
constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4;
int first_batch = (blockDim.y * blockIdx.y + threadIdx.y) * gridDim.x * WARP_BATCH + blockIdx.x;
int local_seq = blockIdx.x + 1;
int warp_iteration_limit = (local_seq + ELEMENTS_PER_LDG_STG * WARP_SIZE - 1)/ WARP_SIZE;
// micro_batch_size might not be a multiple of WARP_BATCH. Check how
// many batches have to computed within this WARP.
int local_batches = micro_batch_size - first_batch;
if (local_batches > WARP_BATCH)
local_batches = WARP_BATCH;
// there might be multiple batches per warp. compute the index within the batch
int local_idx = threadIdx.x;
src += first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx;
dst += first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx;
// load data from global memory
acc_t elements[WARP_BATCH][WARP_ITERATIONS];
input_t temp_data[ELEMENTS_PER_LDG_STG];
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
int batch_element_count = (i >= local_batches) ? 0 : local_seq;
#pragma unroll
for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) {
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
if (element_index < batch_element_count) {
copy_vector<input_t, ELEMENTS_PER_LDG_STG>(temp_data, src + i*element_count*stride + it*WARP_SIZE);
#pragma unroll
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
if ((element_index + element) < batch_element_count) {
elements[i][it+element] = (acc_t)temp_data[element] * scale;
} else {
elements[i][it + element] = -std::numeric_limits<acc_t>::infinity();
}
}
} else {
#pragma unroll
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
elements[i][it + element] = -std::numeric_limits<acc_t>::infinity();
}
}
}
}
// compute max_value
acc_t max_value[WARP_BATCH];
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
max_value[i] = elements[i][0];
#pragma unroll
for (int it = 1; it < WARP_ITERATIONS; ++it) {
max_value[i] = (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it];
}
}
warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Max>(max_value);
acc_t sum[WARP_BATCH] { 0.0f };
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
#pragma unroll
for (int it = 0; it < WARP_ITERATIONS; ++it) {
if (it < warp_iteration_limit) {
elements[i][it] = std::exp((elements[i][it] - max_value[i]));
sum[i] += elements[i][it];
}
}
}
warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Add>(sum);
// store result
output_t out[ELEMENTS_PER_LDG_STG];
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
if (i >= local_batches)
break;
#pragma unroll
for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) {
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
if (element_index < local_seq) {
#pragma unroll
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
if (element_index + element < local_seq) {
out[element] = elements[i][it + element] / sum[i];
} else {
out[element] = 0;
}
}
copy_vector<output_t, ELEMENTS_PER_LDG_STG>(dst + i * element_count * stride + it * WARP_SIZE, out);
} else if (element_index < element_count) {
copy_zero_vector<output_t, ELEMENTS_PER_LDG_STG>(dst + i * element_count * stride + it * WARP_SIZE);
} else {
break;
}
}
}
}
template <typename input_t, typename output_t, typename acc_t, int log2_elements>
__global__ void scaled_upper_triang_masked_softmax_warp_backward(
output_t *gradInput,
input_t *grad,
const input_t *output,
acc_t scale,
int micro_batch_size,
int stride,
int element_count)
{
// WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and
// warp_size of method warp_softmax_backward_kernel.
constexpr int next_power_of_two = 1 << log2_elements;
constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE;
constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1;
constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4;
int first_batch = (blockDim.y * blockIdx.y + threadIdx.y) * gridDim.x * WARP_BATCH + blockIdx.x;
int local_seq = blockIdx.x + 1;
// micro_batch_size might not be a multiple of WARP_BATCH. Check how
// many batches have to computed within this WARP.
int local_batches = micro_batch_size - first_batch;
if (local_batches > WARP_BATCH)
local_batches = WARP_BATCH;
// there might be multiple batches per warp. compute the index within the batch
int local_idx = threadIdx.x;
// the first element to process by the current thread
int thread_offset = first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx;
grad += thread_offset;
output += thread_offset;
gradInput += thread_offset;
// load data from global memory
acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS] { 0.0f };
acc_t output_reg[WARP_BATCH][WARP_ITERATIONS] { 0.0f };
input_t temp_grad[ELEMENTS_PER_LDG_STG];
input_t temp_output[ELEMENTS_PER_LDG_STG];
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
int batch_element_count = (i >= local_batches) ? 0 : local_seq;
#pragma unroll
for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) {
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
if (element_index < batch_element_count) {
copy_vector<input_t, ELEMENTS_PER_LDG_STG>(temp_grad, grad + i * element_count * stride + it * WARP_SIZE);
copy_vector<input_t, ELEMENTS_PER_LDG_STG>(temp_output, output + i * element_count * stride + it * WARP_SIZE);
#pragma unroll
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
if (element_index + element < batch_element_count) {
output_reg[i][it + element] = (acc_t)temp_output[element];
}
}
#pragma unroll
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
if (element_index + element < batch_element_count) {
grad_reg[i][it + element] = (acc_t)temp_grad[element] * output_reg[i][it + element];
}
}
}
}
}
acc_t sum[WARP_BATCH];
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
sum[i] = grad_reg[i][0];
#pragma unroll
for (int it = 1; it < WARP_ITERATIONS; ++it) {
sum[i] += grad_reg[i][it];
}
}
warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Add>(sum);
// store result
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
if (i >= local_batches)
break;
#pragma unroll
for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) {
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
if (element_index < element_count) {
// compute gradients
output_t out[ELEMENTS_PER_LDG_STG];
#pragma unroll
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
out[element] = (output_t)(scale * (grad_reg[i][it + element] - output_reg[i][it + element] * sum[i]));
}
copy_vector<output_t, ELEMENTS_PER_LDG_STG>(gradInput + i * element_count * stride + it * WARP_SIZE, out);
}
}
}
}
} // end of anonymous namespace
template<typename input_t, typename output_t, typename acc_t>
void dispatch_scaled_upper_triang_masked_softmax_forward(
output_t *dst,
const input_t *src,
const input_t scale,
int softmax_elements,
int softmax_elements_stride,
int attn_batches)
{
TORCH_INTERNAL_ASSERT(softmax_elements >= 0 && softmax_elements <= 2048 );
if (softmax_elements == 0) {
return;
} else {
int log2_elements = log2_ceil(softmax_elements);
const int next_power_of_two = 1 << log2_elements;
int seq_len = softmax_elements;
int batch_count = attn_batches * seq_len;
// This value must match the WARP_SIZE constexpr value computed inside softmax_warp_forward.
int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
// This value must match the WARP_BATCH constexpr value computed inside softmax_warp_forward.
int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;
// use 128 threads per block to maximimize gpu utilization
constexpr int threads_per_block = 128;
int warps_per_block = (threads_per_block / warp_size);
int batches_per_block = warps_per_block * batches_per_warp;
TORCH_INTERNAL_ASSERT(attn_batches % batches_per_block == 0);
int blocks_per_seq = attn_batches / batches_per_block;
dim3 blocks(seq_len, blocks_per_seq, 1);
dim3 threads(warp_size, warps_per_block, 1);
// Launch code would be more elegant if C++ supported FOR CONSTEXPR
switch (log2_elements) {
case 0: // 1
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 0>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 1: // 2
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 1>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 2: // 4
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 2>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 3: // 8
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 3>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 4: // 16
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 4>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 5: // 32
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 5>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 6: // 64
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 6>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 7: // 128
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 7>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 8: // 256
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 8>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 9: // 512
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 9>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 10: // 1024
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 10>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 11: // 2048
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 11>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
default:
break;
}
}
}
template<typename input_t, typename output_t, typename acc_t>
void dispatch_scaled_upper_triang_masked_softmax_backward(
output_t *grad_input,
input_t *grad,
const input_t *output,
const acc_t scale,
int softmax_elements,
int softmax_elements_stride,
int attn_batches)
{
TORCH_INTERNAL_ASSERT( softmax_elements >= 0 && softmax_elements <= 2048 );
if (softmax_elements == 0) {
return;
} else {
int log2_elements = log2_ceil(softmax_elements);
const int next_power_of_two = 1 << log2_elements;
int seq_len = softmax_elements;
int batch_count = attn_batches * seq_len;
// This value must match the WARP_SIZE constexpr value computed inside softmax_warp_backward.
int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
// This value must match the WARP_BATCH constexpr value computed inside softmax_warp_backward.
int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;
// use 128 threads per block to maximimize gpu utilization
constexpr int threads_per_block = 128;
int warps_per_block = (threads_per_block / warp_size);
int batches_per_block = warps_per_block * batches_per_warp;
TORCH_INTERNAL_ASSERT(attn_batches % batches_per_block == 0);
int blocks_per_seq = attn_batches / batches_per_block;
dim3 blocks(seq_len, blocks_per_seq, 1);
dim3 threads(warp_size, warps_per_block, 1);
// Launch code would be more elegant if C++ supported FOR CONSTEXPR
switch (log2_elements) {
case 0: // 1
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 0>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 1: // 2
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 1>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 2: // 4
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 2>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 3: // 8
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 3>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 4: // 16
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 4>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 5: // 32
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 5>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 6: // 64
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 6>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 7: // 128
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 7>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 8: // 256
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 8>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 9: // 512
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 9>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 10: // 1024
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 10>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 11: // 2048
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 11>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
default:
break;
}
}
}

View file

@ -0,0 +1,98 @@
/* coding=utf-8
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <ATen/ATen.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cuda_profiler_api.h>
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include "scaled_upper_triang_masked_softmax.h"
#include "type_shim.h"
namespace multihead_attn {
namespace fused_softmax {
namespace scaled_upper_triang_masked_softmax {
torch::Tensor fwd_cuda(
torch::Tensor const& input,
float scale_factor)
{
// input is a 3d tensor with dimensions [attn_batches, seq_len, seq_len]
const int attn_batches = input.size(0);
const int seq_len = input.size(1);
TORCH_INTERNAL_ASSERT(seq_len <= 2048);
// Output
auto act_options = input.options().requires_grad(false);
torch::Tensor softmax_results =
torch::empty({attn_batches, seq_len, seq_len}, act_options);
// Softmax Intermediate Result Ptr
void* input_ptr = static_cast<void*>(input.data_ptr());
void* softmax_results_ptr = static_cast<void*>(softmax_results.data_ptr());
DISPATCH_HALF_AND_BFLOAT(
input.scalar_type(),
"dispatch_scaled_upper_triang_masked_softmax_forward",
dispatch_scaled_upper_triang_masked_softmax_forward<scalar_t, scalar_t, float>(
reinterpret_cast<scalar_t*>(softmax_results_ptr),
reinterpret_cast<const scalar_t*>(input_ptr),
scale_factor,
seq_len,
seq_len,
attn_batches);
);
return softmax_results;
}
torch::Tensor bwd_cuda(
torch::Tensor const& output_grads_,
torch::Tensor const& softmax_results_,
float scale_factor) {
auto output_grads = output_grads_.contiguous();
auto softmax_results = softmax_results_.contiguous();
//output grads is a 3d tensor with dimensions [attn_batches, seq_len, seq_len]
const int attn_batches = output_grads.size(0);
const int seq_len = output_grads.size(1);
TORCH_INTERNAL_ASSERT(output_grads.size(1) == output_grads.size(2));
void* output_grads_ptr = static_cast<void*>(output_grads.data_ptr());
//Softmax Grad
DISPATCH_HALF_AND_BFLOAT(
output_grads_.scalar_type(),
"dispatch_scaled_upper_triang_masked_softmax_backward",
dispatch_scaled_upper_triang_masked_softmax_backward<scalar_t, scalar_t, float>(
reinterpret_cast<scalar_t*>(output_grads_ptr),
reinterpret_cast<scalar_t*>(output_grads_ptr),
reinterpret_cast<scalar_t const*>(softmax_results.data_ptr()),
scale_factor,
seq_len,
seq_len,
attn_batches);
);
//backward pass is completely in-place
return output_grads;
}
}
}
}

View file

@ -0,0 +1,91 @@
/* coding=utf-8
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <ATen/ATen.h>
#include "compat.h"
#define DISPATCH_HALF_AND_BFLOAT(TYPE, NAME, ...) \
switch(TYPE) \
{ \
case at::ScalarType::Half: \
{ \
using scalar_t = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::BFloat16: \
{ \
using scalar_t = at::BFloat16; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(TYPEIN, TYPEOUT, NAME, ...) \
switch(TYPEIN) \
{ \
case at::ScalarType::Float: \
{ \
using scalar_t_in = float; \
switch(TYPEOUT) \
{ \
case at::ScalarType::Float: \
{ \
using scalar_t_out = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: \
{ \
using scalar_t_out = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::BFloat16: \
{ \
using scalar_t_out = at::BFloat16; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPEOUT), "'"); \
} \
break; \
} \
case at::ScalarType::Half: \
{ \
using scalar_t_in = at::Half; \
using scalar_t_out = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::BFloat16: \
{ \
using scalar_t_in = at::BFloat16; \
using scalar_t_out = at::BFloat16; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPEIN), "'"); \
}

View file

@ -0,0 +1,211 @@
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn as nn
from nemo.collections.nlp.modules.common.megatron.enums import AttnMaskType
from nemo.collections.nlp.modules.common.megatron.utils import AutocastModuleWrapper
class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function):
"""
Fused operation which performs following three operations in sequence
1. Scale the tensor.
2. Apply upper triangular mask (typically used in gpt models).
3. Perform softmax.
"""
@staticmethod
def forward(ctx, inputs, scale):
import scaled_upper_triang_masked_softmax_cuda
scale_t = torch.tensor([scale])
softmax_results = scaled_upper_triang_masked_softmax_cuda.forward(inputs, scale_t[0])
ctx.save_for_backward(softmax_results, scale_t)
return softmax_results
@staticmethod
def backward(ctx, output_grads):
import scaled_upper_triang_masked_softmax_cuda
softmax_results, scale_t = ctx.saved_tensors
input_grads = scaled_upper_triang_masked_softmax_cuda.backward(output_grads, softmax_results, scale_t[0])
return input_grads, None
class FusedScaledUpperTriangMaskedSoftmax(AutocastModuleWrapper):
def __init__(self, fp16=False, bf16=True):
super(FusedScaledUpperTriangMaskedSoftmax, self).__init__(fp16, bf16)
self.func = ScaledUpperTriangMaskedSoftmax
def forward(self, inputs, scale):
return self.autocast_forward(inputs, scale)
class ScaledMaskedSoftmax(torch.autograd.Function):
"""
Fused operation which performs following three operations in sequence
1. Scale the tensor.
2. Apply the mask.
3. Perform softmax.
"""
@staticmethod
def forward(ctx, inputs, mask, scale):
import scaled_masked_softmax_cuda
scale_t = torch.tensor([scale])
softmax_results = scaled_masked_softmax_cuda.forward(inputs, mask, scale_t[0])
ctx.save_for_backward(softmax_results, scale_t)
return softmax_results
@staticmethod
def backward(ctx, output_grads):
import scaled_masked_softmax_cuda
softmax_results, scale_t = ctx.saved_tensors
input_grads = scaled_masked_softmax_cuda.backward(output_grads, softmax_results, scale_t[0])
return input_grads, None, None
class FusedScaledMaskedSoftmax(AutocastModuleWrapper):
def __init__(self, fp16=False, bf16=True):
super(FusedScaledMaskedSoftmax, self).__init__(fp16, bf16)
self.fused_func = ScaledMaskedSoftmax
def forward(self, inputs, mask, scale):
return self.autocast_forward(inputs, mask, scale)
class FusedScaleMaskSoftmax(nn.Module):
"""
fused operation: scaling + mask + softmax
Arguments:
input_in_fp16: flag to indicate if input in fp16 data format.
input_in_bf16: flag to indicate if input in bf16 data format.
attn_mask_type: attention mask type (pad or causal)
scaled_masked_softmax_fusion: flag to indicate user want to use softmax fusion
mask_func: mask function to be applied.
softmax_in_fp32: if true, softmax in performed at fp32 precision.
scale: scaling factor used in input tensor scaling.
"""
def __init__(
self,
input_in_fp16,
input_in_bf16,
attn_mask_type,
scaled_masked_softmax_fusion,
mask_func,
softmax_in_fp32,
scale,
):
super(FusedScaleMaskSoftmax, self).__init__()
self.input_in_fp16 = input_in_fp16
self.input_in_bf16 = input_in_bf16
assert not (
self.input_in_fp16 and self.input_in_bf16
), "both fp16 and bf16 flags cannot be active at the same time."
self.input_in_float16 = self.input_in_fp16 or self.input_in_bf16
self.attn_mask_type = attn_mask_type
self.scaled_masked_softmax_fusion = scaled_masked_softmax_fusion
self.mask_func = mask_func
self.softmax_in_fp32 = softmax_in_fp32
self.scale = scale
assert self.scale is None or softmax_in_fp32, "softmax should be in fp32 when scaled"
self.fused_scaled_uppper_triang_masked_softmax = FusedScaledUpperTriangMaskedSoftmax(
input_in_fp16, input_in_bf16
)
self.fused_scaled_masked_softmax = ScaledMaskedSoftmax(input_in_fp16, input_in_bf16)
def forward(self, input, mask):
# [b, np, sq, sk]
assert input.dim() == 4
if self.is_kernel_available(mask, *input.size()):
return self.forward_fused_softmax(input, mask)
else:
return self.forward_torch_softmax(input, mask)
def is_kernel_available(self, mask, b, np, sq, sk):
attn_batches = b * np
if (
self.scaled_masked_softmax_fusion # user want to fuse
and self.input_in_float16 # input must be fp16
and mask is not None # mask tensor must not be None
and 16 < sk <= 2048 # sk must be 16 ~ 2048
and sq % 4 == 0 # sq must be divisor of 4
and attn_batches % 4 == 0 # np * b must be divisor of 4
):
if 0 <= sk <= 2048:
batch_per_block = self.get_batch_per_block(sq, sk, b, np)
if self.attn_mask_type == AttnMaskType.causal:
if attn_batches % batch_per_block == 0:
return True
else:
if sq % batch_per_block == 0:
return True
return False
def forward_fused_softmax(self, input, mask):
b, np, sq, sk = input.size()
scale = self.scale if self.scale is not None else 1.0
if self.attn_mask_type == AttnMaskType.causal:
assert sq == sk, "causal mask is only for self attention"
# input is 3D tensor (attn_batches, sq, sk)
input = input.view(-1, sq, sk)
probs = self.fused_scaled_uppper_triang_masked_softmax(input, scale)
return probs.view(b, np, sq, sk)
else:
# input is 4D tensor (b, np, sq, sk)
return self.fused_scaled_masked_softmax(input, mask, scale)
def forward_torch_softmax(self, input, mask):
if self.input_in_float16 and self.softmax_in_fp32:
input = input.float()
if self.scale is not None:
input = input * self.scale
mask_output = self.mask_func(input, mask) if mask is not None else input
probs = torch.nn.Softmax(dim=-1)(mask_output)
if self.input_in_float16 and self.softmax_in_fp32:
if self.input_in_fp16:
probs = probs.half()
else:
probs = probs.bfloat16()
return probs
@staticmethod
def get_batch_per_block(sq, sk, b, np):
import scaled_masked_softmax_cuda
return scaled_masked_softmax_cuda.get_batch_per_block(sq, sk, b, np)

View file

@ -0,0 +1,581 @@
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Transformer based language model."""
import torch
import torch.nn.functional as F
from apex.transformer import parallel_state, tensor_parallel
from apex.transformer.enums import AttnMaskType, LayerType
from nemo.collections.nlp.modules.common.megatron.module import MegatronModule
from nemo.collections.nlp.modules.common.megatron.transformer import ParallelTransformer
from nemo.collections.nlp.modules.common.megatron.utils import (
get_linear_layer,
init_method_normal,
scaled_init_method_normal,
)
def parallel_lm_logits(input_, word_embeddings_weight, parallel_output, bias=None):
"""LM logits using word embedding weights."""
# Parallel logits.
input_parallel = tensor_parallel.copy_to_tensor_model_parallel_region(input_)
# Matrix multiply.
if bias is None:
logits_parallel = F.linear(input_parallel, word_embeddings_weight)
else:
logits_parallel = F.linear(input_parallel, word_embeddings_weight, bias)
# Gather if needed.
if parallel_output:
return logits_parallel
return tensor_parallel.gather_from_tensor_model_parallel_region(logits_parallel)
def get_language_model(
hidden_size,
ffn_hidden_size,
num_layers,
max_position_embeddings,
num_tokentypes,
add_pooler,
vocab_size,
num_attention_heads,
encoder_attn_mask_type,
apply_query_key_layer_scaling=True,
kv_channels=None,
init_method=None,
scaled_init_method=None,
add_decoder=False,
decoder_attn_mask_type=AttnMaskType.causal,
pre_process=True,
post_process=True,
init_method_std=0.02,
use_cpu_initialization=False,
hidden_dropout=0.1,
fused_fp16=False,
fused_bf16=False,
fp32_residual_connection=False,
activations_checkpoint_method=None,
activations_checkpoint_num_layers=1,
layernorm_epsilon=1e-5,
bias_gelu_fusion=True,
openai_gelu=False,
onnx_safe=False,
):
"""Build language model and return along with the key to save."""
assert not (fused_fp16 and fused_bf16), "both fused_fp16 and fused_bf16 flags cannot be True at the same time."
if kv_channels is None:
assert (
hidden_size % num_attention_heads == 0
), 'hidden_size must be divisible by num_attention_heads if kv_channels is None'
kv_channels = hidden_size // num_attention_heads
if init_method is None:
init_method = init_method_normal(init_method_std)
if scaled_init_method is None:
scaled_init_method = scaled_init_method_normal(init_method_std, num_layers)
# Language model.
language_model = TransformerLanguageModel(
init_method=init_method,
output_layer_init_method=scaled_init_method,
encoder_attn_mask_type=encoder_attn_mask_type,
num_tokentypes=num_tokentypes,
vocab_size=vocab_size,
max_position_embeddings=max_position_embeddings,
hidden_size=hidden_size,
num_layers=num_layers,
num_attention_heads=num_attention_heads,
apply_query_key_layer_scaling=apply_query_key_layer_scaling,
kv_channels=kv_channels,
ffn_hidden_size=ffn_hidden_size,
add_decoder=add_decoder,
decoder_attn_mask_type=decoder_attn_mask_type,
add_pooler=add_pooler,
pre_process=pre_process,
post_process=post_process,
use_cpu_initialization=use_cpu_initialization,
hidden_dropout=hidden_dropout,
fused_fp16=fused_fp16,
fused_bf16=fused_bf16,
fp32_residual_connection=fp32_residual_connection,
activations_checkpoint_method=activations_checkpoint_method,
activations_checkpoint_num_layers=activations_checkpoint_num_layers,
layernorm_epsilon=layernorm_epsilon,
bias_gelu_fusion=bias_gelu_fusion,
openai_gelu=openai_gelu,
onnx_safe=onnx_safe,
)
# key used for checkpoints.
language_model_key = 'language_model'
return language_model, language_model_key
class Pooler(MegatronModule):
"""Pooler layer.
Pool hidden states of a specific token (for example start of the
sequence) and add a linear transformation followed by a tanh.
Arguments:
hidden_size: hidden size
init_method: weight initialization method for the linear layer.
bias is set to zero.
"""
def __init__(self, hidden_size, init_method):
super(Pooler, self).__init__()
self.dense = get_linear_layer(hidden_size, hidden_size, init_method)
def forward(self, hidden_states, sequence_index=0):
# hidden_states: [b, s, h]
# sequence_index: index of the token to pool.
pooled = hidden_states[:, sequence_index, :]
pooled = self.dense(pooled)
pooled = torch.tanh(pooled)
return pooled
class Embedding(MegatronModule):
"""Language model embeddings.
Arguments:
hidden_size: hidden size
vocab_size: vocabulary size
max_sequence_length: maximum size of sequence. This
is used for positional embedding
embedding_dropout_prob: dropout probability for embeddings
init_method: weight initialization method
num_tokentypes: size of the token-type embeddings. 0 value
will ignore this embedding
"""
def __init__(
self,
hidden_size,
vocab_size,
max_sequence_length,
embedding_dropout_prob,
init_method,
num_tokentypes=0,
use_cpu_initialization=False,
):
super(Embedding, self).__init__()
self.hidden_size = hidden_size
self.init_method = init_method
self.num_tokentypes = num_tokentypes
# Word embeddings (parallel).
self.word_embeddings = tensor_parallel.VocabParallelEmbedding(
vocab_size, self.hidden_size, init_method=self.init_method, use_cpu_initialization=use_cpu_initialization,
)
self._word_embeddings_key = 'word_embeddings'
# Position embedding (serial).
self.position_embeddings = torch.nn.Embedding(max_sequence_length, self.hidden_size)
self._position_embeddings_key = 'position_embeddings'
# Initialize the position embeddings.
self.init_method(self.position_embeddings.weight)
# Token type embedding.
# Add this as an optional field that can be added through
# method call so we can load a pretrain model without
# token types and add them as needed.
self._tokentype_embeddings_key = 'tokentype_embeddings'
if self.num_tokentypes > 0:
self.tokentype_embeddings = torch.nn.Embedding(self.num_tokentypes, self.hidden_size)
# Initialize the token-type embeddings.
self.init_method(self.tokentype_embeddings.weight)
else:
self.tokentype_embeddings = None
# Embeddings dropout
self.embedding_dropout = torch.nn.Dropout(embedding_dropout_prob)
def add_tokentype_embeddings(self, num_tokentypes):
"""Add token-type embedding. This function is provided so we can add
token-type embeddings in case the pretrained model does not have it.
This allows us to load the model normally and then add this embedding.
"""
if self.tokentype_embeddings is not None:
raise Exception('tokentype embeddings is already initialized')
if torch.distributed.get_rank() == 0:
print('adding embedding for {} tokentypes'.format(num_tokentypes), flush=True)
self.num_tokentypes = num_tokentypes
self.tokentype_embeddings = torch.nn.Embedding(num_tokentypes, self.hidden_size)
# Initialize the token-type embeddings.
self.init_method(self.tokentype_embeddings.weight)
def forward(self, input_ids, position_ids, tokentype_ids=None):
# Embeddings.
words_embeddings = self.word_embeddings(input_ids)
position_embeddings = self.position_embeddings(position_ids)
embeddings = words_embeddings + position_embeddings
if tokentype_ids is not None:
assert self.tokentype_embeddings is not None
embeddings = embeddings + self.tokentype_embeddings(tokentype_ids)
else:
assert self.tokentype_embeddings is None
# Dropout.
embeddings = self.embedding_dropout(embeddings)
return embeddings
def state_dict_for_save_checkpoint(self, destination=None, prefix='', keep_vars=False):
"""For easy load."""
state_dict_ = {}
state_dict_[self._word_embeddings_key] = self.word_embeddings.state_dict(destination, prefix, keep_vars)
state_dict_[self._position_embeddings_key] = self.position_embeddings.state_dict(
destination, prefix, keep_vars
)
if self.num_tokentypes > 0:
state_dict_[self._tokentype_embeddings_key] = self.tokentype_embeddings.state_dict(
destination, prefix, keep_vars
)
return state_dict_
def load_state_dict(self, state_dict, strict=True):
"""Customized load."""
# Word embedding.
if self._word_embeddings_key in state_dict:
state_dict_ = state_dict[self._word_embeddings_key]
else:
# for backward compatibility.
state_dict_ = {}
for key in state_dict.keys():
if 'word_embeddings' in key:
state_dict_[key.split('word_embeddings.')[1]] = state_dict[key]
self.word_embeddings.load_state_dict(state_dict_, strict=strict)
# Position embedding.
if self._position_embeddings_key in state_dict:
state_dict_ = state_dict[self._position_embeddings_key]
else:
# for backward compatibility.
state_dict_ = {}
for key in state_dict.keys():
if 'position_embeddings' in key:
state_dict_[key.split('position_embeddings.')[1]] = state_dict[key]
self.position_embeddings.load_state_dict(state_dict_, strict=strict)
# Tokentype embedding.
if self.num_tokentypes > 0:
state_dict_ = {}
if self._tokentype_embeddings_key in state_dict:
state_dict_ = state_dict[self._tokentype_embeddings_key]
else:
# for backward compatibility.
for key in state_dict.keys():
if 'tokentype_embeddings' in key:
state_dict_[key.split('tokentype_embeddings.')[1]] = state_dict[key]
if len(state_dict_.keys()) > 0:
self.tokentype_embeddings.load_state_dict(state_dict_, strict=strict)
else:
print(
'***WARNING*** expected tokentype embeddings in the ' 'checkpoint but could not find it',
flush=True,
)
class TransformerLanguageModel(MegatronModule):
"""Transformer language model.
Arguments:
transformer_hparams: transformer hyperparameters
vocab_size: vocabulary size
max_sequence_length: maximum size of sequence. This
is used for positional embedding
embedding_dropout_prob: dropout probability for embeddings
num_tokentypes: size of the token-type embeddings. 0 value
will ignore this embedding
"""
def __init__(
self,
init_method,
output_layer_init_method,
encoder_attn_mask_type,
vocab_size,
max_position_embeddings,
hidden_size,
ffn_hidden_size,
num_layers,
num_tokentypes,
num_attention_heads,
apply_query_key_layer_scaling=True,
kv_channels=None,
add_decoder=False,
decoder_attn_mask_type=AttnMaskType.causal,
add_pooler=False,
pre_process=True,
post_process=True,
use_cpu_initialization=False,
hidden_dropout=0.1,
fused_fp16=False,
fused_bf16=False,
fp32_residual_connection=False,
activations_checkpoint_method=None,
activations_checkpoint_num_layers=1,
layernorm_epsilon=1e-5,
bias_gelu_fusion=True,
openai_gelu=False,
onnx_safe=False,
):
super(TransformerLanguageModel, self).__init__()
self.pre_process = pre_process
self.post_process = post_process
self.hidden_size = hidden_size
self.num_layers = num_layers
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.num_tokentypes = num_tokentypes
self.init_method = init_method
self.encoder_attn_mask_type = encoder_attn_mask_type
self.add_decoder = add_decoder
self.decoder_attn_mask_type = decoder_attn_mask_type
self.add_pooler = add_pooler
self.hidden_dropout = hidden_dropout
self.output_layer_init_method = output_layer_init_method
if kv_channels is None:
assert (
hidden_size % num_attention_heads == 0
), 'hidden_size must be divisible by num_attention_heads if kv_channels is None'
kv_channels = hidden_size // num_attention_heads
# Embeddings.
if self.pre_process:
self.embedding = Embedding(
hidden_size=self.hidden_size,
vocab_size=self.vocab_size,
max_sequence_length=self.max_position_embeddings,
init_method=self.init_method,
num_tokentypes=self.num_tokentypes,
use_cpu_initialization=use_cpu_initialization,
embedding_dropout_prob=self.hidden_dropout,
)
self._embedding_key = 'embedding'
# Transformer.
self.encoder = ParallelTransformer(
init_method=self.init_method,
output_layer_init_method=self.output_layer_init_method,
num_layers=self.num_layers,
hidden_size=self.hidden_size,
num_attention_heads=num_attention_heads,
apply_query_key_layer_scaling=apply_query_key_layer_scaling,
kv_channels=kv_channels,
ffn_hidden_size=ffn_hidden_size,
self_attn_mask_type=self.encoder_attn_mask_type,
pre_process=self.pre_process,
post_process=self.post_process,
fused_fp16=fused_fp16,
fused_bf16=fused_bf16,
fp32_residual_connection=fp32_residual_connection,
activations_checkpoint_method=activations_checkpoint_method,
activations_checkpoint_num_layers=activations_checkpoint_num_layers,
layernorm_epsilon=layernorm_epsilon,
hidden_dropout=hidden_dropout,
use_cpu_initialization=use_cpu_initialization,
bias_gelu_fusion=bias_gelu_fusion,
openai_gelu=openai_gelu,
onnx_safe=onnx_safe,
)
self._encoder_key = 'encoder'
# Decoder
if self.add_decoder:
assert (
parallel_state.get_pipeline_model_parallel_world_size() == 1
), 'pipeline parallelism is not supported in the presence of decoder'
self.decoder = ParallelTransformer(
layer_type=LayerType.decoder,
self_attn_mask_type=self.decoder_attn_mask_type,
init_method=self.init_method,
output_layer_init_method=self.output_layer_init_method,
num_layers=self.num_layers,
hidden_size=self.hidden_size,
num_attention_heads=num_attention_heads,
apply_query_key_layer_scaling=apply_query_key_layer_scaling,
kv_channels=kv_channels,
ffn_hidden_size=ffn_hidden_size,
pre_process=self.pre_process,
post_process=self.post_process,
fused_fp16=fused_fp16,
fused_bf16=fused_bf16,
fp32_residual_connection=fp32_residual_connection,
activations_checkpoint_method=activations_checkpoint_method,
activations_checkpoint_num_layers=activations_checkpoint_num_layers,
layernorm_epsilon=layernorm_epsilon,
hidden_dropout=hidden_dropout,
use_cpu_initialization=use_cpu_initialization,
bias_gelu_fusion=bias_gelu_fusion,
openai_gelu=openai_gelu,
onnx_safe=onnx_safe,
)
self._decoder_key = 'decoder'
if self.post_process:
# Pooler.
if self.add_pooler:
self.pooler = Pooler(self.hidden_size, self.init_method)
self._pooler_key = 'pooler'
def set_input_tensor(self, input_tensor):
""" See megatron.model.transformer.set_input_tensor()"""
self.encoder.set_input_tensor(input_tensor)
def forward(
self,
enc_input_ids,
enc_position_ids,
enc_attn_mask,
dec_input_ids=None,
dec_position_ids=None,
dec_attn_mask=None,
enc_dec_attn_mask=None,
tokentype_ids=None,
layer_past=None,
get_key_value=False,
pooling_sequence_index=0,
enc_hidden_states=None,
output_enc_hidden=False,
):
# Embeddings.
if self.pre_process:
embedding_output = self.embedding(enc_input_ids, enc_position_ids, tokentype_ids=tokentype_ids)
encoder_input = embedding_output
else:
encoder_input = None
# encoder.
if enc_hidden_states is None:
encoder_output = self.encoder(
encoder_input, enc_attn_mask, layer_past=layer_past, get_key_value=get_key_value
)
else:
encoder_output = enc_hidden_states.to(encoder_input.dtype)
if self.post_process:
if self.add_pooler:
pooled_output = self.pooler(encoder_output, pooling_sequence_index)
# output_enc_hidden refers to when we just need the encoder's
# output. For example, it is helpful to compute
# similarity between two sequences by average pooling
if not self.add_decoder or output_enc_hidden:
if self.add_pooler and self.post_process:
return encoder_output, pooled_output
else:
return encoder_output
# Decoder Embedding
dec_embedding_output = self.embedding(dec_input_ids, dec_position_ids)
# decoder
decoder_output = self.decoder(
dec_embedding_output,
dec_attn_mask,
layer_past=layer_past,
get_key_value=get_key_value,
encoder_output=encoder_output,
enc_dec_attn_mask=enc_dec_attn_mask,
)
if self.add_pooler and self.post_process:
return decoder_output, encoder_output, pooled_output
else:
return decoder_output, encoder_output
def state_dict_for_save_checkpoint(self, destination=None, prefix='', keep_vars=False):
"""For easy load."""
state_dict_ = {}
if self.pre_process:
state_dict_[self._embedding_key] = self.embedding.state_dict_for_save_checkpoint(
destination, prefix, keep_vars
)
state_dict_[self._encoder_key] = self.encoder.state_dict_for_save_checkpoint(destination, prefix, keep_vars)
if self.post_process:
if self.add_pooler:
state_dict_[self._pooler_key] = self.pooler.state_dict_for_save_checkpoint(
destination, prefix, keep_vars
)
if self.add_decoder:
state_dict_[self._decoder_key] = self.decoder.state_dict_for_save_checkpoint(
destination, prefix, keep_vars
)
return state_dict_
def load_state_dict(self, state_dict, strict=True):
"""Customized load."""
# Embedding.
if self.pre_process:
if self._embedding_key in state_dict:
state_dict_ = state_dict[self._embedding_key]
else:
# for backward compatibility.
state_dict_ = {}
for key in state_dict.keys():
if '_embeddings' in key:
state_dict_[key] = state_dict[key]
self.embedding.load_state_dict(state_dict_, strict=strict)
# Encoder.
if self._encoder_key in state_dict:
state_dict_ = state_dict[self._encoder_key]
# for backward compatibility.
elif 'transformer' in state_dict:
state_dict_ = state_dict['transformer']
else:
# for backward compatibility.
state_dict_ = {}
for key in state_dict.keys():
if 'transformer.' in key:
state_dict_[key.split('transformer.')[1]] = state_dict[key]
# for backward compatibility.
state_dict_self_attention = {}
for key in state_dict_.keys():
if '.attention.' in key:
state_dict_self_attention[key.replace(".attention.", ".self_attention.")] = state_dict_[key]
else:
state_dict_self_attention[key] = state_dict_[key]
state_dict_ = state_dict_self_attention
self.encoder.load_state_dict(state_dict_, strict=strict)
if self.post_process:
# pooler
if self.add_pooler:
assert 'pooler' in state_dict, 'could not find data for pooler in the checkpoint'
self.pooler.load_state_dict(state_dict[self._pooler_key], strict=strict)
# decoder
if self.add_decoder:
assert 'decoder' in state_dict, 'could not find data for pooler in the checkpoint'
self.decoder.load_state_dict(state_dict[self._decoder_key], strict=strict)

View file

@ -17,11 +17,8 @@
import os
import torch
from megatron import get_args, initialize_megatron
from megatron.checkpointing import set_checkpoint_version
from megatron.model import get_language_model
from megatron.model.bert_model import bert_attention_mask_func, bert_extended_attention_mask, bert_position_ids
from megatron.mpu import (
from apex.transformer.enums import AttnMaskType
from apex.transformer.parallel_state import (
get_model_parallel_group,
model_parallel_is_initialized,
set_pipeline_model_parallel_rank,
@ -30,6 +27,7 @@ from megatron.mpu import (
from omegaconf import DictConfig, OmegaConf
from nemo.collections.nlp.modules.common.bert_module import BertModule
from nemo.collections.nlp.modules.common.megatron.language_model import get_language_model
from nemo.core.classes import typecheck
from nemo.utils import logging
from nemo.utils.app_state import AppState
@ -37,13 +35,6 @@ from nemo.utils.app_state import AppState
__all__ = ['MegatronBertEncoder']
def complete_lazy_init(self):
# finish megatron-lm initialization
if hasattr(self, "_lazy_init_fn") and self._lazy_init_fn is not None:
self._lazy_init_fn()
self._lazy_init_fn = None
class MegatronBertEncoder(BertModule):
"""
MegatronBERT wraps around the Megatron Language model
@ -59,96 +50,47 @@ class MegatronBertEncoder(BertModule):
super().__init__()
self._model_parallel_size = model_parallel_size
self._model_parallel_rank = model_parallel_rank
self._restore_path = None
self._app_state = None
self._model_name = model_name
if 'vocab_size' in config:
self._vocab_size = config.pop('vocab_size')
else:
self._vocab_size = None
self._hidden_size = config.get('hidden_size')
if not os.path.exists(vocab_file):
raise ValueError(f'Vocab file not found at {vocab_file}')
# convert config to dictionary
if isinstance(config, DictConfig):
config = OmegaConf.to_container(config)
config["vocab_file"] = vocab_file
config['tokenizer_type'] = 'BertWordPieceLowerCase'
config['lazy_mpu_init'] = True
config['onnx_safe'] = True
num_tokentypes = config.pop('num_tokentypes', 2)
# if 'model_parallel_size' in config:
if self._model_parallel_size is not None:
app_state = AppState()
self._app_state = app_state
# must be set for model parallel megatron-lm
os.environ["WORLD_SIZE"] = str(app_state.world_size)
os.environ["RANK"] = str(self._model_parallel_rank)
extra_args_provider = self._update_megatron_args(tensor_model_parallel_size=self._model_parallel_size)
else:
extra_args_provider = self._update_megatron_args()
# configure globals for megatron
set_pipeline_model_parallel_rank(0) # pipeline model parallelism not implemented in NeMo
set_pipeline_model_parallel_world_size(1) # pipeline model parallelism not implemented in NeMo
# Initialize part of Megatron global state that is needed for its constructor.
# We set 'lazy_mpu_init' flag on to make Megatron do only the initialization that does not depend
# on ddp be initialized yet (and we don't want Megatron to initialize DDP itself either)
# and to return a hook for us to call after PTL has torch.distributed initialized.
# (or if no PTL in case of inference - then we'll initialize torch.distributed)
# We call and clear this hook on first call to forward()
self._lazy_init_fn = initialize_megatron(
extra_args_provider=extra_args_provider, args_defaults=config, ignore_unknown_args=True
raise ValueError(
f'megatron-lm bert has been deprecated in NeMo 1.5. Please use an earlier release of NeMo. Megatron bert support will be added back to NeMo in a future release.'
)
# read Megatron arguments back
args = get_args()
logging.info(f'Megatron-lm argparse args: {args}')
# self._model_parallel_size = model_parallel_size
# self._model_parallel_rank = model_parallel_rank
# self._restore_path = None
# self._app_state = None
# self._model_name = model_name
self.language_model, self._language_model_key = get_language_model(
attention_mask_func=bert_attention_mask_func, num_tokentypes=num_tokentypes, add_pooler=False
)
# if 'vocab_size' in config:
# self._vocab_size = config.pop('vocab_size')
# else:
# self._vocab_size = None
self.config = OmegaConf.create(config)
# key used for checkpoints
self._hidden_size = self.language_model.hidden_size
# self._hidden_size = config.get('hidden_size')
def _update_megatron_args(
self,
micro_batch_size=1,
tensor_model_parallel_size=1,
scaled_masked_softmax_fusion=False,
bias_gelu_fusion=False,
bias_dropout_fusion=False,
):
def extra_args_provider(parser):
parser.set_defaults(micro_batch_size=micro_batch_size)
parser.set_defaults(tensor_model_parallel_size=tensor_model_parallel_size)
parser.set_defaults(scaled_masked_softmax_fusion=scaled_masked_softmax_fusion)
parser.set_defaults(bias_gelu_fusion=bias_gelu_fusion)
parser.set_defaults(bias_dropout_fusion=bias_dropout_fusion)
# if not os.path.exists(vocab_file):
# raise ValueError(f'Vocab file not found at {vocab_file}')
return parser
# # convert config to dictionary
# if isinstance(config, DictConfig):
# config = OmegaConf.to_container(config)
# config["vocab_file"] = vocab_file
# config['tokenizer_type'] = 'BertWordPieceLowerCase'
# config['lazy_mpu_init'] = True
# config['onnx_safe'] = True
return extra_args_provider
# num_tokentypes = config.pop('num_tokentypes', 2)
def complete_lazy_init(self):
# finish megatron-lm initialization
if hasattr(self, "_lazy_init_fn") and self._lazy_init_fn is not None:
self._lazy_init_fn()
self._lazy_init_fn = None
# # configure globals for megatron
# set_pipeline_model_parallel_rank(0) # pipeline model parallelism not implemented in NeMo
# set_pipeline_model_parallel_world_size(1) # pipeline model parallelism not implemented in NeMo
# self.language_model, self._language_model_key = get_language_model(
# encoder_attn_mask_type=AttnMaskType.padding, num_tokentypes=num_tokentypes, add_pooler=False
# )
# self.config = OmegaConf.create(config)
# # key used for checkpoints
# self._hidden_size = self.language_model.hidden_size
@property
def hidden_size(self):
@ -172,17 +114,14 @@ class MegatronBertEncoder(BertModule):
@typecheck()
def forward(self, input_ids, attention_mask, token_type_ids=None):
app_state = AppState()
if app_state.model_parallel_size is None:
self.complete_lazy_init()
extended_attention_mask = bert_extended_attention_mask(attention_mask)
position_ids = bert_position_ids(input_ids)
sequence_output = self.language_model(
input_ids=input_ids,
position_ids=position_ids,
attention_mask=extended_attention_mask,
enc_input_ids=input_ids,
enc_position_ids=position_ids,
enc_attn_mask=extended_attention_mask,
tokentype_ids=token_type_ids,
)
return sequence_output
@ -196,13 +135,13 @@ class MegatronBertEncoder(BertModule):
state_dict = torch.load(filename, map_location='cpu')
if 'checkpoint_version' in state_dict:
if state_dict['checkpoint_version'] is not None:
set_checkpoint_version(state_dict['checkpoint_version'])
set_megatron_checkpoint_version(state_dict['checkpoint_version'])
logging.info(
f"Megatron-lm checkpoint version found. Setting checkpoint_version to {state_dict['checkpoint_version']}."
)
else:
logging.warning('Megatron-lm checkpoint version not found. Setting checkpoint_version to 0.')
set_checkpoint_version(0)
set_megatron_checkpoint_version(0)
# to load from Megatron pretrained checkpoint
if 'model' in state_dict:
self.language_model.load_state_dict(state_dict['model'][self._language_model_key])
@ -233,3 +172,39 @@ class MegatronBertEncoder(BertModule):
logging.info(f'torch.distributed not initialized yet. Will not restore model parallel checkpoint')
else:
logging.error(f'restore_path: {restore_path} must be a file or directory.')
def get_megatron_checkpoint_version():
app_state = AppState()
return app_state._megatron_checkpoint_version
def set_megatron_checkpoint_version(version: int = None):
app_state = AppState()
app_state._megatron_checkpoint_version = version
def bert_extended_attention_mask(attention_mask):
# We create a 3D attention mask from a 2D tensor mask.
# [b, 1, s]
attention_mask_b1s = attention_mask.unsqueeze(1)
# [b, s, 1]
attention_mask_bs1 = attention_mask.unsqueeze(2)
# [b, s, s]
attention_mask_bss = attention_mask_b1s * attention_mask_bs1
# [b, 1, s, s]
extended_attention_mask = attention_mask_bss.unsqueeze(1)
# Convert attention mask to binary:
extended_attention_mask = extended_attention_mask < 0.5
return extended_attention_mask
def bert_position_ids(token_ids):
# Create position ids
seq_length = token_ids.size(1)
position_ids = torch.arange(seq_length, dtype=torch.long, device=token_ids.device)
position_ids = position_ids.unsqueeze(0).expand_as(token_ids)
return position_ids

View file

@ -0,0 +1,89 @@
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import random
import numpy as np
import torch
from apex.transformer import tensor_parallel
from apex.transformer.parallel_state import (
get_pipeline_model_parallel_rank,
set_pipeline_model_parallel_rank,
set_pipeline_model_parallel_world_size,
set_tensor_model_parallel_rank,
set_tensor_model_parallel_world_size,
)
from nemo.collections.nlp.modules.common.megatron.megatron_utils import compute_model_parallel_rank
from nemo.utils import AppState
def initialize_model_parallel_for_nemo(
world_size, global_rank, local_rank, tensor_model_parallel_size=1, seed=1234,
):
# updating NeMo globals
app_state = AppState()
app_state.global_rank = global_rank
app_state.world_size = world_size
app_state.model_parallel_size = tensor_model_parallel_size
app_state.model_parallel_rank = compute_model_parallel_rank(local_rank, tensor_model_parallel_size)
# update apex.mpu globals
set_tensor_model_parallel_world_size(tensor_model_parallel_size)
set_tensor_model_parallel_rank(app_state.model_parallel_rank)
# pipeline model parallelism not implemented in NeMo yet
set_pipeline_model_parallel_rank(0)
set_pipeline_model_parallel_world_size(1)
_set_random_seed(seed)
app_state._is_megatron_initialized = True
def _set_random_seed(seed_):
"""Set random seed for reproducability."""
if seed_ is not None and seed_ > 0:
# Ensure that different pipeline MP stages get different seeds.
seed = seed_ + (100 * get_pipeline_model_parallel_rank())
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.device_count() > 0:
tensor_parallel.model_parallel_cuda_manual_seed(seed)
else:
raise ValueError('Seed ({}) should be a positive integer.'.format(seed_))
def set_jit_fusion_options():
"""Set PyTorch JIT layer fusion options."""
# flags required to enable jit fusion kernels
TORCH_MAJOR = int(torch.__version__.split('.')[0])
TORCH_MINOR = int(torch.__version__.split('.')[1])
if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR >= 10):
# nvfuser
torch._C._jit_set_profiling_executor(True)
torch._C._jit_set_profiling_mode(True)
torch._C._jit_override_can_fuse_on_cpu(False)
torch._C._jit_override_can_fuse_on_gpu(False)
torch._C._jit_set_texpr_fuser_enabled(False)
torch._C._jit_set_nvfuser_enabled(True)
torch._C._debug_set_autodiff_subgraph_inlining(False)
else:
# legacy pytorch fuser
torch._C._jit_set_profiling_mode(False)
torch._C._jit_set_profiling_executor(False)
torch._C._jit_override_can_fuse_on_cpu(True)
torch._C._jit_override_can_fuse_on_gpu(True)

View file

@ -46,6 +46,14 @@ MEGATRON_CACHE = os.path.join(torch_home, "megatron")
CONFIGS = {"345m": {"hidden_size": 1024, "num_attention_heads": 16, "num_layers": 24, "max_position_embeddings": 512}}
MEGATRON_CONFIG_MAP = {
"megatron-gpt-345m": {
"config": CONFIGS["345m"],
"checkpoint": "models/nvidia/megatron_lm_345m/versions/v0.0/files/release/mp_rank_00/model_optim_rng.pt",
"vocab": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-vocab.json",
"merges_file": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-merges.txt",
"do_lower_case": False,
"tokenizer_name": "gpt2",
},
"megatron-bert-345m-uncased": {
"config": CONFIGS["345m"],
"checkpoint": "https://api.ngc.nvidia.com/v2/models/nvidia/megatron_bert_345m/versions/v0.0/files/release/mp_rank_00/model_optim_rng.pt",
@ -97,6 +105,7 @@ def get_megatron_lm_model(
config_file: Optional[str] = None,
checkpoint_file: Optional[str] = None,
vocab_file: Optional[str] = None,
merges_file: Optional[str] = None,
) -> Tuple[MegatronBertEncoder, str]:
"""
Returns MegatronBertEncoder and a default or user specified path to the checkpoint file
@ -113,81 +122,87 @@ def get_megatron_lm_model(
model: MegatronBertEncoder
checkpoint_file: path to checkpoint file or directory
"""
config = None
# get default config and checkpoint
if config_file:
with open(config_file) as f:
config = json.load(f)
# replace dashes with underscores in config keys
fixed_config = {}
for key in config.keys():
fixed_key = key.replace("-", "_")
if fixed_key == 'max_seq_length':
fixed_key = 'max_position_embeddings'
fixed_config[fixed_key] = config[key]
# 'vocab_size" no longer used.
if 'vocab_size' in fixed_config:
fixed_config.pop('vocab_size')
config = fixed_config
elif config_dict:
config = config_dict
elif pretrained_model_name in get_megatron_lm_models_list():
config = get_megatron_config(pretrained_model_name)
else:
raise ValueError(f"{pretrained_model_name} is not supported")
if config is None:
raise ValueError(f"config_file or config_dict is required for {pretrained_model_name}")
if not checkpoint_file:
checkpoint_file = get_megatron_checkpoint(pretrained_model_name)
if not vocab_file:
vocab_file = get_megatron_vocab_file(pretrained_model_name)
app_state = AppState()
if app_state.model_parallel_size is not None and app_state.model_parallel_rank is not None:
# model parallel already known from .nemo restore
model_parallel_size = app_state.model_parallel_size
model_parallel_rank = app_state.model_parallel_rank
elif os.path.isdir(checkpoint_file):
# starting training from megatron-lm checkpoint
mp_ranks = glob.glob(os.path.join(checkpoint_file, 'mp_rank*'))
model_parallel_size = len(mp_ranks)
app_state.model_parallel_size = model_parallel_size
logging.info(
(
f'restore_path: {checkpoint_file} is a directory. '
f'Assuming megatron model parallelism with '
f'model_parallel_size: {model_parallel_size}'
)
)
# try to get local rank from global
local_rank = None
try:
local_rank = int(os.environ['LOCAL_RANK'])
except:
logging.info('Global variable LOCAL_RANK not yet specified')
if local_rank is not None:
app_state.local_rank = local_rank
else:
# if local is None then we are on the main process
local_rank = 0
model_parallel_rank = compute_model_parallel_rank(local_rank, model_parallel_size)
app_state.model_parallel_rank = model_parallel_rank
else:
model_parallel_size = None
model_parallel_rank = None
model = MegatronBertEncoder(
model_name=pretrained_model_name,
config=config,
vocab_file=vocab_file,
model_parallel_size=model_parallel_size,
model_parallel_rank=model_parallel_rank,
raise ValueError(
f'megatron-lm bert has been deprecated in NeMo 1.5. Please use an earlier release of NeMo. Megatron bert support will be added back to NeMo in a future release.'
)
# config = None
# # get default config and checkpoint
# if config_file:
# with open(config_file) as f:
# config = json.load(f)
# # replace dashes with underscores in config keys
# fixed_config = {}
# for key in config.keys():
# fixed_key = key.replace("-", "_")
# if fixed_key == 'max_seq_length':
# fixed_key = 'max_position_embeddings'
# fixed_config[fixed_key] = config[key]
# # 'vocab_size" no longer used.
# if 'vocab_size' in fixed_config:
# fixed_config.pop('vocab_size')
# config = fixed_config
# elif config_dict:
# config = config_dict
# elif pretrained_model_name in get_megatron_lm_models_list():
# config = get_megatron_config(pretrained_model_name)
# else:
# raise ValueError(f"{pretrained_model_name} is not supported")
return model, checkpoint_file
# if config is None:
# raise ValueError(f"config_file or config_dict is required for {pretrained_model_name}")
# if not checkpoint_file:
# checkpoint_file = get_megatron_checkpoint(pretrained_model_name)
# if not vocab_file:
# vocab_file = get_megatron_vocab_file(pretrained_model_name)
# if not merges_file:
# merges_file = get_megatron_merges_file(pretrained_model_name)
# app_state = AppState()
# if app_state.model_parallel_size is not None and app_state.model_parallel_rank is not None:
# # model parallel already known from .nemo restore
# model_parallel_size = app_state.model_parallel_size
# model_parallel_rank = app_state.model_parallel_rank
# elif os.path.isdir(checkpoint_file):
# # starting training from megatron-lm checkpoint
# mp_ranks = glob.glob(os.path.join(checkpoint_file, 'mp_rank*'))
# model_parallel_size = len(mp_ranks)
# app_state.model_parallel_size = model_parallel_size
# logging.info(
# (
# f'restore_path: {checkpoint_file} is a directory. '
# f'Assuming megatron model parallelism with '
# f'model_parallel_size: {model_parallel_size}'
# )
# )
# # try to get local rank from global
# local_rank = None
# try:
# local_rank = int(os.environ['LOCAL_RANK'])
# except:
# logging.info('Global variable LOCAL_RANK not yet specified')
# if local_rank is not None:
# app_state.local_rank = local_rank
# else:
# # if local is None then we are on the main process
# local_rank = 0
# model_parallel_rank = compute_model_parallel_rank(local_rank, model_parallel_size)
# app_state.model_parallel_rank = model_parallel_rank
# else:
# model_parallel_size = None
# model_parallel_rank = None
# model = MegatronBertEncoder(
# model_name=pretrained_model_name,
# config=config,
# vocab_file=vocab_file,
# model_parallel_size=model_parallel_size,
# model_parallel_rank=model_parallel_rank,
# )
# return model, checkpoint_file
def compute_model_parallel_rank(local_rank, model_parallel_size):
@ -239,6 +254,26 @@ def get_megatron_vocab_file(pretrained_model_name: str) -> str:
return path
def get_megatron_merges_file(pretrained_model_name: str) -> str:
"""
Gets merge file from cache or downloads it
Args:
pretrained_model_name: pretrained model name
Returns:
path: path to the vocab file
"""
if 'gpt' not in pretrained_model_name.lower():
return None
_check_megatron_name(pretrained_model_name)
url = MEGATRON_CONFIG_MAP[pretrained_model_name]["merges_file"]
path = os.path.join(MEGATRON_CACHE, pretrained_model_name + "_merges")
path = _download(path, url)
return path
def get_megatron_checkpoint(pretrained_model_name: str) -> str:
"""
Gets checkpoint file from cache or downloads it

View file

@ -0,0 +1,92 @@
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Megatron Module"""
import torch
from apex.transformer import parallel_state, tensor_parallel
def param_is_not_shared(param):
return not hasattr(param, 'shared') or not param.shared
class MegatronModule(torch.nn.Module):
"""Megatron specific extensions of torch Module with support
for pipelining."""
def __init__(self, share_word_embeddings=True):
super(MegatronModule, self).__init__()
self.share_word_embeddings = share_word_embeddings
def word_embeddings_weight(self):
if parallel_state.is_pipeline_first_stage(ignore_virtual=True):
return self.language_model.embedding.word_embeddings.weight
if parallel_state.is_pipeline_last_stage(ignore_virtual=True):
if not self.share_word_embeddings:
raise Exception(
'word_embeddings_weight() called for last ' 'stage, but share_word_embeddings is false'
)
return self.word_embeddings.weight
raise Exception('word_embeddings_weight() should be ' 'called for first and last stage only')
def initialize_word_embeddings(self, init_method, vocab_size, hidden_size, pipeline_model_parallel_size=1):
if not self.share_word_embeddings:
raise Exception('initialize_word_embeddings() was called but ' 'share_word_embeddings is false')
# TODO: pipeline model parallelism is not implemented in NeMo yet
# This function just initializes the word embeddings in the final stage
# when we are using pipeline parallelism. If we aren't using pipeline
# parallelism there is nothing to do.
if pipeline_model_parallel_size == 1:
return
# Parameters are shared between the word embeddings layer, and the
# heads at the end of the model. In a pipelined setup with more than
# one stage, the initial embedding layer and the head are on different
# workers, so we do the following:
# 1. Create a second copy of word_embeddings on the last stage, with
# initial parameters of 0.0.
# 2. Do an all-reduce between the first and last stage to ensure that
# the two copies of word_embeddings start off with the same
# parameter values.
# 3. In the training loop, before an all-reduce between the grads of
# the two word_embeddings layers to ensure that every applied weight
# update is the same on both stages.
if parallel_state.is_pipeline_last_stage():
assert not parallel_state.is_pipeline_first_stage()
self._word_embeddings_for_head_key = 'word_embeddings_for_head'
# set word_embeddings weights to 0 here, then copy first
# stage's weights using all_reduce below.
self.word_embeddings = tensor_parallel.VocabParallelEmbedding(
vocab_size, hidden_size, init_method=init_method
)
self.word_embeddings.weight.data.fill_(0)
self.word_embeddings.weight.shared = True
# Ensure that first and last stages have the same initial parameter
# values.
if torch.distributed.is_initialized():
if parallel_state.is_pipeline_first_stage() or parallel_state.is_pipeline_last_stage():
torch.distributed.all_reduce(
self.word_embeddings_weight().data, group=parallel_state.get_embedding_group()
)
else:
print(
"WARNING! Distributed processes aren't initialized, so "
"word embeddings in the last layer are not initialized. "
"If you are just manipulating a model this is fine, but "
"this needs to be handled manually. If you are training "
"something is definitely wrong."
)

View file

@ -0,0 +1,850 @@
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Transformer."""
import math
import torch
import torch.nn.functional as F
from apex.normalization.fused_layer_norm import MixedFusedLayerNorm as LayerNorm
from apex.transformer import parallel_state, tensor_parallel
from apex.transformer.enums import AttnMaskType, AttnType, LayerType
from apex.transformer.functional.fused_softmax import FusedScaleMaskSoftmax
from nemo.collections.nlp.modules.common.megatron.fused_bias_dropout_add import (
BiasDropoutAddFusedInference,
BiasDropoutAddFusedTrain,
bias_dropout_add,
)
from nemo.collections.nlp.modules.common.megatron.fused_bias_gelu import FusedBiasGeLU
from nemo.collections.nlp.modules.common.megatron.module import MegatronModule
from nemo.collections.nlp.modules.common.megatron.utils import attention_mask_func, erf_gelu
""" We use the following notation throughout this file:
h: hidden size
n: number of attention heads
p: number of model parallel partitions
np: n/p
hp: h/p
hn: h/n
b: batch size
s: sequence length
l: number of layers
Transformer takes input of size [s, b, h] and returns a
tensor of the same size. We use the following arguments:
hyperparameters: transformer hyperparameters
"""
class ParallelMLP(MegatronModule):
"""MLP.
MLP will take the input with h hidden state, project it to 4*h
hidden dimension, perform nonlinear transformation, and project the
state back into h hidden dimension.
"""
def __init__(
self,
init_method,
output_layer_init_method,
hidden_size,
ffn_hidden_size,
use_cpu_initialization=False,
bias_gelu_fusion=True,
openai_gelu=False,
onnx_safe=False,
fused_fp16=False,
fused_bf16=False,
):
super(ParallelMLP, self).__init__()
# Project to 4h.
self.dense_h_to_4h = tensor_parallel.ColumnParallelLinear(
hidden_size,
ffn_hidden_size,
gather_output=False,
init_method=init_method,
skip_bias_add=True,
use_cpu_initialization=use_cpu_initialization,
)
self.bias_gelu_fusion = bias_gelu_fusion
self.activation_func = F.gelu
if openai_gelu:
self.activation_func = openai_gelu
elif onnx_safe:
self.activation_func = erf_gelu
# Project back to h.
self.dense_4h_to_h = tensor_parallel.RowParallelLinear(
ffn_hidden_size,
hidden_size,
input_is_parallel=True,
init_method=output_layer_init_method,
skip_bias_add=True,
use_cpu_initialization=use_cpu_initialization,
)
self.bias_gelu_impl = FusedBiasGeLU(fused_fp16, fused_bf16)
def forward(self, hidden_states):
# [s, b, 4hp]
intermediate_parallel, bias_parallel = self.dense_h_to_4h(hidden_states)
if self.bias_gelu_fusion:
intermediate_parallel = self.bias_gelu_impl(intermediate_parallel, bias_parallel)
else:
intermediate_parallel = self.activation_func(intermediate_parallel + bias_parallel)
# [s, b, h]
output, output_bias = self.dense_4h_to_h(intermediate_parallel)
return output, output_bias
class ParallelAttention(MegatronModule):
"""Parallel self-attention layer abstract class.
Self-attention layer takes input with size [b, s, h]
and returns output of the same size.
"""
def __init__(
self,
init_method,
output_layer_init_method,
layer_number,
num_attention_heads,
hidden_size,
attention_type=AttnType.self_attn,
attn_mask_type=AttnMaskType.padding,
fused_fp16=False,
fused_bf16=False,
apply_query_key_layer_scaling=True,
kv_channels=None,
use_cpu_initialization=False,
masked_softmax_fusion=True,
attention_dropout=0.1,
):
super(ParallelAttention, self).__init__()
self.apply_query_key_layer_scaling = apply_query_key_layer_scaling
self.attention_softmax_in_fp32 = False
if self.apply_query_key_layer_scaling:
self.attention_softmax_in_fp32 = True
self.layer_number = max(1, layer_number)
self.attention_type = attention_type
self.attn_mask_type = attn_mask_type
if kv_channels is None:
assert (
hidden_size % num_attention_heads == 0
), 'hidden_size must be divisible by num_attention_heads if kv_channels is None'
kv_channels = hidden_size // num_attention_heads
projection_size = kv_channels * num_attention_heads
# Per attention head and per partition values.
world_size = parallel_state.get_tensor_model_parallel_world_size()
self.hidden_size_per_partition = tensor_parallel.divide(projection_size, world_size)
self.hidden_size_per_attention_head = tensor_parallel.divide(projection_size, num_attention_heads)
self.num_attention_heads_per_partition = tensor_parallel.divide(num_attention_heads, world_size)
# Strided linear layer.
if attention_type == AttnType.self_attn:
self.query_key_value = tensor_parallel.ColumnParallelLinear(
hidden_size,
3 * projection_size,
gather_output=False,
init_method=init_method,
use_cpu_initialization=use_cpu_initialization,
)
else:
assert attention_type == AttnType.cross_attn
self.query = tensor_parallel.ColumnParallelLinear(
hidden_size, projection_size, gather_output=False, init_method=init_method
)
self.key_value = tensor_parallel.ColumnParallelLinear(
hidden_size, 2 * projection_size, gather_output=False, init_method=init_method
)
coeff = None
self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
if self.apply_query_key_layer_scaling:
coeff = self.layer_number
self.norm_factor *= coeff
self.scale_mask_softmax = FusedScaleMaskSoftmax(
fused_fp16,
fused_bf16,
self.attn_mask_type,
masked_softmax_fusion,
attention_mask_func,
self.attention_softmax_in_fp32,
coeff,
)
# Dropout. Note that for a single iteration, this layer will generate
# different outputs on different number of parallel partitions but
# on average it should not be partition dependent.
self.attention_dropout = torch.nn.Dropout(attention_dropout)
# Output.
self.dense = tensor_parallel.RowParallelLinear(
projection_size,
hidden_size,
input_is_parallel=True,
init_method=output_layer_init_method,
skip_bias_add=True,
use_cpu_initialization=use_cpu_initialization,
)
def forward(self, hidden_states, attention_mask, layer_past=None, get_key_value=False, encoder_output=None):
# hidden_states: [sq, b, h]
# =====================
# Query, Key, and Value
# =====================
if self.attention_type == AttnType.self_attn:
# Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)]
mixed_x_layer, _ = self.query_key_value(hidden_states)
# [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn]
new_tensor_shape = mixed_x_layer.size()[:-1] + (
self.num_attention_heads_per_partition,
3 * self.hidden_size_per_attention_head,
)
mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)
# [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn]
(query_layer, key_layer, value_layer) = tensor_parallel.split_tensor_along_last_dim(mixed_x_layer, 3)
else:
# Attention heads [sk, b, h] --> [sk, b, (np * 2 * hn)]
mixed_kv_layer, _ = self.key_value(encoder_output)
# [sk, b, (np * 2 * hn)] --> [sk, b, np, 2 * hn]
new_tensor_shape = mixed_kv_layer.size()[:-1] + (
self.num_attention_heads_per_partition,
2 * self.hidden_size_per_attention_head,
)
mixed_kv_layer = mixed_kv_layer.view(*new_tensor_shape)
# [sk, b, np, 2 * hn] --> 2 [sk, b, np, hn]
(key_layer, value_layer) = tensor_parallel.split_tensor_along_last_dim(mixed_kv_layer, 2)
# Attention head [sq, b, h] --> [sq, b, hp]
query_layer, _ = self.query(hidden_states)
# [sq, b, hp] --> [sq, b, np, hn]
new_tensor_shape = query_layer.size()[:-1] + (
self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head,
)
query_layer = query_layer.view(*new_tensor_shape)
# ==================================
# Adjust key and value for inference
# ==================================
if layer_past is not None:
past_key, past_value = layer_past
key_layer = torch.cat((past_key.type_as(key_layer), key_layer), dim=0)
value_layer = torch.cat((past_value.type_as(value_layer), value_layer), dim=0)
if get_key_value:
present = (key_layer, value_layer)
# ===================================
# Raw attention scores. [b, np, s, s]
# ===================================
# [b, np, sq, sk]
output_size = (query_layer.size(1), query_layer.size(2), query_layer.size(0), key_layer.size(0))
# [sq, b, np, hn] -> [sq, b * np, hn]
query_layer = query_layer.view(output_size[2], output_size[0] * output_size[1], -1)
# [sk, b, np, hn] -> [sk, b * np, hn]
key_layer = key_layer.view(output_size[3], output_size[0] * output_size[1], -1)
# preallocting result tensor: [b * np, sq, sk]
matmul_result = torch.empty(
output_size[0] * output_size[1],
output_size[2],
output_size[3],
dtype=query_layer.dtype,
device=torch.cuda.current_device(),
)
# Raw attention scores. [b * np, sq, sk]
matmul_result = torch.baddbmm(
matmul_result,
query_layer.transpose(0, 1), # [b * np, sq, hn]
key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk]
beta=0.0,
alpha=(1.0 / self.norm_factor),
)
# change view to [b, np, sq, sk]
attention_scores = matmul_result.view(*output_size)
# ==================================================
# Update attention mask for inference. [b, np, sq, sk]
# ==================================================
if get_key_value:
with torch.no_grad():
if layer_past is not None:
attention_mask = attention_mask[
..., attention_scores.size(3) - 1, : attention_scores.size(3)
].unsqueeze(2)
else:
attention_mask = attention_mask[..., : attention_scores.size(3), : attention_scores.size(3)]
# ===========================
# Attention probs and dropout
# ===========================
# attention scores and attention mask [b, np, sq, sk]
attention_probs = self.scale_mask_softmax(attention_scores, attention_mask)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
with tensor_parallel.random.get_cuda_rng_tracker().fork():
attention_probs = self.attention_dropout(attention_probs)
# =========================
# Context layer. [sq, b, hp]
# =========================
# value_layer -> context layer.
# [sk, b, np, hn] --> [b, np, sq, hn]
# context layer shape: [b, np, sq, hn]
output_size = (value_layer.size(1), value_layer.size(2), query_layer.size(0), value_layer.size(3))
# change view [sk, b * np, hn]
value_layer = value_layer.view(value_layer.size(0), output_size[0] * output_size[1], -1)
# change view [b * np, sq, sk]
attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1)
# matmul: [b * np, sq, hn]
context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1))
# change view [b, np, sq, hn]
context_layer = context_layer.view(*output_size)
# [b, np, sq, hn] --> [sq, b, np, hn]
context_layer = context_layer.permute(2, 0, 1, 3).contiguous()
# [sq, b, np, hn] --> [sq, b, hp]
new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
context_layer = context_layer.view(*new_context_layer_shape)
# =================
# Output. [sq, b, h]
# =================
output, bias = self.dense(context_layer)
if get_key_value:
output = [output, present]
return output, bias
def get_bias_dropout_add(training):
def _bias_dropout_add(x, bias, residual, prob):
return bias_dropout_add(x, bias, residual, prob, training)
return _bias_dropout_add
class ParallelTransformerLayer_(MegatronModule):
"""A single transformer layer.
Transformer layer takes input with size [b, s, h] and returns an
output of the same size.
"""
def __init__(
self,
init_method,
output_layer_init_method,
layer_number,
hidden_size,
ffn_hidden_size,
num_attention_heads,
layer_type=LayerType.encoder,
self_attn_mask_type=AttnMaskType.padding,
fp32_residual_connection=False,
apply_residual_connection_post_layernorm=False,
fused_fp16=False,
fused_bf16=False,
apply_query_key_layer_scaling=True,
kv_channels=None,
layernorm_epsilon=1e-5,
hidden_dropout=0.1,
bias_dropout_fusion=True,
use_cpu_initialization=False,
bias_gelu_fusion=True,
openai_gelu=False,
onnx_safe=False,
masked_softmax_fusion=True,
attention_dropout=0.1,
):
super(ParallelTransformerLayer_, self).__init__()
assert not (fused_fp16 and fused_bf16), "both fused_fp16 and fused_bf16 flags cannot be True at the same time."
if kv_channels is None:
assert (
hidden_size % num_attention_heads == 0
), 'hidden_size must be divisible by num_attention_heads if kv_channels is None'
kv_channels = hidden_size // num_attention_heads
self.layer_number = layer_number
self.layer_type = layer_type
self.apply_residual_connection_post_layernorm = apply_residual_connection_post_layernorm # if true apply residual connection post layer norm (like original bert)
self.bf16 = fused_bf16
self.fp32_residual_connection = fp32_residual_connection # if true move residual connections to fp32
# Layernorm on the input data.
self.input_layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon, fp16=fused_fp16, bf16=fused_bf16)
# Self attention.
self.self_attention = ParallelAttention(
init_method=init_method,
output_layer_init_method=output_layer_init_method,
layer_number=layer_number,
num_attention_heads=num_attention_heads,
hidden_size=hidden_size,
attention_type=AttnType.self_attn,
attn_mask_type=self_attn_mask_type,
fused_fp16=fused_fp16,
fused_bf16=fused_bf16,
apply_query_key_layer_scaling=apply_query_key_layer_scaling,
kv_channels=kv_channels,
use_cpu_initialization=use_cpu_initialization,
masked_softmax_fusion=masked_softmax_fusion,
attention_dropout=attention_dropout,
)
self.hidden_dropout = hidden_dropout
self.bias_dropout_fusion = bias_dropout_fusion # if true, enable bias dropout fusion
# Layernorm on the attention output
self.post_attention_layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon, fp16=fused_fp16, bf16=fused_bf16)
if self.layer_type == LayerType.decoder:
self.inter_attention = ParallelAttention(
init_method=init_method,
output_layer_init_method=output_layer_init_method,
layer_number=layer_number,
num_attention_heads=num_attention_heads,
hidden_size=hidden_size,
attention_type=AttnType.cross_attn,
attn_mask_type=AttnMaskType.padding,
fused_fp16=fused_fp16,
fused_bf16=fused_bf16,
apply_query_key_layer_scaling=apply_query_key_layer_scaling,
kv_channels=kv_channels,
use_cpu_initialization=use_cpu_initialization,
masked_softmax_fusion=masked_softmax_fusion,
attention_dropout=attention_dropout,
)
# Layernorm on the attention output.
self.post_inter_attention_layernorm = LayerNorm(
hidden_size, eps=layernorm_epsilon, fp16=fused_fp16, bf16=fused_bf16
)
# MLP
self.mlp = ParallelMLP(
init_method=init_method,
output_layer_init_method=output_layer_init_method,
hidden_size=hidden_size,
ffn_hidden_size=ffn_hidden_size,
use_cpu_initialization=use_cpu_initialization,
bias_gelu_fusion=bias_gelu_fusion,
openai_gelu=openai_gelu,
onnx_safe=onnx_safe,
fused_fp16=fused_fp16,
fused_bf16=fused_bf16,
)
self.bias_dropout_add_fused_train = BiasDropoutAddFusedTrain(fused_fp16, fused_bf16)
self.bias_dropout_add_fused_inference = BiasDropoutAddFusedInference(fused_fp16, fused_bf16)
self.bias_dropout_add_fused_train = BiasDropoutAddFusedTrain(fused_fp16, fused_bf16)
self.bias_dropout_add_fused_inference = BiasDropoutAddFusedInference(fused_fp16, fused_bf16)
def forward(
self,
hidden_states,
attention_mask,
encoder_output=None,
enc_dec_attn_mask=None,
layer_past=None,
get_key_value=False,
):
# hidden_states: [b, s, h]
# Layer norm at the beginning of the transformer layer.
layernorm_output = self.input_layernorm(hidden_states)
# Self attention.
attention_output, attention_bias = self.self_attention(
layernorm_output, attention_mask, layer_past=layer_past, get_key_value=get_key_value
)
if get_key_value:
attention_output, presents = attention_output
# Residual connection.
if self.apply_residual_connection_post_layernorm:
residual = layernorm_output
else:
residual = hidden_states
# jit scripting for a nn.module (with dropout) is not
# trigerring the fusion kernel. For now, we use two
# different nn.functional routines to account for varying
# dropout semantics during training and inference phases.
if self.bias_dropout_fusion:
if self.training:
bias_dropout_add_func = self.bias_dropout_add_fused_train
else:
bias_dropout_add_func = self.bias_dropout_add_fused_inference
else:
bias_dropout_add_func = get_bias_dropout_add(self.training)
# re-enable torch grad to enable fused optimization.
with torch.enable_grad():
layernorm_input = bias_dropout_add_func(
attention_output, attention_bias.expand_as(residual), residual, self.hidden_dropout
)
# Layer norm post the self attention.
layernorm_output = self.post_attention_layernorm(layernorm_input)
if self.layer_type == LayerType.decoder:
attention_output, attention_bias = self.inter_attention(
layernorm_output, enc_dec_attn_mask, encoder_output=encoder_output
)
# residual connection
if self.apply_residual_connection_post_layernorm:
residual = layernorm_output
else:
residual = layernorm_input
# re-enable torch grad to enable fused optimization.
with torch.enable_grad():
layernorm_input = bias_dropout_add_func(
attention_output, attention_bias.expand_as(residual), residual, self.hidden_dropout
)
# Layer norm post the decoder attention
layernorm_output = self.post_inter_attention_layernorm(layernorm_input)
# MLP.
mlp_output, mlp_bias = self.mlp(layernorm_output)
# Second residual connection.
if self.apply_residual_connection_post_layernorm:
residual = layernorm_output
else:
residual = layernorm_input
# re-enable torch grad to enable fused optimization.
with torch.enable_grad():
output = bias_dropout_add_func(mlp_output, mlp_bias.expand_as(residual), residual, self.hidden_dropout)
if get_key_value:
output = [output, presents]
return output
class ParallelTransformerLayer(ParallelTransformerLayer_):
def __init__(self, **kwargs):
super(ParallelTransformerLayer, self).__init__(**kwargs)
assert not (kwargs['fused_fp16'] and kwargs['fused_bf16'])
if kwargs['fused_fp16']:
self.dtype = torch.float16
elif kwargs['fused_bf16']:
self.dtype = torch.bfloat16
else:
self.dtype = torch.float32
def forward(
self,
hidden_states,
attention_mask,
encoder_output=None,
enc_dec_attn_mask=None,
layer_past=None,
get_key_value=False,
):
if self.dtype == torch.float32:
return super().forward(
hidden_states, attention_mask, encoder_output, enc_dec_attn_mask, layer_past, get_key_value
)
else:
with torch.cuda.amp.autocast(dtype=self.dtype):
return super().forward(
hidden_states, attention_mask, encoder_output, enc_dec_attn_mask, layer_past, get_key_value
)
class ParallelTransformer(MegatronModule):
"""Transformer class."""
def __init__(
self,
init_method,
output_layer_init_method,
num_layers,
hidden_size,
ffn_hidden_size,
num_attention_heads,
apply_query_key_layer_scaling=True,
kv_channels=None,
layer_type=LayerType.encoder,
self_attn_mask_type=AttnMaskType.padding,
pre_process=True,
post_process=True,
fused_fp16=False,
fused_bf16=False,
fp32_residual_connection=False,
activations_checkpoint_method=None,
activations_checkpoint_num_layers=1,
layernorm_epsilon=1e-5,
hidden_dropout=0.1,
use_cpu_initialization=False,
bias_gelu_fusion=True,
openai_gelu=False,
onnx_safe=False,
):
super(ParallelTransformer, self).__init__()
assert not (fused_fp16 and fused_bf16), "both fused_fp16 and fused_bf16 flags cannot be True at the same time."
if kv_channels is None:
assert (
hidden_size % num_attention_heads == 0
), 'hidden_size must be divisible by num_attention_heads if kv_channels is None'
kv_channels = hidden_size // num_attention_heads
self.bf16 = fused_bf16
self.fp32_residual_connection = fp32_residual_connection
self.pre_process = pre_process
self.post_process = post_process
self.input_tensor = None
# Store activation checkpointing flag.
self.activations_checkpoint_method = activations_checkpoint_method
self.activations_checkpoint_num_layers = activations_checkpoint_num_layers
# Number of layers.
assert (
num_layers % parallel_state.get_pipeline_model_parallel_world_size() == 0
), 'num_layers must be divisible by pipeline_model_parallel_size'
self.num_layers = num_layers // parallel_state.get_pipeline_model_parallel_world_size()
# Transformer layers.
def build_layer(layer_number):
return ParallelTransformerLayer(
init_method=init_method,
output_layer_init_method=output_layer_init_method,
layer_number=layer_number,
hidden_size=hidden_size,
ffn_hidden_size=ffn_hidden_size,
num_attention_heads=num_attention_heads,
apply_query_key_layer_scaling=apply_query_key_layer_scaling,
kv_channels=kv_channels,
layer_type=layer_type,
self_attn_mask_type=self_attn_mask_type,
fused_fp16=fused_fp16,
fused_bf16=fused_bf16,
fp32_residual_connection=fp32_residual_connection,
layernorm_epsilon=layernorm_epsilon,
hidden_dropout=hidden_dropout,
use_cpu_initialization=use_cpu_initialization,
bias_gelu_fusion=bias_gelu_fusion,
openai_gelu=openai_gelu,
onnx_safe=onnx_safe,
)
# TODO: get virtual_pipeline_model_parallel_size from apex.mpu
# if parallel_state.get_virtual_pipeline_model_parallel_rank() is not None:
# assert args.num_layers % args.virtual_pipeline_model_parallel_size == 0, (
# 'num_layers_per_stage must be divisible by ' 'virtual_pipeline_model_parallel_size'
# )
# # Number of layers in each model chunk is the number of layers in the stage,
# # divided by the number of model chunks in a stage.
# self.num_layers = self.num_layers // args.virtual_pipeline_model_parallel_size
# # With 8 layers, 2 stages, and 4 model chunks, we want an assignment of
# # layers to stages like (each list is a model chunk):
# # Stage 0: [0] [2] [4] [6]
# # Stage 1: [1] [3] [5] [7]
# # With 8 layers, 2 stages, and 2 virtual stages, we want an assignment of
# # layers to stages like (each list is a model chunk):
# # Stage 0: [0, 1] [4, 5]
# # Stage 1: [2, 3] [6, 7]
# offset = parallel_state.get_virtual_pipeline_model_parallel_rank() * (
# args.num_layers // args.virtual_pipeline_model_parallel_size
# ) + (parallel_state.get_pipeline_model_parallel_rank() * self.num_layers)
# else:
# # Each stage gets a contiguous set of layers.
# offset = parallel_state.get_pipeline_model_parallel_rank() * self.num_layers
offset = parallel_state.get_pipeline_model_parallel_rank() * self.num_layers
self.layers = torch.nn.ModuleList([build_layer(i + 1 + offset) for i in range(self.num_layers)])
if self.post_process:
# Final layer norm before output.
self.final_layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon, fp16=fused_fp16, bf16=fused_bf16)
def _get_layer(self, layer_number):
return self.layers[layer_number]
def _checkpointed_forward(self, hidden_states, attention_mask, encoder_output, enc_dec_attn_mask):
"""Forward method with activation checkpointing."""
def custom(start, end):
def custom_forward(*inputs):
x_ = inputs[0]
attention_mask = inputs[1]
encoder_output = inputs[2]
enc_dec_attn_mask = inputs[3]
for index in range(start, end):
layer = self._get_layer(index)
x_ = layer(x_, attention_mask, encoder_output, enc_dec_attn_mask)
return x_
return custom_forward
# Make sure memory is freed.
tensor_parallel.reset_checkpointed_activations_memory_buffer()
if self.activations_checkpoint_method == 'uniform':
# Uniformly divide the total number of Transformer layers and checkpoint
# the input activation of each divided chunk.
# A method to further reduce memory usage reducing checkpoints.
l = 0
while l < self.num_layers:
hidden_states = tensor_parallel.checkpoint(
custom(l, l + self.activations_checkpoint_num_layers),
hidden_states,
attention_mask,
encoder_output,
enc_dec_attn_mask,
)
l += self.activations_checkpoint_num_layers
elif self.activations_checkpoint_method == 'block':
# Checkpoint the input activation of only a set number of individual
# Transformer layers and skip the rest.
# A method fully use the device memory removing redundant re-computation.
for l in range(self.num_layers):
if l < self.activations_checkpoint_num_layers:
hidden_states = tensor_parallel.checkpoint(
custom(l, l + 1), hidden_states, attention_mask, encoder_output, enc_dec_attn_mask
)
else:
hidden_states = custom(l, l + 1)(hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
else:
raise ValueError("Invalid activation checkpoint method.")
return hidden_states
def set_input_tensor(self, input_tensor):
"""Set input tensor to be used instead of forward()'s input.
When doing pipeline parallelism the input from the previous
stage comes from communication, not from the input, so the
model's forward_step_func won't have it. This function is thus
used by internal code to bypass the input provided by the
forward_step_func"""
self.input_tensor = input_tensor
def forward(
self,
hidden_states,
attention_mask,
layer_past=None,
get_key_value=False,
encoder_output=None,
enc_dec_attn_mask=None,
):
# Checks.
if layer_past is not None:
assert get_key_value, 'for not None values in layer_past, ' 'expected get_key_value to be set'
if get_key_value:
assert self.activations_checkpoint_method is None, (
'get_key_value does not work with ' 'activation checkpointing'
)
if self.pre_process:
# Data format change to avoid explicit tranposes : [b s h] --> [s b h].
# If the input flag for fp32 residual connection is set, convert for float.
if self.fp32_residual_connection:
hidden_states = hidden_states.transpose(0, 1).contiguous().float()
# Otherwise, leave it as is.
else:
hidden_states = hidden_states.transpose(0, 1).contiguous()
else:
# See set_input_tensor()
hidden_states = self.input_tensor
if encoder_output is not None:
encoder_output = encoder_output.transpose(0, 1).contiguous()
if self.activations_checkpoint_method is not None:
hidden_states = self._checkpointed_forward(
hidden_states, attention_mask, encoder_output, enc_dec_attn_mask
)
else:
if get_key_value:
presents = []
for index in range(self.num_layers):
layer = self._get_layer(index)
past = None
if layer_past is not None:
past = layer_past[index]
hidden_states = layer(
hidden_states,
attention_mask,
encoder_output=encoder_output,
enc_dec_attn_mask=enc_dec_attn_mask,
layer_past=past,
get_key_value=get_key_value,
)
if get_key_value:
hidden_states, present = hidden_states
presents.append(present)
# Final layer norm.
if self.post_process:
# Reverting data format change [s b h] --> [b s h].
hidden_states = hidden_states.transpose(0, 1).contiguous()
output = self.final_layernorm(hidden_states)
else:
output = hidden_states
if get_key_value:
output = [output, presents]
return output

View file

@ -0,0 +1,166 @@
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for models."""
import math
import torch
from apex.transformer import parallel_state
def init_method_normal(sigma):
"""Init method based on N(0, sigma)."""
def init_(tensor):
return torch.nn.init.normal_(tensor, mean=0.0, std=sigma)
return init_
def scaled_init_method_normal(sigma, num_layers):
"""Init method based on N(0, sigma/sqrt(2*num_layers)."""
std = sigma / math.sqrt(2.0 * num_layers)
def init_(tensor):
return torch.nn.init.normal_(tensor, mean=0.0, std=std)
return init_
def attention_mask_func(attention_scores, attention_mask):
attention_scores.masked_fill_(attention_mask, -10000.0)
return attention_scores
def get_linear_layer(rows, columns, init_method):
"""Simple linear layer with weight initialization."""
layer = torch.nn.Linear(rows, columns)
init_method(layer.weight)
with torch.no_grad():
layer.bias.zero_()
return layer
@torch.jit.script
def gelu_impl(x):
"""OpenAI's gelu implementation."""
return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * x * (1.0 + 0.044715 * x * x)))
def openai_gelu(x):
return gelu_impl(x)
# This is actually Python equivalent of torch.nn.functional.gelu(), also with type hints for ONNX exporter
@torch.jit.script
def erf_gelu(x):
return x * 0.5 * (torch.erf(x / 1.41421).to(dtype=x.dtype) + torch.ones_like(x).to(dtype=x.dtype))
def average_losses_across_data_parallel_group(losses):
"""Reduce a tensor of losses across all GPUs."""
averaged_losses = torch.cat([loss.clone().detach().view(1) for loss in losses])
torch.distributed.all_reduce(averaged_losses, group=parallel_state.get_data_parallel_group())
averaged_losses = averaged_losses / torch.distributed.get_world_size(
group=parallel_state.get_data_parallel_group()
)
return averaged_losses
def get_ltor_masks_and_position_ids(data, eod_token, reset_position_ids, reset_attention_mask, eod_mask_loss):
"""Build masks and position id for left to right model."""
# Extract batch size and sequence length.
micro_batch_size, seq_length = data.size()
# Attention mask (lower triangular).
if reset_attention_mask:
att_mask_batch = micro_batch_size
else:
att_mask_batch = 1
attention_mask = torch.tril(torch.ones((att_mask_batch, seq_length, seq_length), device=data.device)).view(
att_mask_batch, 1, seq_length, seq_length
)
# Loss mask.
loss_mask = torch.ones(data.size(), dtype=torch.float, device=data.device)
if eod_mask_loss:
loss_mask[data == eod_token] = 0.0
# Position ids.
position_ids = torch.arange(seq_length, dtype=torch.long, device=data.device)
position_ids = position_ids.unsqueeze(0).expand_as(data)
# We need to clone as the ids will be modifed based on batch index.
if reset_position_ids:
position_ids = position_ids.clone()
if reset_position_ids or reset_attention_mask:
# Loop through the batches:
for b in range(micro_batch_size):
# Find indecies where EOD token is.
eod_index = position_ids[b, data[b] == eod_token]
# Detach indecies from positions if going to modify positions.
if reset_position_ids:
eod_index = eod_index.clone()
# Loop through EOD indicies:
prev_index = 0
for j in range(eod_index.size()[0]):
i = eod_index[j]
# Mask attention loss.
if reset_attention_mask:
attention_mask[b, 0, (i + 1) :, : (i + 1)] = 0
# Reset positions.
if reset_position_ids:
position_ids[b, (i + 1) :] -= i + 1 - prev_index
prev_index = i + 1
# Convert attention mask to binary:
attention_mask = attention_mask < 0.5
return attention_mask, loss_mask, position_ids
class AutocastModuleWrapper(torch.nn.Module):
def __init__(self, fp16=False, bf16=False):
super(AutocastModuleWrapper, self).__init__()
# TODO: @titu1994 Maybe useful to make this adaptive - prefer bf16 over fp16 when both flags are set and both dtypes are supported?
assert not (fp16 and bf16)
if fp16 or bf16:
self.precision = torch.float16 if fp16 else torch.bfloat16
self.use_autocast = True
else:
self.use_autocast = False
# class or function to autocast should be defined
self.func = None
def autocast_forward(self, *args):
assert self.func is not None
if self.use_autocast:
if isinstance(self.func, torch.autograd.function.FunctionMeta):
fwd = torch.cuda.amp.custom_fwd(cast_inputs=self.precision)(self.func.apply)
else:
fwd = torch.cuda.amp.custom_fwd(cast_inputs=self.precision)(self.func)
return fwd(*args)
else:
if isinstance(self.func, torch.autograd.function.FunctionMeta):
return self.func.apply(*args)
else:
return self.func(*args)

View file

@ -31,6 +31,13 @@ from nemo.utils import logging
__all__ = ['get_tokenizer', 'get_tokenizer_list']
megatron_tokenizer_model_map = {
'BertWordPieceLowerCase': 'megatron-bert-345m-uncased',
'BertWordPieceCase': 'megatron-bert-345m-cased',
'GPT2BPETokenizer': 'megatron-gpt-345m',
}
def get_tokenizer_list() -> List[str]:
"""
Returns all all supported tokenizer names
@ -57,6 +64,7 @@ def get_tokenizer(
tokenizer_name: str,
tokenizer_model: Optional[str] = None,
vocab_file: Optional[str] = None,
merges_file: Optional[str] = None,
special_tokens: Optional[Dict[str, str]] = None,
use_fast: Optional[bool] = False,
bpe_dropout: Optional[float] = 0.0,
@ -84,6 +92,9 @@ def get_tokenizer(
vocab_file = nemo.collections.nlp.modules.common.megatron.megatron_utils.get_megatron_vocab_file(
tokenizer_name
)
merges_file = nemo.collections.nlp.modules.common.megatron.megatron_utils.get_megatron_merges_file(
tokenizer_name
)
tokenizer_name = get_megatron_tokenizer(tokenizer_name)
if tokenizer_name == 'sentencepiece':
@ -101,7 +112,11 @@ def get_tokenizer(
f"Getting HuggingFace AutoTokenizer with pretrained_model_name: {tokenizer_name}, vocab_file: {vocab_file}, special_tokens_dict: {special_tokens_dict}, and use_fast: {use_fast}"
)
return AutoTokenizer(
pretrained_model_name=tokenizer_name, vocab_file=vocab_file, **special_tokens_dict, use_fast=use_fast
pretrained_model_name=tokenizer_name,
vocab_file=vocab_file,
merges_file=merges_file,
**special_tokens_dict,
use_fast=use_fast,
)
@ -110,6 +125,7 @@ def get_nmt_tokenizer(
model_name: Optional[str] = None,
tokenizer_model: Optional[str] = None,
vocab_file: Optional[str] = None,
merges_file: Optional[str] = None,
special_tokens: Optional[Dict[str, str]] = None,
use_fast: Optional[bool] = False,
bpe_dropout: Optional[float] = 0.0,
@ -138,10 +154,14 @@ def get_nmt_tokenizer(
elif library == 'huggingface':
logging.info(f'Getting HuggingFace AutoTokenizer with pretrained_model_name: {model_name}')
return AutoTokenizer(
pretrained_model_name=model_name, vocab_file=vocab_file, **special_tokens_dict, use_fast=use_fast
pretrained_model_name=model_name,
vocab_file=vocab_file,
merges_file=merges_file,
**special_tokens_dict,
use_fast=use_fast,
)
elif library == 'sentencepiece':
logging.info(f'Getting SentencePiece with model: {model_name}')
logging.info(f'Getting SentencePiece with model: {tokenizer_model}')
return nemo.collections.common.tokenizers.sentencepiece_tokenizer.SentencePieceTokenizer(
model_path=tokenizer_model, special_tokens=special_tokens_dict
)
@ -149,10 +169,12 @@ def get_nmt_tokenizer(
logging.info(f'Using byte-level tokenization')
return ByteLevelTokenizer()
elif library == 'megatron':
if model_name in megatron_tokenizer_model_map:
model_name = megatron_tokenizer_model_map[model_name]
logging.info(
f'Getting Megatron tokenizer for pretrained model name: {model_name} and custom vocab file: {vocab_file}'
)
return get_tokenizer(tokenizer_name=model_name, vocab_file=vocab_file)
return get_tokenizer(tokenizer_name=model_name, vocab_file=vocab_file, merges_file=merges_file)
else:
raise NotImplementedError(
'Currently we only support "yttm", "huggingface", "sentencepiece", "megatron", and "byte-level" tokenizer libraries.'

View file

@ -13,23 +13,34 @@
# limitations under the License.
import os
import shutil
import tempfile
from typing import Any, Dict, List, Optional, Union
import pytorch_lightning as pl
import torch
from megatron import mpu
from megatron.checkpointing import get_checkpoint_version, set_checkpoint_version
from megatron.initialize import _set_random_seed
from pytorch_lightning.core.lightning import LightningModule
from apex.transformer import parallel_state
from pytorch_lightning.overrides import LightningDistributedModule
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
from pytorch_lightning.plugins.precision.native_amp import NativeMixedPrecisionPlugin
from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin
from pytorch_lightning.plugins.training_type.ddp import DDPPlugin
from pytorch_lightning.trainer.connectors.checkpoint_connector import CheckpointConnector
from pytorch_lightning.trainer.trainer import Trainer
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.cloud_io import atomic_save
from pytorch_lightning.utilities.enums import GradClipAlgorithmType
from torch.nn.modules.module import Module
from torch.nn.parallel import DistributedDataParallel
from torch.optim.optimizer import Optimizer
from nemo.collections.nlp.modules.common.megatron.megatron_bert import MegatronBertEncoder
from nemo.collections.nlp.modules.common.megatron.megatron_encoder import MegatronEncoderModule
from nemo.collections.nlp.modules.common.megatron.clip_grads import clip_grad_norm_fp32
from nemo.collections.nlp.modules.common.megatron.megatron_bert import (
get_megatron_checkpoint_version,
set_megatron_checkpoint_version,
)
from nemo.core.connectors.save_restore_connector import SaveRestoreConnector
from nemo.utils import AppState, logging
@ -45,21 +56,22 @@ class NLPDDPPlugin(DDPPlugin):
num_nodes: int = 1,
cluster_environment: ClusterEnvironment = None,
sync_batchnorm: bool = False,
checkpoint_io: Optional[CheckpointIO] = None,
**kwargs: Union[Any, Dict[str, Any]],
) -> None:
super().__init__(parallel_devices, num_nodes, cluster_environment, sync_batchnorm, **kwargs)
super().__init__(parallel_devices, num_nodes, cluster_environment, checkpoint_io, sync_batchnorm, **kwargs)
def init_ddp_connection(self, global_rank: int = None, world_size: int = None) -> None:
def setup_distributed(self, global_rank: int = None, world_size: int = None) -> None:
# call PTL init ddp
super().init_ddp_connection()
super().setup_distributed()
# init model parallel if needed
app_state = AppState()
if app_state.model_parallel_size is not None:
if self.lightning_module.has_megatron_encoder and not self.lightning_module.is_model_parallel_initialized:
self.init_model_parallel(app_state.global_rank, app_state.world_size)
self.init_model_parallel(app_state.global_rank, app_state.world_size)
# if self.lightning_module.has_megatron_encoder and not self.lightning_module.is_model_parallel_initialized:
# self.init_model_parallel(app_state.global_rank, app_state.world_size)
def start_training(self, trainer: 'Trainer') -> None:
""" PTL Hook that is called after DPP is initialized. """
@ -73,7 +85,7 @@ class NLPDDPPlugin(DDPPlugin):
if not hasattr(p, 'model_parallel'):
p.model_parallel = False
if get_checkpoint_version() is not None:
if get_megatron_checkpoint_version() is not None:
# megatron checkpoint already restored
pass
elif trainer.checkpoint_connector.resume_checkpoint_path is not None:
@ -92,14 +104,14 @@ class NLPDDPPlugin(DDPPlugin):
'checkpoint_version', None
)
if checkpoint_version is not None:
set_checkpoint_version(checkpoint_version)
set_megatron_checkpoint_version(checkpoint_version)
else:
logging.warning('Megatron-lm checkpoint version not found. Setting checkpoint_version to 0.')
set_checkpoint_version(0)
set_megatron_checkpoint_version(0)
else:
self.lightning_module.restore_megatron_encoder_weights()
else:
if get_checkpoint_version() is not None:
if get_megatron_checkpoint_version() is not None:
# megatron checkpoint already restored
pass
else:
@ -117,7 +129,7 @@ class NLPDDPPlugin(DDPPlugin):
if self.has_megatron_encoder:
# check megatron checkpoint version
checkpoint_version = get_checkpoint_version()
checkpoint_version = get_megatron_checkpoint_version()
if checkpoint_version is None:
raise ValueError("Unable to find megatron checkpoint version.")
@ -125,7 +137,7 @@ class NLPDDPPlugin(DDPPlugin):
def configure_ddp(self):
""" Override LightningModule ddp if using model parallel.
Sets find_unused_parameters to True.
Sets find_unused_parameters to False to use activation-checkpoint-recomputation.
"""
app_state = AppState()
@ -142,7 +154,7 @@ class NLPDDPPlugin(DDPPlugin):
device_ids=device_ids,
output_device=device_ids[0],
process_group=app_state.data_parallel_group,
find_unused_parameters=True,
find_unused_parameters=False,
**self._ddp_kwargs,
)
@ -163,17 +175,35 @@ class NLPDDPPlugin(DDPPlugin):
# after initializing DDP with PTL.
if app_state.model_parallel_size is not None:
if torch.distributed.is_initialized():
mpu.initialize_model_parallel(app_state.model_parallel_size)
app_state.model_parallel_group = mpu.get_model_parallel_group()
app_state.data_parallel_group = mpu.get_data_parallel_group()
app_state.model_parallel_rank = mpu.get_tensor_model_parallel_rank()
app_state.data_parallel_rank = mpu.get_data_parallel_rank()
app_state.data_parallel_size = mpu.get_data_parallel_world_size()
parallel_state.initialize_model_parallel(app_state.model_parallel_size)
app_state.model_parallel_group = parallel_state.get_tensor_model_parallel_group()
app_state.data_parallel_group = parallel_state.get_data_parallel_group()
app_state.model_parallel_rank = parallel_state.get_tensor_model_parallel_rank()
app_state.data_parallel_rank = parallel_state.get_data_parallel_rank()
app_state.data_parallel_size = parallel_state.get_data_parallel_world_size()
logging.info(f'mp_rank: {app_state.model_parallel_rank}')
logging.info(f'dp_rank: {app_state.data_parallel_rank}')
seed = os.environ.get("PL_GLOBAL_SEED", 1234)
# random seed must be set for megatron model parallel init
_set_random_seed(seed)
def save_checkpoint(self, checkpoint: Dict[str, Any], filepath: str) -> None:
"""Save model/training states as a checkpoint file through state-dump and file-write.
Args:
checkpoint: dict containing model and trainer state
filepath: write-target file's path
"""
app_state = AppState()
# dump states as a checkpoint dictionary object
# TrainingTypePlugin.on_save() just seems to return the same thing.
# checkpoint = self.on_save(checkpoint)
if self.is_global_zero or app_state.data_parallel_rank == 0:
try:
# write the checkpoint dictionary on the file
atomic_save(checkpoint, filepath)
except AttributeError as err:
key = pl.LightningModule.CHECKPOINT_HYPER_PARAMS_KEY
checkpoint.pop(key, None)
rank_zero_warn(f"Warning, `{key}` dropped from checkpoint. An attribute is not picklable: {err}")
atomic_save(checkpoint, filepath)
@property
def distributed_sampler_kwargs(self):
@ -195,18 +225,17 @@ class NLPCheckpointConnector(CheckpointConnector):
""" Override PTL CheckpointConnector to support model parallel checkpoints from Megatron-LM.
"""
def __init__(self, trainer):
super().__init__(trainer)
def __init__(self, trainer, resume_from_checkpoint):
super().__init__(trainer, resume_from_checkpoint)
def save_checkpoint(self, filepath, weights_only: bool):
def save_checkpoint(self, filepath, weights_only: bool = False) -> None:
"""Slightly modified version of PyTorch Lightning's save_checkpoint.
Accounts for model parallel training.
Save model/training states as a checkpoint file through state-dump and file-write.
Args:
filepath ([str]): [description]
weights_only (bool): [description]
Returns:
[type]: [description]
filepath: write-target file's path
weights_only: saving model weights only
"""
app_state = AppState()
if app_state.model_parallel_size is not None:
@ -214,22 +243,137 @@ class NLPCheckpointConnector(CheckpointConnector):
dirname = os.path.dirname(filepath)
basename = os.path.basename(filepath)
filepath = f'{dirname}/mp_rank_{app_state.model_parallel_rank:02d}/{basename}'
# dump states as a checkpoint dictionary object
checkpoint = self.dump_checkpoint(weights_only)
_checkpoint = self.dump_checkpoint(weights_only)
# each model parallel rank needs to save a copy of its model
if app_state.data_parallel_rank == 0:
# write the checkpoint dictionary on the file
if self.trainer.accelerator_backend:
checkpoint = self.trainer.accelerator_backend.on_save(checkpoint)
try:
atomic_save(checkpoint, filepath)
except AttributeError as err:
if LightningModule.CHECKPOINT_HYPER_PARAMS_KEY in checkpoint:
del checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY]
rank_zero_warn(
'Warning, `hyper_parameters` dropped from checkpoint.' f' An attribute is not picklable {err}'
)
atomic_save(checkpoint, filepath)
return None
self.trainer.accelerator.save_checkpoint(_checkpoint, filepath)
else:
super().save_checkpoint(filepath, weights_only)
class NLPSaveRestoreConnector(SaveRestoreConnector):
def __init__(self) -> None:
super().__init__()
def save_to(self, model, save_path: str):
app_state = AppState()
if app_state.model_parallel_size is not None and app_state.model_parallel_size > 1:
dir_name = os.path.dirname(save_path)
# first we save the weights for each model parallel rank
if app_state.data_parallel_rank == 0:
mp_model_weights = os.path.join(
dir_name, f'mp_rank_{app_state.model_parallel_rank:02d}_' + self.model_weights_ckpt
)
self._save_state_dict_to_disk(model.state_dict(), mp_model_weights)
torch.distributed.barrier()
# create nemo file from folder with all mp_ranks checkpoints
if app_state.model_parallel_rank == 0 and app_state.data_parallel_rank == 0:
with tempfile.TemporaryDirectory() as tmpdir:
# move weights to the tmpdir
for mp_rank in range(app_state.model_parallel_size):
os.makedirs(os.path.join(tmpdir, f'mp_rank_{mp_rank:02d}'))
mp_model_weights = os.path.join(dir_name, f'mp_rank_{mp_rank:02d}_' + self.model_weights_ckpt)
shutil.move(
mp_model_weights, os.path.join(tmpdir, f'mp_rank_{mp_rank:02d}', self.model_weights_ckpt)
)
# create config and artifacts in tmpdir
config_yaml = os.path.join(tmpdir, self.model_config_yaml)
model.to_config_file(path2yaml_file=config_yaml)
if hasattr(model, 'artifacts') and model.artifacts is not None:
self._handle_artifacts(model, nemo_file_folder=tmpdir)
self._update_artifact_paths(model, path2yaml_file=config_yaml)
# create tar file
self._make_nemo_file_from_folder(save_path, tmpdir)
else:
return super().save_to(model, save_path)
class NLPNativeMixedPrecisionPlugin(NativeMixedPrecisionPlugin):
def __init__(self, init_scale: float = 2 ** 32, growth_interval: int = 1000) -> None:
super().__init__(precision=16)
self.scaler = torch.cuda.amp.GradScaler(init_scale=init_scale, growth_interval=growth_interval)
def clip_gradients(
self,
optimizer: Optimizer,
clip_val: Union[int, float],
gradient_clip_algorithm: GradClipAlgorithmType,
model: Optional[Module],
) -> None:
"""Override PTL gradient clipping.
Do nothing because we've already clipped gradients in `on_before_optimizer_step` hook.
"""
pass
class NLPNativeBfloat16PrecisionPlugin(NativeMixedPrecisionPlugin):
def __init__(self) -> None:
super().__init__(precision='bf16')
def clip_gradients(
self,
optimizer: Optimizer,
clip_val: Union[int, float],
gradient_clip_algorithm: GradClipAlgorithmType,
model: Optional[Module],
) -> None:
"""Override PTL gradient clipping.
Model parallel models require gradient clipping from megatron-lm.
"""
if clip_val is None:
return
clip_val = float(clip_val)
if clip_val <= 0:
return
app_state = AppState()
if app_state.model_parallel_size is not None:
parameters = model.parameters()
clip_grad_norm_fp32(parameters=parameters, max_norm=clip_val)
else:
return super().clip_gradients(
optimizer, clip_val, gradient_clip_algorithm=gradient_clip_algorithm, model=model
)
class NLPPrecisionPlugin(PrecisionPlugin):
def __init__(self) -> None:
super().__init__()
def clip_gradients(
self,
optimizer: Optimizer,
clip_val: Union[int, float],
gradient_clip_algorithm: GradClipAlgorithmType,
model: Optional[Module],
) -> None:
"""Override PTL gradient clipping.
Model parallel models require gradient clipping from megatron-lm.
"""
if clip_val is None:
return
clip_val = float(clip_val)
if clip_val <= 0:
return
app_state = AppState()
if app_state.model_parallel_size is not None:
parameters = model.parameters()
clip_grad_norm_fp32(parameters=parameters, max_norm=clip_val)
else:
return super().clip_gradients(
optimizer, clip_val, gradient_clip_algorithm=gradient_clip_algorithm, model=model
)

View file

@ -150,7 +150,7 @@ class FastSpeech2Model(SpectrogramGenerator):
self.log(name="train_loss", value=total_loss)
return {"loss": total_loss, "outputs": [spec, mel]}
def training_epoch_end(self, training_step_outputs):
def training_epoch_end(self, outputs):
if self.log_train_images and self.logger is not None and self.logger.experiment is not None:
tb_logger = self.logger.experiment
if isinstance(self.logger, LoggerCollection):
@ -158,7 +158,7 @@ class FastSpeech2Model(SpectrogramGenerator):
if isinstance(logger, TensorBoardLogger):
tb_logger = logger.experiment
break
spec_target, spec_predict = training_step_outputs[0]["outputs"]
spec_target, spec_predict = outputs[0]["outputs"]
tb_logger.add_image(
"train_mel_target",
plot_spectrogram_to_numpy(spec_target[0].data.cpu().numpy()),
@ -171,6 +171,8 @@ class FastSpeech2Model(SpectrogramGenerator):
)
self.log_train_images = False
return super().training_epoch_end(outputs)
def validation_step(self, batch, batch_idx):
f, fl, t, tl, _, _, _ = batch
spec, spec_len = self.audio_to_melspec_preprocessor(f, fl)

View file

@ -126,7 +126,7 @@ class HifiGanModel(Vocoder, Exportable):
def convert_spectrogram_to_audio(self, spec: 'torch.tensor') -> 'torch.tensor':
return self(spec=spec).squeeze(1)
def training_step(self, batch, batch_idx, optimizer_idx):
def training_step(self, batch, batch_idx):
# if in finetune mode the mels are pre-computed using a
# spectrogram generator
if self.input_as_mel:

View file

@ -15,6 +15,7 @@
"""Interfaces common to all Neural Modules and Models."""
import hashlib
import inspect
import traceback
from abc import ABC, abstractmethod
from contextlib import contextmanager
@ -429,7 +430,7 @@ class Typing(ABC):
class Serialization(ABC):
@classmethod
def from_config_dict(cls, config: 'DictConfig'):
def from_config_dict(cls, config: 'DictConfig', trainer: Optional['Trainer'] = None):
"""Instantiates object using DictConfig-based configuration"""
# Resolve the config dict
if _HAS_HYDRA:
@ -470,7 +471,12 @@ class Serialization(ABC):
imported_cls = cls
try:
instance = imported_cls(cfg=config)
accepts_trainer = Serialization._inspect_signature_for_trainer(imported_cls)
if accepts_trainer:
instance = imported_cls(cfg=config, trainer=trainer)
else:
instance = imported_cls(cfg=config)
except Exception as e:
imported_cls_tb = traceback.format_exc()
instance_init_error = str(e)
@ -484,8 +490,14 @@ class Serialization(ABC):
f"Falling back to `cls`.\n"
f"{imported_cls_tb}"
)
instance = cls(cfg=config, trainer=trainer)
try:
instance = cls(cfg=config)
accepts_trainer = Serialization._inspect_signature_for_trainer(cls)
if accepts_trainer:
instance = cls(cfg=config, trainer=trainer)
else:
instance = cls(cfg=config)
except Exception as e:
if imported_cls_tb is not None:
logging.error(f"Instance failed restore_from due to: {instance_init_error}")
@ -515,6 +527,17 @@ class Serialization(ABC):
'to_config_dict() can currently only return object._cfg but current object does not have it.'
)
@classmethod
def _inspect_signature_for_trainer(cls, check_cls):
if hasattr(check_cls, '__init__'):
signature = inspect.signature(check_cls.__init__)
if 'trainer' in signature.parameters:
return True
else:
return False
else:
return False
class FileIO(ABC):
def save_to(self, save_path: str):
@ -529,6 +552,8 @@ class FileIO(ABC):
map_location: Optional['torch.device'] = None,
strict: bool = True,
return_config: bool = False,
trainer: Optional['Trainer'] = None,
save_restore_connector: SaveRestoreConnector = None,
):
"""Restores module/model with weights"""
raise NotImplementedError()
@ -644,6 +669,7 @@ class Model(Typing, Serialization, FileIO):
map_location: Optional['torch.device'] = None,
strict: bool = True,
return_config: bool = False,
trainer: Optional['Trainer'] = None,
save_restore_connector: SaveRestoreConnector = None,
):
"""
@ -708,6 +734,7 @@ class Model(Typing, Serialization, FileIO):
map_location=map_location,
strict=strict,
return_config=return_config,
trainer=trainer,
save_restore_connector=save_restore_connector,
)
return instance

View file

@ -201,7 +201,6 @@ class ModelPT(LightningModule, Model):
return self._save_restore_connector.register_artifact(self, config_path, src, verify_src_exists)
@rank_zero_only
def save_to(self, save_path: str):
"""
Saves model instance (weights and configuration) into .nemo file
@ -215,12 +214,24 @@ class ModelPT(LightningModule, Model):
save_path: Path to .nemo file where model instance should be saved
"""
app_state = AppState()
# Add NeMo rank check as well
if not is_global_rank_zero():
return
else:
if app_state.model_parallel_size is not None:
if app_state.model_parallel_size > 1:
if isinstance(self._save_restore_connector, SaveRestoreConnector):
raise ValueError(
'Default NeMo SaveRestoreConnector will not work in model parallel mode. You should use a connector which supports model parallel mode, such as NLPSaveRestoreConnector in NLP. You can also you custom one.'
)
save_path = os.path.abspath(os.path.expanduser(save_path))
# connector checks for ranks properly, no need to check here
self._save_restore_connector.save_to(self, save_path)
else:
if not is_global_rank_zero():
return
else:
save_path = os.path.abspath(os.path.expanduser(save_path))
self._save_restore_connector.save_to(self, save_path)
@classmethod
def restore_from(
@ -231,6 +242,7 @@ class ModelPT(LightningModule, Model):
strict: bool = True,
return_config: bool = False,
save_restore_connector: SaveRestoreConnector = None,
trainer: Optional[Trainer] = None,
):
"""
Restores model instance (weights and configuration) from .nemo file.
@ -244,6 +256,8 @@ class ModelPT(LightningModule, Model):
strict: Passed to load_state_dict. By default True.
return_config: If set to true, will return just the underlying config of the restored
model as an OmegaConf DictConfig object without instantiating the model.
trainer: Optional, a pytorch lightning Trainer object that will be forwarded to the
instantiated model's constructor.
save_restore_connector (SaveRestoreConnector): Can be overrided to add custom save and restore logic.
Example:
@ -268,7 +282,7 @@ class ModelPT(LightningModule, Model):
cls.update_save_restore_connector(save_restore_connector)
instance = cls._save_restore_connector.restore_from(
cls, restore_path, override_config_path, map_location, strict, return_config
cls, restore_path, override_config_path, map_location, strict, return_config, trainer
)
if isinstance(instance, ModelPT):
instance._save_restore_connector = save_restore_connector

View file

@ -48,6 +48,7 @@ class TrainerConfig:
tpu_cores: Optional[Any] = None
log_gpu_memory: Optional[str] = None
progress_bar_refresh_rate: int = 1
enable_progress_bar: bool = True
overfit_batches: Any = 0.0
track_grad_norm: Any = -1
check_val_every_n_epoch: int = 1
@ -65,11 +66,10 @@ class TrainerConfig:
log_every_n_steps: int = 50
accelerator: Optional[str] = None
sync_batchnorm: bool = False
precision: int = 32
precision: Any = 32
weights_summary: Optional[str] = "full" # ModelSummary.MODE_DEFAULT
weights_save_path: Optional[str] = None
num_sanity_val_steps: int = 2
truncated_bptt_steps: Optional[int] = None
resume_from_checkpoint: Optional[str] = None
profiler: Optional[Any] = None
benchmark: bool = False
@ -77,11 +77,12 @@ class TrainerConfig:
reload_dataloaders_every_epoch: bool = False
auto_lr_find: Any = False
replace_sampler_ddp: bool = True
detect_anomaly: bool = False
terminate_on_nan: bool = False
auto_scale_batch_size: Any = False
prepare_data_per_node: bool = True
amp_backend: str = 'native'
amp_level: str = 'O2' # backward compatible, todo: remove in v1.0.0
amp_level: Optional[str] = None
plugins: Optional[Any] = None # Optional[Union[str, list]]
move_metrics_to_cpu: bool = False
multiple_trainloader_mode: str = 'max_size_cycle'

View file

@ -51,6 +51,18 @@ class WarmupHoldSchedulerParams(WarmupSchedulerParams):
min_lr: float = 0.0
@dataclass
class WarmupAnnealingHoldSchedulerParams(WarmupSchedulerParams):
"""
Base configuration for all schedulers.
It is not derived from Config as it is not a NeMo object (and in particular it doesn't need a name).
"""
constant_steps: Optional[float] = None
constant_ratio: Optional[float] = None
min_lr: float = 0.0
@dataclass
class SquareAnnealingParams(WarmupSchedulerParams):
"""
@ -72,7 +84,7 @@ class SquareRootAnnealingParams(WarmupSchedulerParams):
@dataclass
class CosineAnnealingParams(WarmupSchedulerParams):
class CosineAnnealingParams(WarmupAnnealingHoldSchedulerParams):
"""
Cosine Annealing parameter config
It is not derived from Config as it is not a NeMo object (and in particular it doesn't need a name).
@ -98,7 +110,7 @@ class WarmupAnnealingParams(WarmupSchedulerParams):
It is not derived from Config as it is not a NeMo object (and in particular it doesn't need a name).
"""
warmup_ratio: 0.0
warmup_ratio: float = 0.0
@dataclass
@ -240,6 +252,7 @@ AVAILABLE_SCHEDULER_PARAMS = {
'SchedulerParams': SchedulerParams,
'WarmupPolicyParams': WarmupSchedulerParams,
'WarmupHoldPolicyParams': WarmupHoldSchedulerParams,
'WarmupAnnealingHoldSchedulerParams': WarmupAnnealingHoldSchedulerParams,
'SquareAnnealingParams': SquareAnnealingParams,
'SquareRootAnnealingParams': SquareRootAnnealingParams,
'InverseSquareRootAnnealingParams': InverseSquareRootAnnealingParams,

View file

@ -23,9 +23,11 @@ from typing import Optional, Union
import torch
from omegaconf import DictConfig, OmegaConf
from omegaconf.omegaconf import open_dict
from pytorch_lightning.trainer.trainer import Trainer
from nemo.utils import logging, model_utils
from nemo.utils.app_state import AppState
from nemo.utils.get_rank import is_global_rank_zero
class SaveRestoreConnector:
@ -47,16 +49,19 @@ class SaveRestoreConnector:
save_path: Path to .nemo file where model instance should be saved
"""
with tempfile.TemporaryDirectory() as tmpdir:
config_yaml = os.path.join(tmpdir, self.model_config_yaml)
model_weights = os.path.join(tmpdir, self.model_weights_ckpt)
model.to_config_file(path2yaml_file=config_yaml)
if hasattr(model, 'artifacts') and model.artifacts is not None:
self._handle_artifacts(model, nemo_file_folder=tmpdir)
# We should not update self._cfg here - the model can still be in use
self._update_artifact_paths(model, path2yaml_file=config_yaml)
self._save_state_dict_to_disk(model.state_dict(), model_weights)
self._make_nemo_file_from_folder(filename=save_path, source_dir=tmpdir)
if is_global_rank_zero():
with tempfile.TemporaryDirectory() as tmpdir:
config_yaml = os.path.join(tmpdir, self.model_config_yaml)
model_weights = os.path.join(tmpdir, self.model_weights_ckpt)
model.to_config_file(path2yaml_file=config_yaml)
if hasattr(model, 'artifacts') and model.artifacts is not None:
self._handle_artifacts(model, nemo_file_folder=tmpdir)
# We should not update self._cfg here - the model can still be in use
self._update_artifact_paths(model, path2yaml_file=config_yaml)
self._save_state_dict_to_disk(model.state_dict(), model_weights)
self._make_nemo_file_from_folder(filename=save_path, source_dir=tmpdir)
else:
return
def restore_from(
self,
@ -66,6 +71,7 @@ class SaveRestoreConnector:
map_location: Optional[torch.device] = None,
strict: bool = True,
return_config: bool = False,
trainer: Trainer = None,
):
"""
Restores model instance (weights and configuration) into .nemo file
@ -125,7 +131,7 @@ class SaveRestoreConnector:
return instance
else:
app_state = AppState()
if app_state.model_parallel_rank is not None:
if app_state.model_parallel_rank is not None and app_state.model_parallel_size > 1:
model_weights = self._inject_model_parallel_rank_for_ckpt(tmpdir, self.model_weights_ckpt)
else:
model_weights = os.path.join(tmpdir, self.model_weights_ckpt)
@ -133,7 +139,7 @@ class SaveRestoreConnector:
os.chdir(cwd)
# get the class
calling_cls._set_model_restore_state(is_being_restored=True, folder=tmpdir)
instance = calling_cls.from_config_dict(config=conf)
instance = calling_cls.from_config_dict(config=conf, trainer=trainer)
instance = instance.to(map_location)
# add load_state_dict override
instance.load_state_dict(
@ -369,6 +375,8 @@ class SaveRestoreConnector:
@staticmethod
def _make_nemo_file_from_folder(filename, source_dir):
dirname = os.path.dirname(filename)
os.makedirs(dirname, exist_ok=True)
with tarfile.open(filename, "w:gz") as tar:
tar.add(source_dir, arcname=".")

View file

@ -160,6 +160,84 @@ class WarmupHoldPolicy(WarmupPolicy):
return self._get_lr(step)
class WarmupAnnealHoldPolicy(_LRScheduler):
"""Adds warmup kwargs and warmup logic to lr policy.
All arguments should be passed as kwargs for clarity,
Args:
warmup_steps: Number of training steps in warmup stage
warmup_ratio: Ratio of warmup steps to total steps
max_steps: Total number of steps while training or `None` for
infinite training
min_lr: Minimum lr to hold the learning rate after decay at.
constant_steps: Number of steps to keep lr constant at.
constant_ratio: Ratio of steps to keep lr constant.
"""
def __init__(
self,
optimizer,
*,
warmup_steps=None,
warmup_ratio=None,
constant_steps=None,
constant_ratio=None,
max_steps=None,
min_lr=0.0,
last_epoch=-1,
):
assert not (
warmup_steps is not None and warmup_ratio is not None
), "Either use particular number of step or ratio"
assert warmup_ratio is None or max_steps is not None, "If there is a ratio, there should be a total steps"
# It is necessary to assign all attributes *before* __init__,
# as class is wrapped by an inner class.
self.max_steps = max_steps
if warmup_steps is not None:
self.warmup_steps = warmup_steps
elif warmup_ratio is not None:
self.warmup_steps = int(warmup_ratio * max_steps)
else:
self.warmup_steps = 0
if constant_steps is not None:
self.constant_steps = constant_steps
self.decay_steps = max_steps - (self.constant_steps + self.warmup_steps)
assert self.decay_steps > 0
elif constant_ratio is not None:
self.constant_steps = int(constant_ratio * max_steps)
self.decay_steps = max_steps - (self.constant_steps + self.warmup_steps)
assert self.decay_steps > 0
else:
self.constant_steps = 0
self.decay_steps = max_steps - self.warmup_steps
self.min_lr = min_lr
super().__init__(optimizer, last_epoch)
def get_lr(self):
if not self._get_lr_called_within_step:
warnings.warn(
"To get the last learning rate computed by the scheduler, please use `get_last_lr()`.", UserWarning
)
step = self.last_epoch
if step <= self.warmup_steps:
lr_val = (step + 1) / (self.warmup_steps + 1)
return [initial_lr * lr_val for initial_lr in self.base_lrs]
if step > self.max_steps:
return [self.min_lr for _ in self.base_lrs]
return self._get_lr(step)
def _get_lr(self, step):
"""Simple const lr policy"""
return self.base_lrs
def _squareroot_annealing(initial_lr, step, max_steps, min_lr):
mult = ((max_steps - step) / max_steps) ** 0.5
out_lr = initial_lr * mult
@ -180,6 +258,30 @@ def _cosine_annealing(initial_lr, step, max_steps, min_lr):
return out_lr
def _linear_warmup_with_cosine_annealing(max_lr, warmup_steps, step, decay_steps, min_lr):
assert max_lr > min_lr
# Use linear warmup for the initial part.
if warmup_steps > 0 and step <= warmup_steps:
return max_lr * float(step) / float(warmup_steps)
# For any steps larger than `decay_steps`, use `min_lr`.
if step > warmup_steps + decay_steps:
return min_lr
# If we are done with the warmup period, use the decay style.
num_steps_ = step - warmup_steps
decay_steps_ = decay_steps
decay_ratio = float(num_steps_) / float(decay_steps_)
assert decay_ratio >= 0.0
assert decay_ratio <= 1.0
delta_lr = max_lr - min_lr
coeff = 0.5 * (math.cos(math.pi * decay_ratio) + 1.0)
return min_lr + coeff * delta_lr
def _poly_decay(initial_lr, step, decay_steps, power, min_lr, cycle):
if cycle:
multiplier = 1.0 if step == 0 else math.ceil(step / decay_steps)
@ -221,7 +323,7 @@ class SquareRootAnnealing(WarmupPolicy):
return new_lrs
class CosineAnnealing(WarmupPolicy):
class CosineAnnealing(WarmupAnnealHoldPolicy):
def __init__(self, optimizer, *, max_steps, min_lr=0, last_epoch=-1, **kwargs):
super().__init__(optimizer=optimizer, max_steps=max_steps, last_epoch=last_epoch, min_lr=min_lr, **kwargs)
@ -232,15 +334,27 @@ class CosineAnnealing(WarmupPolicy):
f"{self} received an initial learning rate that was lower than the minimum learning rate."
)
new_lrs = [
_cosine_annealing(
initial_lr=initial_lr,
step=step - self.warmup_steps,
max_steps=self.max_steps - self.warmup_steps,
min_lr=self.min_lr,
)
for initial_lr in self.base_lrs
]
if self.constant_steps is None:
new_lrs = [
_cosine_annealing(
initial_lr=initial_lr,
step=step - self.warmup_steps,
max_steps=self.max_steps - self.warmup_steps,
min_lr=self.min_lr,
)
for initial_lr in self.base_lrs
]
else:
new_lrs = [
_linear_warmup_with_cosine_annealing(
max_lr=self.base_lrs[0],
warmup_steps=self.warmup_steps,
step=step,
decay_steps=self.decay_steps,
min_lr=self.min_lr,
)
for _ in self.base_lrs
]
return new_lrs
@ -587,7 +701,17 @@ def prepare_lr_scheduler(
# Compute effective num max_steps
num_samples = len(train_dataloader.dataset)
batch_size = train_dataloader.batch_size
# TODO: not sure if this will be the correct LR schedule for Megatron
# we may need to override ModelPT setup_optimization
if train_dataloader.batch_size is not None:
batch_size = train_dataloader.batch_size
elif train_dataloader.batch_sampler is not None:
if train_dataloader.batch_sampler.micro_batch_size is not None:
batch_size = train_dataloader.batch_sampler.micro_batch_size
else:
raise ValueError(f'Could not find batch_size from batch_sampler: {train_dataloader.batch_sampler}')
else:
raise ValueError(f'Could not find batch_size from train_dataloader: {train_dataloader}')
drop_last = train_dataloader.drop_last
max_steps = compute_max_steps(

View file

@ -17,6 +17,7 @@ from functools import partial
from typing import Any, Dict, List, Optional, Union
import hydra
import torch
import torch.optim as optim
from omegaconf import DictConfig, OmegaConf
from torch.optim import adadelta, adagrad, adamax, rmsprop, rprop
@ -41,10 +42,12 @@ AVAILABLE_OPTIMIZERS = {
try:
from apex.optimizers import FusedLAMB
from apex.optimizers import FusedAdam
AVAILABLE_OPTIMIZERS['lamb'] = FusedLAMB
AVAILABLE_OPTIMIZERS['fused_adam'] = FusedAdam
except ModuleNotFoundError:
logging.warning("Apex was not found. Using the lamb optimizer will error out.")
logging.warning("Apex was not found. Using the lamb or fused_adam optimizer will error out.")
__all__ = ['get_optimizer', 'register_optimizer', 'parse_optimizer_args']
@ -165,6 +168,9 @@ def get_optimizer(name: str, **kwargs: Optional[Dict[str, Any]]) -> Optimizer:
raise ValueError(
f"Cannot resolve optimizer '{name}'. Available optimizers are : " f"{AVAILABLE_OPTIMIZERS.keys()}"
)
if name == 'fused_adam':
if not torch.cuda.is_available():
raise ValueError(f'CUDA must be available to use fused_adam.')
optimizer = AVAILABLE_OPTIMIZERS[name]
optimizer = partial(optimizer, **kwargs)

View file

@ -14,9 +14,9 @@
MAJOR = 1
MINOR = 4
MINOR = 5
PATCH = 0
PRE_RELEASE = ''
PRE_RELEASE = 'b1'
# Use the following formatting: (major, minor, patch, pre-release)
VERSION = (MAJOR, MINOR, PATCH, PRE_RELEASE)

View file

@ -44,8 +44,10 @@ class AppState(metaclass=Singleton):
self._world_size = None
self._model_parallel_size = None
self._model_parallel_group = None
self._is_megatron_initialized = False
self._data_parallel_size = None
self._data_parallel_group = None
self._megatron_checkpoint_version = None
self._random_seed = None

View file

@ -31,11 +31,10 @@ from pytorch_lightning.callbacks import Callback, ModelCheckpoint
from pytorch_lightning.loggers import LoggerCollection as _LoggerCollection
from pytorch_lightning.loggers import TensorBoardLogger, WandbLogger
from pytorch_lightning.plugins.training_type.ddp import DDPPlugin
from pytorch_lightning.utilities import rank_zero_only
from pytorch_lightning.utilities.types import _METRIC
from nemo.constants import NEMO_ENV_VARNAME_VERSION
from nemo.utils import app_state, logging, timers
from nemo.utils import logging, timers
from nemo.utils.app_state import AppState
from nemo.utils.exceptions import NeMoBaseException
from nemo.utils.get_rank import is_global_rank_zero
@ -72,7 +71,6 @@ class CallbackParams:
save_top_k: Optional[int] = 3
save_weights_only: Optional[bool] = False
mode: Optional[str] = "min"
period: Optional[int] = None
every_n_val_epochs: Optional[int] = 1
prefix: Optional[str] = None # If None, exp_manager will attempt to handle the filepath
postfix: str = ".nemo"
@ -380,6 +378,7 @@ def check_resume(
NotFoundError: If resume is True, resume_ignore_no_checkpoint is False, and checkpoints could not be found.
ValueError: If resume is True, and there were more than 1 checkpoint could found.
"""
if not log_dir:
raise ValueError(f"Resuming requires the log_dir {log_dir} to be passed to exp_manager")
@ -707,58 +706,55 @@ class NeMoModelCheckpoint(ModelCheckpoint):
for _ in range(models_to_delete):
model = best_k_models.pop(-1)
self.best_k_models.pop(model)
self._del_model(model)
self._del_model_without_trainer(model)
logging.debug(f"Removed checkpoint: {model}")
self.kth_best_model_path = best_k_models[-1]
self.best_model_path = best_k_models[0]
self.best_model_score = self.best_k_models[self.best_model_path]
@rank_zero_only
def on_save_checkpoint(self, trainer, pl_module, checkpoint):
output = super().on_save_checkpoint(trainer, pl_module, checkpoint)
if not self.always_save_nemo:
return output
# Load the best model and then re-save it
app_state = AppState()
# since we are creating tarfile artifacts we need to update .nemo path
app_state.model_restore_path = os.path.abspath(
os.path.expanduser(os.path.join(self.dirpath, self.prefix + self.postfix))
)
if self.save_best_model:
if not os.path.exists(self.best_model_path):
return output
if self.best_model_path == self.previous_best_path:
return output
self.previous_model_path = self.best_model_path
old_state_dict = deepcopy(pl_module.state_dict())
checkpoint = torch.load(self.best_model_path, map_location='cpu')
if 'state_dict' in checkpoint:
checkpoint = checkpoint['state_dict']
# get a new instanace of the model
pl_module.load_state_dict(checkpoint, strict=True)
pl_module.save_to(save_path=app_state.model_restore_path)
pl_module.load_state_dict(old_state_dict, strict=True)
else:
pl_module.save_to(save_path=app_state.model_restore_path)
return output
# Load the best model and then re-save it
app_state = AppState()
if app_state.model_parallel_size is not None and app_state.model_parallel_size > 1:
raise ValueError(f'always_save_nemo is not implemented for model parallel models.')
# since we are creating tarfile artifacts we need to update .nemo path
app_state.model_restore_path = os.path.abspath(
os.path.expanduser(os.path.join(self.dirpath, self.prefix + self.postfix))
)
if self.save_best_model:
if not os.path.exists(self.best_model_path):
return output
if self.best_model_path == self.previous_best_path:
return output
self.previous_model_path = self.best_model_path
old_state_dict = deepcopy(pl_module.state_dict())
checkpoint = torch.load(self.best_model_path, map_location='cpu')
if 'state_dict' in checkpoint:
checkpoint = checkpoint['state_dict']
# get a new instanace of the model
pl_module.load_state_dict(checkpoint, strict=True)
pl_module.save_to(save_path=app_state.model_restore_path)
pl_module.load_state_dict(old_state_dict, strict=True)
else:
pl_module.save_to(save_path=app_state.model_restore_path)
return output
@rank_zero_only
def on_train_end(self, trainer, pl_module):
if trainer.fast_dev_run:
return None
app_state = AppState()
if app_state.model_parallel_size is not None:
return None
# TODO: make this work for model parallel, need to call on data parallel rank 0 and update best_model_path
# Load the best model and then re-save it
if self.save_best_model:
trainer.checkpoint_connector.restore(self.best_model_path)
pl_module.save_to(save_path=os.path.join(self.dirpath, self.prefix + self.postfix))
def _del_model(self, trainer: "pl.Trainer", filepath: str) -> None:
@ -768,18 +764,36 @@ class NeMoModelCheckpoint(ModelCheckpoint):
app_state = AppState()
if app_state.model_parallel_size is not None:
# filepath needs to be updated to include mp_rank
# TODO: figure out a good way to update these filepaths
dirname = os.path.dirname(filepath)
basename = os.path.basename(filepath)
filepath = f'{dirname}/mp_rank_{app_state.model_parallel_rank:02d}/{basename}'
# each model parallel rank needs to remove its model
if app_state.data_parallel_rank == 0:
super()._del_model(trainer, filepath)
logging.info(f"Removed model parallel checkpoint: {filepath}")
if self._fs.exists(filepath):
self._fs.rm(filepath)
logging.info(f"Removed model parallel checkpoint: {filepath}")
else:
return super()._del_model(trainer, filepath)
def _del_model_without_trainer(self, filepath: str) -> None:
app_state = AppState()
if app_state.model_parallel_size is not None:
# filepath needs to be updated to include mp_rank
dirname = os.path.dirname(filepath)
basename = os.path.basename(filepath)
filepath = f'{dirname}/mp_rank_{app_state.model_parallel_rank:02d}/{basename}'
# each model parallel rank needs to remove its model
if is_global_rank_zero() or (app_state.model_parallel_size is not None and app_state.data_parallel_rank == 0):
try:
self._fs.rm(filepath)
logging.info(f"Removed checkpoint: {filepath}")
except:
logging.info(f"Tried to remove checkpoint: {filepath} but failed.")
def _save_last_checkpoint(self, trainer: 'pl.Trainer', monitor_candidates: Dict[str, _METRIC]) -> None:
""" Overrides PTL method to account for model parallel checkpoints.
Checks for data parallel rank 0 rather than global rank 0.
@ -792,11 +806,18 @@ class NeMoModelCheckpoint(ModelCheckpoint):
filepath = self._format_checkpoint_name(self.CHECKPOINT_NAME_LAST, monitor_candidates)
filepath = os.path.join(self.dirpath, f"{filepath}{self.FILE_EXTENSION}")
self._save_model(trainer, filepath)
trainer.save_checkpoint(filepath)
# TODO: figure out where self.last_model_path is being set
if self.last_model_path is not None:
if 'mp_rank' in self.last_model_path:
last_model_path = Path(self.last_model_path)
last_model_path = last_model_path.parent.parent.joinpath(last_model_path.name)
self.last_model_path = str(last_model_path)
# for model parallel we need to delete models for each model parallel rank
if self.last_model_path and self.last_model_path != filepath and app_state.data_parallel_rank == 0:
self._del_model(self.last_model_path)
self._del_model(trainer, self.last_model_path)
self.last_model_path = filepath
@ -813,7 +834,8 @@ class NeMoModelCheckpoint(ModelCheckpoint):
return
filepath = self._get_metric_interpolated_filepath_name(monitor_candidates, trainer)
self._save_model(trainer, filepath)
trainer.save_checkpoint(filepath)
if (
self.save_top_k is None
@ -821,7 +843,7 @@ class NeMoModelCheckpoint(ModelCheckpoint):
and self.best_model_path != filepath
and app_state.data_parallel_rank == 0
):
self._del_model(self.best_model_path)
self._del_model(trainer, self.best_model_path)
self.best_model_path = filepath
else:
@ -887,14 +909,6 @@ def configure_checkpointing(
f"{trainer.check_val_every_n_epoch} epochs to ensure that checkpointing will not error out."
)
if params.period is not None:
logging.warning(
"The use of `period` in the checkpoint callback is deprecrated, please use `every_n_val_epochs` instead. "
"Overwriting `every_n_val_epochs` with `period`."
)
params.every_n_val_epochs = params.period
# NOTE: checkpoint_callback should be called last
checkpoint_callback = NeMoModelCheckpoint(n_resume=resume, **params)
checkpoint_callback.last_model_path = trainer.checkpoint_connector.resume_checkpoint_path or ""
trainer.callbacks.append(checkpoint_callback)

Some files were not shown because too many files have changed in this diff Show more