NeMo/nemo/collections/nlp/modules/common/megatron/language_model.py
Eric Harper 32fa5cfaf3
[BigNLP] Merge Megatron GPT to main (#2975)
* fix gpu init after removing debug print in mpu

Signed-off-by: ericharper <complex451@gmail.com>

* add fused_adam

Signed-off-by: ericharper <complex451@gmail.com>

* check ds is not none before logging len

Signed-off-by: ericharper <complex451@gmail.com>

* set fp16 arg to true and fix enum conflict

Signed-off-by: ericharper <complex451@gmail.com>

* make fp16 arg configurable

Signed-off-by: ericharper <complex451@gmail.com>

* add grad clip from megatron

Signed-off-by: ericharper <complex451@gmail.com>

* Linear warmup with cosine annealing and constant holding (#2846)

* Testing cosine schedule

Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca>

* Style fixes

Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca>

* Fixes

Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca>

* More fixes

Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca>

* update config for constant steps in schedule

Signed-off-by: ericharper <complex451@gmail.com>

* temporarily import enum from megatron

Signed-off-by: ericharper <complex451@gmail.com>

* add grad clip for fp32

Signed-off-by: ericharper <complex451@gmail.com>

* update check for _del_model_without_trainer

Signed-off-by: ericharper <complex451@gmail.com>

* updating restore for model parallel

Signed-off-by: ericharper <complex451@gmail.com>

* add predict script

Signed-off-by: ericharper <complex451@gmail.com>

* update test iters

Signed-off-by: ericharper <complex451@gmail.com>

* add barrier

Signed-off-by: ericharper <complex451@gmail.com>

* return if clip_val is 0 or None

Signed-off-by: ericharper <complex451@gmail.com>

* when using amp clip grads after they are unscaled

Signed-off-by: ericharper <complex451@gmail.com>

* make native amp scaler hyperparams configurable

Signed-off-by: ericharper <complex451@gmail.com>

* (1) nvfuser, (2) amp-casting decoration (#2894)

* (1) nvfuser, (2) amp-casting decoration

Signed-off-by: Sangkug Lym <slym@nvidia.com>

* support bf16

Signed-off-by: Sangkug Lym <slym@nvidia.com>

* update package info

Signed-off-by: ericharper <complex451@gmail.com>

* add set device to constructor

Signed-off-by: ericharper <complex451@gmail.com>

* set_device in constructor

Signed-off-by: ericharper <complex451@gmail.com>

* [BigNLP] Remove megatron-lm dependency. (#2910)

* remove args

Signed-off-by: ericharper <complex451@gmail.com>

* remove args

Signed-off-by: ericharper <complex451@gmail.com>

* remove args

Signed-off-by: ericharper <complex451@gmail.com>

* remove args

Signed-off-by: ericharper <complex451@gmail.com>

* remove args in progress

Signed-off-by: ericharper <complex451@gmail.com>

* remove args in progress

Signed-off-by: ericharper <complex451@gmail.com>

* remove args in progress

Signed-off-by: ericharper <complex451@gmail.com>

* remove args in progress

Signed-off-by: ericharper <complex451@gmail.com>

* add load_fused_kernels

Signed-off-by: ericharper <complex451@gmail.com>

* add load_fused_kernels

Signed-off-by: ericharper <complex451@gmail.com>

* update megatron_init

Signed-off-by: ericharper <complex451@gmail.com>

* add fused kernels

Signed-off-by: ericharper <complex451@gmail.com>

* add fused kernels

Signed-off-by: ericharper <complex451@gmail.com>

* update process batch

Signed-off-by: ericharper <complex451@gmail.com>

* remove erroneous import

Signed-off-by: ericharper <complex451@gmail.com>

* remove erroneous import

Signed-off-by: ericharper <complex451@gmail.com>

* remove erroneous import

Signed-off-by: ericharper <complex451@gmail.com>

* add megatron clip_grad

Signed-off-by: ericharper <complex451@gmail.com>

* trying to resolve circular import error

Signed-off-by: ericharper <complex451@gmail.com>

* rename file

Signed-off-by: ericharper <complex451@gmail.com>

* remove non-gpt models and datasets from __init__ files

Signed-off-by: ericharper <complex451@gmail.com>

* set device in constructorfor gpu init

Signed-off-by: ericharper <complex451@gmail.com>

* set device in constructorfor gpu init

Signed-off-by: ericharper <complex451@gmail.com>

* set_device in constructor

Signed-off-by: ericharper <complex451@gmail.com>

* clean config

Signed-off-by: ericharper <complex451@gmail.com>

* update MegatronDataset

Signed-off-by: ericharper <complex451@gmail.com>

* clean up MegatronModule

Signed-off-by: ericharper <complex451@gmail.com>

* clean up MegatronModule

Signed-off-by: ericharper <complex451@gmail.com>

* rename fp16 and bf16 flags to fused_softmax_input_in_fp16/bf16

Signed-off-by: ericharper <complex451@gmail.com>

* rename to fused_fp16

Signed-off-by: ericharper <complex451@gmail.com>

* add fused_fp16 arg to LayerNorm calls

Signed-off-by: ericharper <complex451@gmail.com>

* fix arg name

Signed-off-by: ericharper <complex451@gmail.com>

* fix arg name

Signed-off-by: ericharper <complex451@gmail.com>

* fix import

Signed-off-by: ericharper <complex451@gmail.com>

* update arg

Signed-off-by: ericharper <complex451@gmail.com>

* skip warmup default to True

Signed-off-by: ericharper <complex451@gmail.com>

* skip warmup default to True

Signed-off-by: ericharper <complex451@gmail.com>

* Adding complete method to MegatronGPTModel (#2935)

Signed-off-by: Oleksii Kuchaiev <okuchaiev@nvidia.com>

* make ffn_hidden_size mandatory

Signed-off-by: ericharper <complex451@gmail.com>

* Manually migrating timing of step into branch (#2937)

* 1. Manually migrating timing of step into branch.

Signed-off-by: Micha Livne <mlivne@nvidia.com>

* 1. Updated file name and content.

Signed-off-by: Micha Livne <mlivne@nvidia.com>

* 1. Updated to latest code.

Signed-off-by: Micha Livne <mlivne@nvidia.com>

Co-authored-by: Micha Livne <mlivne@nvidia.com>

* remove unused imports

Signed-off-by: ericharper <complex451@gmail.com>

* remove unused import

Signed-off-by: ericharper <complex451@gmail.com>

* remove unused import

Signed-off-by: ericharper <complex451@gmail.com>

* remove unused import

Signed-off-by: ericharper <complex451@gmail.com>

* check fused_fp16 and fused_bf16 are not both True

Signed-off-by: ericharper <complex451@gmail.com>

* update predict script for model parallel .nemo

Signed-off-by: ericharper <complex451@gmail.com>

* typo

Signed-off-by: ericharper <complex451@gmail.com>

* typo

Signed-off-by: ericharper <complex451@gmail.com>

Co-authored-by: Oleksii Kuchaiev <okuchaiev@users.noreply.github.com>
Co-authored-by: Micha Livne <michalivne@users.noreply.github.com>
Co-authored-by: Micha Livne <mlivne@nvidia.com>

* NVfuser (#2943)

* activation checkpoint recompute

Signed-off-by: Sangkug Lym <slym@nvidia.com>

* selective nvfuser setup

* Megatron gpt bfloat support (#2926)

* Save/restore fix

Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca>

* Another merge

Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca>

* Bf16 args in init

Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca>

* Set precision

Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca>

* Remove debug stuff

Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca>

* add bf16 casting decorator

Signed-off-by: Sangkug Lym <slym@nvidia.com>

* Bfloat layernorm propagation

Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca>

* activation checkpoint recompute

Signed-off-by: Sangkug Lym <slym@nvidia.com>

* selective nvfuser setup

* More arg removal

Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca>

* Remove BERTDataset

Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca>

* update to latest apex and patch transformer autocast

Signed-off-by: ericharper <complex451@gmail.com>

Co-authored-by: Sangkug Lym <slym@nvidia.com>
Co-authored-by: ericharper <complex451@gmail.com>

* don't set jit for bf16

Signed-off-by: ericharper <complex451@gmail.com>

* replace apex.mpu

Signed-off-by: ericharper <complex451@gmail.com>

* fix grad clip

Signed-off-by: ericharper <complex451@gmail.com>

* NVFuser fixes (#2951)

* Fuser fixes

Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca>

* Remove dummy handler

Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca>

* Remove PTL plugin based logic for fusion

Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca>

* remove duplicated file

Signed-off-by: ericharper <complex451@gmail.com>

* typo (#2960)

Signed-off-by: ericharper <complex451@gmail.com>

* [BigNLP] Script to convert GPT checkpoint to .nemo (#2958)

* remove args

Signed-off-by: ericharper <complex451@gmail.com>

* remove args

Signed-off-by: ericharper <complex451@gmail.com>

* remove args

Signed-off-by: ericharper <complex451@gmail.com>

* remove args

Signed-off-by: ericharper <complex451@gmail.com>

* remove args in progress

Signed-off-by: ericharper <complex451@gmail.com>

* remove args in progress

Signed-off-by: ericharper <complex451@gmail.com>

* remove args in progress

Signed-off-by: ericharper <complex451@gmail.com>

* remove args in progress

Signed-off-by: ericharper <complex451@gmail.com>

* add load_fused_kernels

Signed-off-by: ericharper <complex451@gmail.com>

* add load_fused_kernels

Signed-off-by: ericharper <complex451@gmail.com>

* update megatron_init

Signed-off-by: ericharper <complex451@gmail.com>

* add fused kernels

Signed-off-by: ericharper <complex451@gmail.com>

* add fused kernels

Signed-off-by: ericharper <complex451@gmail.com>

* update process batch

Signed-off-by: ericharper <complex451@gmail.com>

* remove erroneous import

Signed-off-by: ericharper <complex451@gmail.com>

* remove erroneous import

Signed-off-by: ericharper <complex451@gmail.com>

* remove erroneous import

Signed-off-by: ericharper <complex451@gmail.com>

* add megatron clip_grad

Signed-off-by: ericharper <complex451@gmail.com>

* trying to resolve circular import error

Signed-off-by: ericharper <complex451@gmail.com>

* rename file

Signed-off-by: ericharper <complex451@gmail.com>

* remove non-gpt models and datasets from __init__ files

Signed-off-by: ericharper <complex451@gmail.com>

* set device in constructorfor gpu init

Signed-off-by: ericharper <complex451@gmail.com>

* set device in constructorfor gpu init

Signed-off-by: ericharper <complex451@gmail.com>

* set_device in constructor

Signed-off-by: ericharper <complex451@gmail.com>

* clean config

Signed-off-by: ericharper <complex451@gmail.com>

* update MegatronDataset

Signed-off-by: ericharper <complex451@gmail.com>

* clean up MegatronModule

Signed-off-by: ericharper <complex451@gmail.com>

* clean up MegatronModule

Signed-off-by: ericharper <complex451@gmail.com>

* rename fp16 and bf16 flags to fused_softmax_input_in_fp16/bf16

Signed-off-by: ericharper <complex451@gmail.com>

* rename to fused_fp16

Signed-off-by: ericharper <complex451@gmail.com>

* add fused_fp16 arg to LayerNorm calls

Signed-off-by: ericharper <complex451@gmail.com>

* fix arg name

Signed-off-by: ericharper <complex451@gmail.com>

* fix arg name

Signed-off-by: ericharper <complex451@gmail.com>

* fix import

Signed-off-by: ericharper <complex451@gmail.com>

* update arg

Signed-off-by: ericharper <complex451@gmail.com>

* skip warmup default to True

Signed-off-by: ericharper <complex451@gmail.com>

* skip warmup default to True

Signed-off-by: ericharper <complex451@gmail.com>

* Adding complete method to MegatronGPTModel (#2935)

Signed-off-by: Oleksii Kuchaiev <okuchaiev@nvidia.com>

* make ffn_hidden_size mandatory

Signed-off-by: ericharper <complex451@gmail.com>

* Manually migrating timing of step into branch (#2937)

* 1. Manually migrating timing of step into branch.

Signed-off-by: Micha Livne <mlivne@nvidia.com>

* 1. Updated file name and content.

Signed-off-by: Micha Livne <mlivne@nvidia.com>

* 1. Updated to latest code.

Signed-off-by: Micha Livne <mlivne@nvidia.com>

Co-authored-by: Micha Livne <mlivne@nvidia.com>

* remove unused imports

Signed-off-by: ericharper <complex451@gmail.com>

* remove unused import

Signed-off-by: ericharper <complex451@gmail.com>

* remove unused import

Signed-off-by: ericharper <complex451@gmail.com>

* remove unused import

Signed-off-by: ericharper <complex451@gmail.com>

* check fused_fp16 and fused_bf16 are not both True

Signed-off-by: ericharper <complex451@gmail.com>

* update predict script for model parallel .nemo

Signed-off-by: ericharper <complex451@gmail.com>

* typo

Signed-off-by: ericharper <complex451@gmail.com>

* add script to convert .ckpt to .nemo

Signed-off-by: ericharper <complex451@gmail.com>

* in progress

Signed-off-by: ericharper <complex451@gmail.com>

* update

Signed-off-by: ericharper <complex451@gmail.com>

* convert mp checkpoints to nemo

Signed-off-by: ericharper <complex451@gmail.com>

* update help

Signed-off-by: ericharper <complex451@gmail.com>

* add safeguard for model parallel save_to

Signed-off-by: ericharper <complex451@gmail.com>

* adjust NLPModel save_to to be safer for model parallel

Signed-off-by: Oleksii Kuchaiev <okuchaiev@nvidia.com>

Co-authored-by: Oleksii Kuchaiev <okuchaiev@users.noreply.github.com>
Co-authored-by: Micha Livne <michalivne@users.noreply.github.com>
Co-authored-by: Micha Livne <mlivne@nvidia.com>
Co-authored-by: Oleksii Kuchaiev <okuchaiev@nvidia.com>

* [BigNLP] Update GPT evaluation to work with tensor model parallel  (#2959)

* in progress

Signed-off-by: ericharper <complex451@gmail.com>

* update args

Signed-off-by: ericharper <complex451@gmail.com>

* add request dataset

Signed-off-by: ericharper <complex451@gmail.com>

* tokenize request

Signed-off-by: ericharper <complex451@gmail.com>

* in progress

Signed-off-by: ericharper <complex451@gmail.com>

* able to run

Signed-off-by: ericharper <complex451@gmail.com>

* reduce logits

Signed-off-by: ericharper <complex451@gmail.com>

* capture response

Signed-off-by: ericharper <complex451@gmail.com>

* squeeze and unsqueeze

Signed-off-by: ericharper <complex451@gmail.com>

* handle non model parallel case

Signed-off-by: ericharper <complex451@gmail.com>

* clean imports

Signed-off-by: ericharper <complex451@gmail.com>

* add file

Signed-off-by: ericharper <complex451@gmail.com>

* convert logits to log_probs

Signed-off-by: Oleksii Kuchaiev <okuchaiev@nvidia.com>

* rename logits to log_probs

Signed-off-by: Oleksii Kuchaiev <okuchaiev@nvidia.com>

Co-authored-by: Oleksii Kuchaiev <okuchaiev@nvidia.com>

* add megatron gpt pretraining

Signed-off-by: ericharper <complex451@gmail.com>

* add megatron gpt pretraining

Signed-off-by: ericharper <complex451@gmail.com>

* add megatron gpt pretraining

Signed-off-by: ericharper <complex451@gmail.com>

* updating to work with latest megatron

Signed-off-by: ericharper <complex451@gmail.com>

* updating to work with latest megatron

Signed-off-by: ericharper <complex451@gmail.com>

* update _del_model

Signed-off-by: ericharper <complex451@gmail.com>

* adding gpt model

Signed-off-by: ericharper <complex451@gmail.com>

* adding gpt model

Signed-off-by: ericharper <complex451@gmail.com>

* adding gpt model

Signed-off-by: ericharper <complex451@gmail.com>

* instantiate GPTmodel

Signed-off-by: ericharper <complex451@gmail.com>

* adding build dataset

Signed-off-by: ericharper <complex451@gmail.com>

* build megatron dataset in .setup

Signed-off-by: ericharper <complex451@gmail.com>

* setup dataloader

Signed-off-by: ericharper <complex451@gmail.com>

* add vocab_file and merge_file to megatron init

Signed-off-by: ericharper <complex451@gmail.com>

* add forward

Signed-off-by: ericharper <complex451@gmail.com>

* add train loss

Signed-off-by: ericharper <complex451@gmail.com>

* add optimizer

Signed-off-by: ericharper <complex451@gmail.com>

* add exp_manager

Signed-off-by: ericharper <complex451@gmail.com>

* multi-gpu is working

Signed-off-by: ericharper <complex451@gmail.com>

* adding val loop

Signed-off-by: ericharper <complex451@gmail.com>

* style

Signed-off-by: ericharper <complex451@gmail.com>

* adding val loop

Signed-off-by: ericharper <complex451@gmail.com>

* fix ranks

Signed-off-by: ericharper <complex451@gmail.com>

* fix model parallel checkpoint saving

Signed-off-by: ericharper <complex451@gmail.com>

* fix _del_model

Signed-off-by: ericharper <complex451@gmail.com>

* added megatron batch sampler

Signed-off-by: ericharper <complex451@gmail.com>

* try to fix num steps

Signed-off-by: ericharper <complex451@gmail.com>

* add wandb to config

Signed-off-by: ericharper <complex451@gmail.com>

* log lr

Signed-off-by: ericharper <complex451@gmail.com>

* add warmup ratio to config

Signed-off-by: ericharper <complex451@gmail.com>

* update configs

Signed-off-by: ericharper <complex451@gmail.com>

* update configs

Signed-off-by: ericharper <complex451@gmail.com>

* add cpu init to args

Signed-off-by: ericharper <complex451@gmail.com>

* update config

Signed-off-by: ericharper <complex451@gmail.com>

* update config

Signed-off-by: ericharper <complex451@gmail.com>

* Initial megatron dataset port

Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca>

* Fix merge conflicts

Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca>

* License fixes and megatron model porting

Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca>

* Style fixes

Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca>

* More fixes to import from nemo rather than megatron

Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca>

* Fix circular imports

Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca>

* Style fixes

Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca>

* Revert config file

Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca>

* Restructure further to avoid circular imports

Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca>

* add Makefile

Signed-off-by: ericharper <complex451@gmail.com>

* Add megatron modules

Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca>

* add license

Signed-off-by: ericharper <complex451@gmail.com>

* Port from latest megatron

Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca>

* update cfg

Signed-off-by: ericharper <complex451@gmail.com>

* update config

Signed-off-by: ericharper <complex451@gmail.com>

* add _del_model_without_trainer

Signed-off-by: ericharper <complex451@gmail.com>

* add data preprocessing script

Signed-off-by: ericharper <complex451@gmail.com>

* update config

Signed-off-by: ericharper <complex451@gmail.com>

* use apex mpu

Signed-off-by: ericharper <complex451@gmail.com>

* replace print_rank_0 with nemo utils logging

Signed-off-by: ericharper <complex451@gmail.com>

* use apex mpu

Signed-off-by: ericharper <complex451@gmail.com>

* use apex mpu

Signed-off-by: ericharper <complex451@gmail.com>

* add use_cpu_initialization

Signed-off-by: ericharper <complex451@gmail.com>

* fixing autoresume in progress

Signed-off-by: ericharper <complex451@gmail.com>

* properly removing last checkpoint

Signed-off-by: ericharper <complex451@gmail.com>

* log consumed samples

Signed-off-by: ericharper <complex451@gmail.com>

* fix mp autoresume

Signed-off-by: ericharper <complex451@gmail.com>

* add NLPSaveRestoreConnector

Signed-off-by: ericharper <complex451@gmail.com>

* Megatron GPT training with NeMo tokenizers (#2818)

* Update files from megatron repo

Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca>

* Remove non NLP data related files from megatron

Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca>

* Merge megatron and nemo tokenizers

Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca>

* Remove get_tokenizer() calls from gpt model

Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca>

* Update tokenizer yaml config

Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca>

* add todo

Signed-off-by: ericharper <complex451@gmail.com>

* update config

Signed-off-by: ericharper <complex451@gmail.com>

* make init_method_std configurable

Signed-off-by: ericharper <complex451@gmail.com>

* make gpu init work by setting random seed earlier

Signed-off-by: ericharper <complex451@gmail.com>

* fix gpu init after removing debug print in mpu

Signed-off-by: ericharper <complex451@gmail.com>

* add fused_adam

Signed-off-by: ericharper <complex451@gmail.com>

* check ds is not none before logging len

Signed-off-by: ericharper <complex451@gmail.com>

* set fp16 arg to true and fix enum conflict

Signed-off-by: ericharper <complex451@gmail.com>

* make fp16 arg configurable

Signed-off-by: ericharper <complex451@gmail.com>

* add grad clip from megatron

Signed-off-by: ericharper <complex451@gmail.com>

* Linear warmup with cosine annealing and constant holding (#2846)

* Testing cosine schedule

Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca>

* Style fixes

Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca>

* Fixes

Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca>

* More fixes

Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca>

* update config for constant steps in schedule

Signed-off-by: ericharper <complex451@gmail.com>

* temporarily import enum from megatron

Signed-off-by: ericharper <complex451@gmail.com>

* add grad clip for fp32

Signed-off-by: ericharper <complex451@gmail.com>

* update check for _del_model_without_trainer

Signed-off-by: ericharper <complex451@gmail.com>

* updating restore for model parallel

Signed-off-by: ericharper <complex451@gmail.com>

* add predict script

Signed-off-by: ericharper <complex451@gmail.com>

* update test iters

Signed-off-by: ericharper <complex451@gmail.com>

* add barrier

Signed-off-by: ericharper <complex451@gmail.com>

* return if clip_val is 0 or None

Signed-off-by: ericharper <complex451@gmail.com>

* when using amp clip grads after they are unscaled

Signed-off-by: ericharper <complex451@gmail.com>

* make native amp scaler hyperparams configurable

Signed-off-by: ericharper <complex451@gmail.com>

* (1) nvfuser, (2) amp-casting decoration (#2894)

* (1) nvfuser, (2) amp-casting decoration

Signed-off-by: Sangkug Lym <slym@nvidia.com>

* support bf16

Signed-off-by: Sangkug Lym <slym@nvidia.com>

* update package info

Signed-off-by: ericharper <complex451@gmail.com>

* add set device to constructor

Signed-off-by: ericharper <complex451@gmail.com>

* set_device in constructor

Signed-off-by: ericharper <complex451@gmail.com>

* [BigNLP] Remove megatron-lm dependency. (#2910)

* remove args

Signed-off-by: ericharper <complex451@gmail.com>

* remove args

Signed-off-by: ericharper <complex451@gmail.com>

* remove args

Signed-off-by: ericharper <complex451@gmail.com>

* remove args

Signed-off-by: ericharper <complex451@gmail.com>

* remove args in progress

Signed-off-by: ericharper <complex451@gmail.com>

* remove args in progress

Signed-off-by: ericharper <complex451@gmail.com>

* remove args in progress

Signed-off-by: ericharper <complex451@gmail.com>

* remove args in progress

Signed-off-by: ericharper <complex451@gmail.com>

* add load_fused_kernels

Signed-off-by: ericharper <complex451@gmail.com>

* add load_fused_kernels

Signed-off-by: ericharper <complex451@gmail.com>

* update megatron_init

Signed-off-by: ericharper <complex451@gmail.com>

* add fused kernels

Signed-off-by: ericharper <complex451@gmail.com>

* add fused kernels

Signed-off-by: ericharper <complex451@gmail.com>

* update process batch

Signed-off-by: ericharper <complex451@gmail.com>

* remove erroneous import

Signed-off-by: ericharper <complex451@gmail.com>

* remove erroneous import

Signed-off-by: ericharper <complex451@gmail.com>

* remove erroneous import

Signed-off-by: ericharper <complex451@gmail.com>

* add megatron clip_grad

Signed-off-by: ericharper <complex451@gmail.com>

* trying to resolve circular import error

Signed-off-by: ericharper <complex451@gmail.com>

* rename file

Signed-off-by: ericharper <complex451@gmail.com>

* remove non-gpt models and datasets from __init__ files

Signed-off-by: ericharper <complex451@gmail.com>

* set device in constructorfor gpu init

Signed-off-by: ericharper <complex451@gmail.com>

* set device in constructorfor gpu init

Signed-off-by: ericharper <complex451@gmail.com>

* set_device in constructor

Signed-off-by: ericharper <complex451@gmail.com>

* clean config

Signed-off-by: ericharper <complex451@gmail.com>

* update MegatronDataset

Signed-off-by: ericharper <complex451@gmail.com>

* clean up MegatronModule

Signed-off-by: ericharper <complex451@gmail.com>

* clean up MegatronModule

Signed-off-by: ericharper <complex451@gmail.com>

* rename fp16 and bf16 flags to fused_softmax_input_in_fp16/bf16

Signed-off-by: ericharper <complex451@gmail.com>

* rename to fused_fp16

Signed-off-by: ericharper <complex451@gmail.com>

* add fused_fp16 arg to LayerNorm calls

Signed-off-by: ericharper <complex451@gmail.com>

* fix arg name

Signed-off-by: ericharper <complex451@gmail.com>

* fix arg name

Signed-off-by: ericharper <complex451@gmail.com>

* fix import

Signed-off-by: ericharper <complex451@gmail.com>

* update arg

Signed-off-by: ericharper <complex451@gmail.com>

* skip warmup default to True

Signed-off-by: ericharper <complex451@gmail.com>

* skip warmup default to True

Signed-off-by: ericharper <complex451@gmail.com>

* Adding complete method to MegatronGPTModel (#2935)

Signed-off-by: Oleksii Kuchaiev <okuchaiev@nvidia.com>

* make ffn_hidden_size mandatory

Signed-off-by: ericharper <complex451@gmail.com>

* Manually migrating timing of step into branch (#2937)

* 1. Manually migrating timing of step into branch.

Signed-off-by: Micha Livne <mlivne@nvidia.com>

* 1. Updated file name and content.

Signed-off-by: Micha Livne <mlivne@nvidia.com>

* 1. Updated to latest code.

Signed-off-by: Micha Livne <mlivne@nvidia.com>

Co-authored-by: Micha Livne <mlivne@nvidia.com>

* remove unused imports

Signed-off-by: ericharper <complex451@gmail.com>

* remove unused import

Signed-off-by: ericharper <complex451@gmail.com>

* remove unused import

Signed-off-by: ericharper <complex451@gmail.com>

* remove unused import

Signed-off-by: ericharper <complex451@gmail.com>

* check fused_fp16 and fused_bf16 are not both True

Signed-off-by: ericharper <complex451@gmail.com>

* update predict script for model parallel .nemo

Signed-off-by: ericharper <complex451@gmail.com>

* typo

Signed-off-by: ericharper <complex451@gmail.com>

* typo

Signed-off-by: ericharper <complex451@gmail.com>

Co-authored-by: Oleksii Kuchaiev <okuchaiev@users.noreply.github.com>
Co-authored-by: Micha Livne <michalivne@users.noreply.github.com>
Co-authored-by: Micha Livne <mlivne@nvidia.com>

* NVfuser (#2943)

* activation checkpoint recompute

Signed-off-by: Sangkug Lym <slym@nvidia.com>

* selective nvfuser setup

* Megatron gpt bfloat support (#2926)

* Save/restore fix

Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca>

* Another merge

Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca>

* Bf16 args in init

Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca>

* Set precision

Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca>

* Remove debug stuff

Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca>

* add bf16 casting decorator

Signed-off-by: Sangkug Lym <slym@nvidia.com>

* Bfloat layernorm propagation

Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca>

* activation checkpoint recompute

Signed-off-by: Sangkug Lym <slym@nvidia.com>

* selective nvfuser setup

* More arg removal

Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca>

* Remove BERTDataset

Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca>

* update to latest apex and patch transformer autocast

Signed-off-by: ericharper <complex451@gmail.com>

Co-authored-by: Sangkug Lym <slym@nvidia.com>
Co-authored-by: ericharper <complex451@gmail.com>

* don't set jit for bf16

Signed-off-by: ericharper <complex451@gmail.com>

* replace apex.mpu

Signed-off-by: ericharper <complex451@gmail.com>

* fix grad clip

Signed-off-by: ericharper <complex451@gmail.com>

* NVFuser fixes (#2951)

* Fuser fixes

Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca>

* Remove dummy handler

Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca>

* Remove PTL plugin based logic for fusion

Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca>

* remove duplicated file

Signed-off-by: ericharper <complex451@gmail.com>

* typo (#2960)

Signed-off-by: ericharper <complex451@gmail.com>

* [BigNLP] Script to convert GPT checkpoint to .nemo (#2958)

* remove args

Signed-off-by: ericharper <complex451@gmail.com>

* remove args

Signed-off-by: ericharper <complex451@gmail.com>

* remove args

Signed-off-by: ericharper <complex451@gmail.com>

* remove args

Signed-off-by: ericharper <complex451@gmail.com>

* remove args in progress

Signed-off-by: ericharper <complex451@gmail.com>

* remove args in progress

Signed-off-by: ericharper <complex451@gmail.com>

* remove args in progress

Signed-off-by: ericharper <complex451@gmail.com>

* remove args in progress

Signed-off-by: ericharper <complex451@gmail.com>

* add load_fused_kernels

Signed-off-by: ericharper <complex451@gmail.com>

* add load_fused_kernels

Signed-off-by: ericharper <complex451@gmail.com>

* update megatron_init

Signed-off-by: ericharper <complex451@gmail.com>

* add fused kernels

Signed-off-by: ericharper <complex451@gmail.com>

* add fused kernels

Signed-off-by: ericharper <complex451@gmail.com>

* update process batch

Signed-off-by: ericharper <complex451@gmail.com>

* remove erroneous import

Signed-off-by: ericharper <complex451@gmail.com>

* remove erroneous import

Signed-off-by: ericharper <complex451@gmail.com>

* remove erroneous import

Signed-off-by: ericharper <complex451@gmail.com>

* add megatron clip_grad

Signed-off-by: ericharper <complex451@gmail.com>

* trying to resolve circular import error

Signed-off-by: ericharper <complex451@gmail.com>

* rename file

Signed-off-by: ericharper <complex451@gmail.com>

* remove non-gpt models and datasets from __init__ files

Signed-off-by: ericharper <complex451@gmail.com>

* set device in constructorfor gpu init

Signed-off-by: ericharper <complex451@gmail.com>

* set device in constructorfor gpu init

Signed-off-by: ericharper <complex451@gmail.com>

* set_device in constructor

Signed-off-by: ericharper <complex451@gmail.com>

* clean config

Signed-off-by: ericharper <complex451@gmail.com>

* update MegatronDataset

Signed-off-by: ericharper <complex451@gmail.com>

* clean up MegatronModule

Signed-off-by: ericharper <complex451@gmail.com>

* clean up MegatronModule

Signed-off-by: ericharper <complex451@gmail.com>

* rename fp16 and bf16 flags to fused_softmax_input_in_fp16/bf16

Signed-off-by: ericharper <complex451@gmail.com>

* rename to fused_fp16

Signed-off-by: ericharper <complex451@gmail.com>

* add fused_fp16 arg to LayerNorm calls

Signed-off-by: ericharper <complex451@gmail.com>

* fix arg name

Signed-off-by: ericharper <complex451@gmail.com>

* fix arg name

Signed-off-by: ericharper <complex451@gmail.com>

* fix import

Signed-off-by: ericharper <complex451@gmail.com>

* update arg

Signed-off-by: ericharper <complex451@gmail.com>

* skip warmup default to True

Signed-off-by: ericharper <complex451@gmail.com>

* skip warmup default to True

Signed-off-by: ericharper <complex451@gmail.com>

* Adding complete method to MegatronGPTModel (#2935)

Signed-off-by: Oleksii Kuchaiev <okuchaiev@nvidia.com>

* make ffn_hidden_size mandatory

Signed-off-by: ericharper <complex451@gmail.com>

* Manually migrating timing of step into branch (#2937)

* 1. Manually migrating timing of step into branch.

Signed-off-by: Micha Livne <mlivne@nvidia.com>

* 1. Updated file name and content.

Signed-off-by: Micha Livne <mlivne@nvidia.com>

* 1. Updated to latest code.

Signed-off-by: Micha Livne <mlivne@nvidia.com>

Co-authored-by: Micha Livne <mlivne@nvidia.com>

* remove unused imports

Signed-off-by: ericharper <complex451@gmail.com>

* remove unused import

Signed-off-by: ericharper <complex451@gmail.com>

* remove unused import

Signed-off-by: ericharper <complex451@gmail.com>

* remove unused import

Signed-off-by: ericharper <complex451@gmail.com>

* check fused_fp16 and fused_bf16 are not both True

Signed-off-by: ericharper <complex451@gmail.com>

* update predict script for model parallel .nemo

Signed-off-by: ericharper <complex451@gmail.com>

* typo

Signed-off-by: ericharper <complex451@gmail.com>

* add script to convert .ckpt to .nemo

Signed-off-by: ericharper <complex451@gmail.com>

* in progress

Signed-off-by: ericharper <complex451@gmail.com>

* update

Signed-off-by: ericharper <complex451@gmail.com>

* convert mp checkpoints to nemo

Signed-off-by: ericharper <complex451@gmail.com>

* update help

Signed-off-by: ericharper <complex451@gmail.com>

* add safeguard for model parallel save_to

Signed-off-by: ericharper <complex451@gmail.com>

* adjust NLPModel save_to to be safer for model parallel

Signed-off-by: Oleksii Kuchaiev <okuchaiev@nvidia.com>

Co-authored-by: Oleksii Kuchaiev <okuchaiev@users.noreply.github.com>
Co-authored-by: Micha Livne <michalivne@users.noreply.github.com>
Co-authored-by: Micha Livne <mlivne@nvidia.com>
Co-authored-by: Oleksii Kuchaiev <okuchaiev@nvidia.com>

* [BigNLP] Update GPT evaluation to work with tensor model parallel  (#2959)

* in progress

Signed-off-by: ericharper <complex451@gmail.com>

* update args

Signed-off-by: ericharper <complex451@gmail.com>

* add request dataset

Signed-off-by: ericharper <complex451@gmail.com>

* tokenize request

Signed-off-by: ericharper <complex451@gmail.com>

* in progress

Signed-off-by: ericharper <complex451@gmail.com>

* able to run

Signed-off-by: ericharper <complex451@gmail.com>

* reduce logits

Signed-off-by: ericharper <complex451@gmail.com>

* capture response

Signed-off-by: ericharper <complex451@gmail.com>

* squeeze and unsqueeze

Signed-off-by: ericharper <complex451@gmail.com>

* handle non model parallel case

Signed-off-by: ericharper <complex451@gmail.com>

* clean imports

Signed-off-by: ericharper <complex451@gmail.com>

* add file

Signed-off-by: ericharper <complex451@gmail.com>

* convert logits to log_probs

Signed-off-by: Oleksii Kuchaiev <okuchaiev@nvidia.com>

* rename logits to log_probs

Signed-off-by: Oleksii Kuchaiev <okuchaiev@nvidia.com>

Co-authored-by: Oleksii Kuchaiev <okuchaiev@nvidia.com>

* style

Signed-off-by: ericharper <complex451@gmail.com>

* fix copyright headers

Signed-off-by: ericharper <complex451@gmail.com>

* fix copyright headers

Signed-off-by: ericharper <complex451@gmail.com>

* remove old TimingCallback

Signed-off-by: ericharper <complex451@gmail.com>

* style

Signed-off-by: ericharper <complex451@gmail.com>

* update jenkins to use latest apex and sandeep's fork

Signed-off-by: ericharper <complex451@gmail.com>

* update jenkins

Signed-off-by: ericharper <complex451@gmail.com>

* update jenkins

Signed-off-by: ericharper <complex451@gmail.com>

* update jenkins

Signed-off-by: ericharper <complex451@gmail.com>

* update jenkins

Signed-off-by: ericharper <complex451@gmail.com>

* try 2109 container

Signed-off-by: ericharper <complex451@gmail.com>

* try cuda container

Signed-off-by: ericharper <complex451@gmail.com>

* use internal container

Signed-off-by: ericharper <complex451@gmail.com>

* update checkpoint tests

Signed-off-by: ericharper <complex451@gmail.com>

* fix scheduler args

Signed-off-by: ericharper <complex451@gmail.com>

* update eval

Signed-off-by: ericharper <complex451@gmail.com>

* style

Signed-off-by: ericharper <complex451@gmail.com>

* update jenkins to use ptl 1.5 rc

Signed-off-by: ericharper <complex451@gmail.com>

* add import guard to jenkins

Signed-off-by: ericharper <complex451@gmail.com>

* add import guard to jenkins

Signed-off-by: ericharper <complex451@gmail.com>

* remove deterministic

Signed-off-by: ericharper <complex451@gmail.com>

* install numba .53

Signed-off-by: ericharper <complex451@gmail.com>

* allow for more variance

Signed-off-by: ericharper <complex451@gmail.com>

* update trainer config dataclass

Signed-off-by: ericharper <complex451@gmail.com>

* test_get_optimizer on gpu

Signed-off-by: ericharper <complex451@gmail.com>

* revert comment

Signed-off-by: ericharper <complex451@gmail.com>

* change trainer config default to 32

Signed-off-by: ericharper <complex451@gmail.com>

* [BigNLP] Remove fused kernel code instead use Apex (#2984)

* remove fused_kernels

Signed-off-by: ericharper <complex451@gmail.com>

* remove fused_kernels

Signed-off-by: ericharper <complex451@gmail.com>

* remove fused layer norm and fused softmax and use apex instead

Signed-off-by: ericharper <complex451@gmail.com>

* update imports

Signed-off-by: ericharper <complex451@gmail.com>

* remove comment

Signed-off-by: ericharper <complex451@gmail.com>

* use apex enums

Signed-off-by: ericharper <complex451@gmail.com>

* use apex enums

Signed-off-by: ericharper <complex451@gmail.com>

* add tab

Signed-off-by: ericharper <complex451@gmail.com>

* Timer with sliding window (#3002)

Co-authored-by: Micha Livne <michalivne@users.noreply.github.com>

* revert tab

Signed-off-by: ericharper <complex451@gmail.com>

* check for rank zero

Signed-off-by: ericharper <complex451@gmail.com>

* check for rank zero

Signed-off-by: ericharper <complex451@gmail.com>

* try explicit log dir

Signed-off-by: ericharper <complex451@gmail.com>

* add +

Signed-off-by: ericharper <complex451@gmail.com>

* don't rm

Signed-off-by: ericharper <complex451@gmail.com>

* make dir if it doesn't exist

Signed-off-by: ericharper <complex451@gmail.com>

* create mp nemo file in temp directory

Signed-off-by: ericharper <complex451@gmail.com>

* simplify mp save_to

Signed-off-by: ericharper <complex451@gmail.com>

* handle mp 1 case

Signed-off-by: ericharper <complex451@gmail.com>

* style fix

Signed-off-by: ericharper <complex451@gmail.com>

* remove files

Signed-off-by: ericharper <complex451@gmail.com>

* fix consumed_samples when resuming

Signed-off-by: ericharper <complex451@gmail.com>

* fix reinstall.sh

Signed-off-by: ericharper <complex451@gmail.com>

* update req

Signed-off-by: ericharper <complex451@gmail.com>

* add more detailed log for dataloaders

Signed-off-by: ericharper <complex451@gmail.com>

* check if cuda is available before using fused_adam

Signed-off-by: ericharper <complex451@gmail.com>

* revert comment

Signed-off-by: ericharper <complex451@gmail.com>

* update eval script to use model.freeze

Signed-off-by: ericharper <complex451@gmail.com>

* log train loss averaged over gradient accumulation steps

Signed-off-by: ericharper <complex451@gmail.com>

* check copyright earlier

Signed-off-by: ericharper <complex451@gmail.com>

* todo

Signed-off-by: ericharper <complex451@gmail.com>

* override SaveRestoreConnector in NLPModel init

Signed-off-by: ericharper <complex451@gmail.com>

* move to scripts

Signed-off-by: ericharper <complex451@gmail.com>

* remove star import

Signed-off-by: ericharper <complex451@gmail.com>

* remove comments

Signed-off-by: ericharper <complex451@gmail.com>

* remove unused dataset

Signed-off-by: ericharper <complex451@gmail.com>

* removed barrier

Signed-off-by: ericharper <complex451@gmail.com>

* check cfg

Signed-off-by: ericharper <complex451@gmail.com>

* remove logging

Signed-off-by: ericharper <complex451@gmail.com>

* freeze, unfreeze

Signed-off-by: ericharper <complex451@gmail.com>

* return None

Signed-off-by: ericharper <complex451@gmail.com>

* remove unused imports

Signed-off-by: ericharper <complex451@gmail.com>

* add TODO

Signed-off-by: ericharper <complex451@gmail.com>

* typecheck

Signed-off-by: ericharper <complex451@gmail.com>

* typo

Signed-off-by: ericharper <complex451@gmail.com>

* todo

Signed-off-by: ericharper <complex451@gmail.com>

* add common native plugin

Signed-off-by: ericharper <complex451@gmail.com>

* restore with trainer

Signed-off-by: ericharper <complex451@gmail.com>

* style

Signed-off-by: ericharper <complex451@gmail.com>

* deprecate megatron-lm bert

Signed-off-by: ericharper <complex451@gmail.com>

* deprecate megatron-lm bert

Signed-off-by: ericharper <complex451@gmail.com>

* compile helpers ont he fly

Signed-off-by: ericharper <complex451@gmail.com>

* remove amp_level

Signed-off-by: ericharper <complex451@gmail.com>

* remove amp_level from configs

Signed-off-by: ericharper <complex451@gmail.com>

* add missing import

Signed-off-by: ericharper <complex451@gmail.com>

* typo

Signed-off-by: ericharper <complex451@gmail.com>

* remove amp_level

Signed-off-by: ericharper <complex451@gmail.com>

* use fast huggingface tokenizers by default

Signed-off-by: ericharper <complex451@gmail.com>

* deal with huggingface tokenizer positional args

Signed-off-by: ericharper <complex451@gmail.com>

* deal with huggingface tokenizer positional args

Signed-off-by: ericharper <complex451@gmail.com>

* deal with huggingface tokenizer positional args

Signed-off-by: ericharper <complex451@gmail.com>

* revert use_fast default to False

Signed-off-by: ericharper <complex451@gmail.com>

* return super training_epoch_end

Signed-off-by: ericharper <complex451@gmail.com>

* remove optimizer_idx arg from training_step

Signed-off-by: ericharper <complex451@gmail.com>

* remove unused arg from on_train_epoch_end

Signed-off-by: ericharper <complex451@gmail.com>

* add restore_from_path to nemo config

Signed-off-by: ericharper <complex451@gmail.com>

* add comment

Signed-off-by: ericharper <complex451@gmail.com>

* revert

Signed-off-by: ericharper <complex451@gmail.com>

* override connector if not subclassing NLPSaveRestoreConnector for model parallel save

Signed-off-by: ericharper <complex451@gmail.com>

* update test optimizer

Signed-off-by: ericharper <complex451@gmail.com>

* clean up

Signed-off-by: ericharper <complex451@gmail.com>

* clean up

Signed-off-by: ericharper <complex451@gmail.com>

* clean up

Signed-off-by: ericharper <complex451@gmail.com>

* clean up

Signed-off-by: ericharper <complex451@gmail.com>

* make data_prefix mandatory in config

Signed-off-by: ericharper <complex451@gmail.com>

* update installation instructions on readme

Signed-off-by: ericharper <complex451@gmail.com>

* update dockerfile

Signed-off-by: ericharper <complex451@gmail.com>

* add todo

Signed-off-by: ericharper <complex451@gmail.com>

* raise error if trying to use always_save_nemo with model parallel model

Signed-off-by: ericharper <complex451@gmail.com>

* remove comment

Signed-off-by: ericharper <complex451@gmail.com>

Co-authored-by: Sandeep Subramanian <sandeep.subramanian.1@umontreal.ca>
Co-authored-by: Sangkug Lym <slym@nvidia.com>
Co-authored-by: Oleksii Kuchaiev <okuchaiev@users.noreply.github.com>
Co-authored-by: Micha Livne <michalivne@users.noreply.github.com>
Co-authored-by: Micha Livne <mlivne@nvidia.com>
Co-authored-by: Oleksii Kuchaiev <okuchaiev@nvidia.com>
2021-10-20 21:06:37 -06:00

582 lines
22 KiB
Python

# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Transformer based language model."""
import torch
import torch.nn.functional as F
from apex.transformer import parallel_state, tensor_parallel
from apex.transformer.enums import AttnMaskType, LayerType
from nemo.collections.nlp.modules.common.megatron.module import MegatronModule
from nemo.collections.nlp.modules.common.megatron.transformer import ParallelTransformer
from nemo.collections.nlp.modules.common.megatron.utils import (
get_linear_layer,
init_method_normal,
scaled_init_method_normal,
)
def parallel_lm_logits(input_, word_embeddings_weight, parallel_output, bias=None):
"""LM logits using word embedding weights."""
# Parallel logits.
input_parallel = tensor_parallel.copy_to_tensor_model_parallel_region(input_)
# Matrix multiply.
if bias is None:
logits_parallel = F.linear(input_parallel, word_embeddings_weight)
else:
logits_parallel = F.linear(input_parallel, word_embeddings_weight, bias)
# Gather if needed.
if parallel_output:
return logits_parallel
return tensor_parallel.gather_from_tensor_model_parallel_region(logits_parallel)
def get_language_model(
hidden_size,
ffn_hidden_size,
num_layers,
max_position_embeddings,
num_tokentypes,
add_pooler,
vocab_size,
num_attention_heads,
encoder_attn_mask_type,
apply_query_key_layer_scaling=True,
kv_channels=None,
init_method=None,
scaled_init_method=None,
add_decoder=False,
decoder_attn_mask_type=AttnMaskType.causal,
pre_process=True,
post_process=True,
init_method_std=0.02,
use_cpu_initialization=False,
hidden_dropout=0.1,
fused_fp16=False,
fused_bf16=False,
fp32_residual_connection=False,
activations_checkpoint_method=None,
activations_checkpoint_num_layers=1,
layernorm_epsilon=1e-5,
bias_gelu_fusion=True,
openai_gelu=False,
onnx_safe=False,
):
"""Build language model and return along with the key to save."""
assert not (fused_fp16 and fused_bf16), "both fused_fp16 and fused_bf16 flags cannot be True at the same time."
if kv_channels is None:
assert (
hidden_size % num_attention_heads == 0
), 'hidden_size must be divisible by num_attention_heads if kv_channels is None'
kv_channels = hidden_size // num_attention_heads
if init_method is None:
init_method = init_method_normal(init_method_std)
if scaled_init_method is None:
scaled_init_method = scaled_init_method_normal(init_method_std, num_layers)
# Language model.
language_model = TransformerLanguageModel(
init_method=init_method,
output_layer_init_method=scaled_init_method,
encoder_attn_mask_type=encoder_attn_mask_type,
num_tokentypes=num_tokentypes,
vocab_size=vocab_size,
max_position_embeddings=max_position_embeddings,
hidden_size=hidden_size,
num_layers=num_layers,
num_attention_heads=num_attention_heads,
apply_query_key_layer_scaling=apply_query_key_layer_scaling,
kv_channels=kv_channels,
ffn_hidden_size=ffn_hidden_size,
add_decoder=add_decoder,
decoder_attn_mask_type=decoder_attn_mask_type,
add_pooler=add_pooler,
pre_process=pre_process,
post_process=post_process,
use_cpu_initialization=use_cpu_initialization,
hidden_dropout=hidden_dropout,
fused_fp16=fused_fp16,
fused_bf16=fused_bf16,
fp32_residual_connection=fp32_residual_connection,
activations_checkpoint_method=activations_checkpoint_method,
activations_checkpoint_num_layers=activations_checkpoint_num_layers,
layernorm_epsilon=layernorm_epsilon,
bias_gelu_fusion=bias_gelu_fusion,
openai_gelu=openai_gelu,
onnx_safe=onnx_safe,
)
# key used for checkpoints.
language_model_key = 'language_model'
return language_model, language_model_key
class Pooler(MegatronModule):
"""Pooler layer.
Pool hidden states of a specific token (for example start of the
sequence) and add a linear transformation followed by a tanh.
Arguments:
hidden_size: hidden size
init_method: weight initialization method for the linear layer.
bias is set to zero.
"""
def __init__(self, hidden_size, init_method):
super(Pooler, self).__init__()
self.dense = get_linear_layer(hidden_size, hidden_size, init_method)
def forward(self, hidden_states, sequence_index=0):
# hidden_states: [b, s, h]
# sequence_index: index of the token to pool.
pooled = hidden_states[:, sequence_index, :]
pooled = self.dense(pooled)
pooled = torch.tanh(pooled)
return pooled
class Embedding(MegatronModule):
"""Language model embeddings.
Arguments:
hidden_size: hidden size
vocab_size: vocabulary size
max_sequence_length: maximum size of sequence. This
is used for positional embedding
embedding_dropout_prob: dropout probability for embeddings
init_method: weight initialization method
num_tokentypes: size of the token-type embeddings. 0 value
will ignore this embedding
"""
def __init__(
self,
hidden_size,
vocab_size,
max_sequence_length,
embedding_dropout_prob,
init_method,
num_tokentypes=0,
use_cpu_initialization=False,
):
super(Embedding, self).__init__()
self.hidden_size = hidden_size
self.init_method = init_method
self.num_tokentypes = num_tokentypes
# Word embeddings (parallel).
self.word_embeddings = tensor_parallel.VocabParallelEmbedding(
vocab_size, self.hidden_size, init_method=self.init_method, use_cpu_initialization=use_cpu_initialization,
)
self._word_embeddings_key = 'word_embeddings'
# Position embedding (serial).
self.position_embeddings = torch.nn.Embedding(max_sequence_length, self.hidden_size)
self._position_embeddings_key = 'position_embeddings'
# Initialize the position embeddings.
self.init_method(self.position_embeddings.weight)
# Token type embedding.
# Add this as an optional field that can be added through
# method call so we can load a pretrain model without
# token types and add them as needed.
self._tokentype_embeddings_key = 'tokentype_embeddings'
if self.num_tokentypes > 0:
self.tokentype_embeddings = torch.nn.Embedding(self.num_tokentypes, self.hidden_size)
# Initialize the token-type embeddings.
self.init_method(self.tokentype_embeddings.weight)
else:
self.tokentype_embeddings = None
# Embeddings dropout
self.embedding_dropout = torch.nn.Dropout(embedding_dropout_prob)
def add_tokentype_embeddings(self, num_tokentypes):
"""Add token-type embedding. This function is provided so we can add
token-type embeddings in case the pretrained model does not have it.
This allows us to load the model normally and then add this embedding.
"""
if self.tokentype_embeddings is not None:
raise Exception('tokentype embeddings is already initialized')
if torch.distributed.get_rank() == 0:
print('adding embedding for {} tokentypes'.format(num_tokentypes), flush=True)
self.num_tokentypes = num_tokentypes
self.tokentype_embeddings = torch.nn.Embedding(num_tokentypes, self.hidden_size)
# Initialize the token-type embeddings.
self.init_method(self.tokentype_embeddings.weight)
def forward(self, input_ids, position_ids, tokentype_ids=None):
# Embeddings.
words_embeddings = self.word_embeddings(input_ids)
position_embeddings = self.position_embeddings(position_ids)
embeddings = words_embeddings + position_embeddings
if tokentype_ids is not None:
assert self.tokentype_embeddings is not None
embeddings = embeddings + self.tokentype_embeddings(tokentype_ids)
else:
assert self.tokentype_embeddings is None
# Dropout.
embeddings = self.embedding_dropout(embeddings)
return embeddings
def state_dict_for_save_checkpoint(self, destination=None, prefix='', keep_vars=False):
"""For easy load."""
state_dict_ = {}
state_dict_[self._word_embeddings_key] = self.word_embeddings.state_dict(destination, prefix, keep_vars)
state_dict_[self._position_embeddings_key] = self.position_embeddings.state_dict(
destination, prefix, keep_vars
)
if self.num_tokentypes > 0:
state_dict_[self._tokentype_embeddings_key] = self.tokentype_embeddings.state_dict(
destination, prefix, keep_vars
)
return state_dict_
def load_state_dict(self, state_dict, strict=True):
"""Customized load."""
# Word embedding.
if self._word_embeddings_key in state_dict:
state_dict_ = state_dict[self._word_embeddings_key]
else:
# for backward compatibility.
state_dict_ = {}
for key in state_dict.keys():
if 'word_embeddings' in key:
state_dict_[key.split('word_embeddings.')[1]] = state_dict[key]
self.word_embeddings.load_state_dict(state_dict_, strict=strict)
# Position embedding.
if self._position_embeddings_key in state_dict:
state_dict_ = state_dict[self._position_embeddings_key]
else:
# for backward compatibility.
state_dict_ = {}
for key in state_dict.keys():
if 'position_embeddings' in key:
state_dict_[key.split('position_embeddings.')[1]] = state_dict[key]
self.position_embeddings.load_state_dict(state_dict_, strict=strict)
# Tokentype embedding.
if self.num_tokentypes > 0:
state_dict_ = {}
if self._tokentype_embeddings_key in state_dict:
state_dict_ = state_dict[self._tokentype_embeddings_key]
else:
# for backward compatibility.
for key in state_dict.keys():
if 'tokentype_embeddings' in key:
state_dict_[key.split('tokentype_embeddings.')[1]] = state_dict[key]
if len(state_dict_.keys()) > 0:
self.tokentype_embeddings.load_state_dict(state_dict_, strict=strict)
else:
print(
'***WARNING*** expected tokentype embeddings in the ' 'checkpoint but could not find it',
flush=True,
)
class TransformerLanguageModel(MegatronModule):
"""Transformer language model.
Arguments:
transformer_hparams: transformer hyperparameters
vocab_size: vocabulary size
max_sequence_length: maximum size of sequence. This
is used for positional embedding
embedding_dropout_prob: dropout probability for embeddings
num_tokentypes: size of the token-type embeddings. 0 value
will ignore this embedding
"""
def __init__(
self,
init_method,
output_layer_init_method,
encoder_attn_mask_type,
vocab_size,
max_position_embeddings,
hidden_size,
ffn_hidden_size,
num_layers,
num_tokentypes,
num_attention_heads,
apply_query_key_layer_scaling=True,
kv_channels=None,
add_decoder=False,
decoder_attn_mask_type=AttnMaskType.causal,
add_pooler=False,
pre_process=True,
post_process=True,
use_cpu_initialization=False,
hidden_dropout=0.1,
fused_fp16=False,
fused_bf16=False,
fp32_residual_connection=False,
activations_checkpoint_method=None,
activations_checkpoint_num_layers=1,
layernorm_epsilon=1e-5,
bias_gelu_fusion=True,
openai_gelu=False,
onnx_safe=False,
):
super(TransformerLanguageModel, self).__init__()
self.pre_process = pre_process
self.post_process = post_process
self.hidden_size = hidden_size
self.num_layers = num_layers
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.num_tokentypes = num_tokentypes
self.init_method = init_method
self.encoder_attn_mask_type = encoder_attn_mask_type
self.add_decoder = add_decoder
self.decoder_attn_mask_type = decoder_attn_mask_type
self.add_pooler = add_pooler
self.hidden_dropout = hidden_dropout
self.output_layer_init_method = output_layer_init_method
if kv_channels is None:
assert (
hidden_size % num_attention_heads == 0
), 'hidden_size must be divisible by num_attention_heads if kv_channels is None'
kv_channels = hidden_size // num_attention_heads
# Embeddings.
if self.pre_process:
self.embedding = Embedding(
hidden_size=self.hidden_size,
vocab_size=self.vocab_size,
max_sequence_length=self.max_position_embeddings,
init_method=self.init_method,
num_tokentypes=self.num_tokentypes,
use_cpu_initialization=use_cpu_initialization,
embedding_dropout_prob=self.hidden_dropout,
)
self._embedding_key = 'embedding'
# Transformer.
self.encoder = ParallelTransformer(
init_method=self.init_method,
output_layer_init_method=self.output_layer_init_method,
num_layers=self.num_layers,
hidden_size=self.hidden_size,
num_attention_heads=num_attention_heads,
apply_query_key_layer_scaling=apply_query_key_layer_scaling,
kv_channels=kv_channels,
ffn_hidden_size=ffn_hidden_size,
self_attn_mask_type=self.encoder_attn_mask_type,
pre_process=self.pre_process,
post_process=self.post_process,
fused_fp16=fused_fp16,
fused_bf16=fused_bf16,
fp32_residual_connection=fp32_residual_connection,
activations_checkpoint_method=activations_checkpoint_method,
activations_checkpoint_num_layers=activations_checkpoint_num_layers,
layernorm_epsilon=layernorm_epsilon,
hidden_dropout=hidden_dropout,
use_cpu_initialization=use_cpu_initialization,
bias_gelu_fusion=bias_gelu_fusion,
openai_gelu=openai_gelu,
onnx_safe=onnx_safe,
)
self._encoder_key = 'encoder'
# Decoder
if self.add_decoder:
assert (
parallel_state.get_pipeline_model_parallel_world_size() == 1
), 'pipeline parallelism is not supported in the presence of decoder'
self.decoder = ParallelTransformer(
layer_type=LayerType.decoder,
self_attn_mask_type=self.decoder_attn_mask_type,
init_method=self.init_method,
output_layer_init_method=self.output_layer_init_method,
num_layers=self.num_layers,
hidden_size=self.hidden_size,
num_attention_heads=num_attention_heads,
apply_query_key_layer_scaling=apply_query_key_layer_scaling,
kv_channels=kv_channels,
ffn_hidden_size=ffn_hidden_size,
pre_process=self.pre_process,
post_process=self.post_process,
fused_fp16=fused_fp16,
fused_bf16=fused_bf16,
fp32_residual_connection=fp32_residual_connection,
activations_checkpoint_method=activations_checkpoint_method,
activations_checkpoint_num_layers=activations_checkpoint_num_layers,
layernorm_epsilon=layernorm_epsilon,
hidden_dropout=hidden_dropout,
use_cpu_initialization=use_cpu_initialization,
bias_gelu_fusion=bias_gelu_fusion,
openai_gelu=openai_gelu,
onnx_safe=onnx_safe,
)
self._decoder_key = 'decoder'
if self.post_process:
# Pooler.
if self.add_pooler:
self.pooler = Pooler(self.hidden_size, self.init_method)
self._pooler_key = 'pooler'
def set_input_tensor(self, input_tensor):
""" See megatron.model.transformer.set_input_tensor()"""
self.encoder.set_input_tensor(input_tensor)
def forward(
self,
enc_input_ids,
enc_position_ids,
enc_attn_mask,
dec_input_ids=None,
dec_position_ids=None,
dec_attn_mask=None,
enc_dec_attn_mask=None,
tokentype_ids=None,
layer_past=None,
get_key_value=False,
pooling_sequence_index=0,
enc_hidden_states=None,
output_enc_hidden=False,
):
# Embeddings.
if self.pre_process:
embedding_output = self.embedding(enc_input_ids, enc_position_ids, tokentype_ids=tokentype_ids)
encoder_input = embedding_output
else:
encoder_input = None
# encoder.
if enc_hidden_states is None:
encoder_output = self.encoder(
encoder_input, enc_attn_mask, layer_past=layer_past, get_key_value=get_key_value
)
else:
encoder_output = enc_hidden_states.to(encoder_input.dtype)
if self.post_process:
if self.add_pooler:
pooled_output = self.pooler(encoder_output, pooling_sequence_index)
# output_enc_hidden refers to when we just need the encoder's
# output. For example, it is helpful to compute
# similarity between two sequences by average pooling
if not self.add_decoder or output_enc_hidden:
if self.add_pooler and self.post_process:
return encoder_output, pooled_output
else:
return encoder_output
# Decoder Embedding
dec_embedding_output = self.embedding(dec_input_ids, dec_position_ids)
# decoder
decoder_output = self.decoder(
dec_embedding_output,
dec_attn_mask,
layer_past=layer_past,
get_key_value=get_key_value,
encoder_output=encoder_output,
enc_dec_attn_mask=enc_dec_attn_mask,
)
if self.add_pooler and self.post_process:
return decoder_output, encoder_output, pooled_output
else:
return decoder_output, encoder_output
def state_dict_for_save_checkpoint(self, destination=None, prefix='', keep_vars=False):
"""For easy load."""
state_dict_ = {}
if self.pre_process:
state_dict_[self._embedding_key] = self.embedding.state_dict_for_save_checkpoint(
destination, prefix, keep_vars
)
state_dict_[self._encoder_key] = self.encoder.state_dict_for_save_checkpoint(destination, prefix, keep_vars)
if self.post_process:
if self.add_pooler:
state_dict_[self._pooler_key] = self.pooler.state_dict_for_save_checkpoint(
destination, prefix, keep_vars
)
if self.add_decoder:
state_dict_[self._decoder_key] = self.decoder.state_dict_for_save_checkpoint(
destination, prefix, keep_vars
)
return state_dict_
def load_state_dict(self, state_dict, strict=True):
"""Customized load."""
# Embedding.
if self.pre_process:
if self._embedding_key in state_dict:
state_dict_ = state_dict[self._embedding_key]
else:
# for backward compatibility.
state_dict_ = {}
for key in state_dict.keys():
if '_embeddings' in key:
state_dict_[key] = state_dict[key]
self.embedding.load_state_dict(state_dict_, strict=strict)
# Encoder.
if self._encoder_key in state_dict:
state_dict_ = state_dict[self._encoder_key]
# for backward compatibility.
elif 'transformer' in state_dict:
state_dict_ = state_dict['transformer']
else:
# for backward compatibility.
state_dict_ = {}
for key in state_dict.keys():
if 'transformer.' in key:
state_dict_[key.split('transformer.')[1]] = state_dict[key]
# for backward compatibility.
state_dict_self_attention = {}
for key in state_dict_.keys():
if '.attention.' in key:
state_dict_self_attention[key.replace(".attention.", ".self_attention.")] = state_dict_[key]
else:
state_dict_self_attention[key] = state_dict_[key]
state_dict_ = state_dict_self_attention
self.encoder.load_state_dict(state_dict_, strict=strict)
if self.post_process:
# pooler
if self.add_pooler:
assert 'pooler' in state_dict, 'could not find data for pooler in the checkpoint'
self.pooler.load_state_dict(state_dict[self._pooler_key], strict=strict)
# decoder
if self.add_decoder:
assert 'decoder' in state_dict, 'could not find data for pooler in the checkpoint'
self.decoder.load_state_dict(state_dict[self._decoder_key], strict=strict)