[BigNLP] Merge Megatron GPT to main (#2975)
* fix gpu init after removing debug print in mpu Signed-off-by: ericharper <complex451@gmail.com> * add fused_adam Signed-off-by: ericharper <complex451@gmail.com> * check ds is not none before logging len Signed-off-by: ericharper <complex451@gmail.com> * set fp16 arg to true and fix enum conflict Signed-off-by: ericharper <complex451@gmail.com> * make fp16 arg configurable Signed-off-by: ericharper <complex451@gmail.com> * add grad clip from megatron Signed-off-by: ericharper <complex451@gmail.com> * Linear warmup with cosine annealing and constant holding (#2846) * Testing cosine schedule Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca> * Style fixes Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca> * Fixes Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca> * More fixes Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca> * update config for constant steps in schedule Signed-off-by: ericharper <complex451@gmail.com> * temporarily import enum from megatron Signed-off-by: ericharper <complex451@gmail.com> * add grad clip for fp32 Signed-off-by: ericharper <complex451@gmail.com> * update check for _del_model_without_trainer Signed-off-by: ericharper <complex451@gmail.com> * updating restore for model parallel Signed-off-by: ericharper <complex451@gmail.com> * add predict script Signed-off-by: ericharper <complex451@gmail.com> * update test iters Signed-off-by: ericharper <complex451@gmail.com> * add barrier Signed-off-by: ericharper <complex451@gmail.com> * return if clip_val is 0 or None Signed-off-by: ericharper <complex451@gmail.com> * when using amp clip grads after they are unscaled Signed-off-by: ericharper <complex451@gmail.com> * make native amp scaler hyperparams configurable Signed-off-by: ericharper <complex451@gmail.com> * (1) nvfuser, (2) amp-casting decoration (#2894) * (1) nvfuser, (2) amp-casting decoration Signed-off-by: Sangkug Lym <slym@nvidia.com> * support bf16 Signed-off-by: Sangkug Lym <slym@nvidia.com> * update package info Signed-off-by: ericharper <complex451@gmail.com> * add set device to constructor Signed-off-by: ericharper <complex451@gmail.com> * set_device in constructor Signed-off-by: ericharper <complex451@gmail.com> * [BigNLP] Remove megatron-lm dependency. (#2910) * remove args Signed-off-by: ericharper <complex451@gmail.com> * remove args Signed-off-by: ericharper <complex451@gmail.com> * remove args Signed-off-by: ericharper <complex451@gmail.com> * remove args Signed-off-by: ericharper <complex451@gmail.com> * remove args in progress Signed-off-by: ericharper <complex451@gmail.com> * remove args in progress Signed-off-by: ericharper <complex451@gmail.com> * remove args in progress Signed-off-by: ericharper <complex451@gmail.com> * remove args in progress Signed-off-by: ericharper <complex451@gmail.com> * add load_fused_kernels Signed-off-by: ericharper <complex451@gmail.com> * add load_fused_kernels Signed-off-by: ericharper <complex451@gmail.com> * update megatron_init Signed-off-by: ericharper <complex451@gmail.com> * add fused kernels Signed-off-by: ericharper <complex451@gmail.com> * add fused kernels Signed-off-by: ericharper <complex451@gmail.com> * update process batch Signed-off-by: ericharper <complex451@gmail.com> * remove erroneous import Signed-off-by: ericharper <complex451@gmail.com> * remove erroneous import Signed-off-by: ericharper <complex451@gmail.com> * remove erroneous import Signed-off-by: ericharper <complex451@gmail.com> * add megatron clip_grad Signed-off-by: ericharper <complex451@gmail.com> * trying to resolve circular import error Signed-off-by: ericharper <complex451@gmail.com> * rename file Signed-off-by: ericharper <complex451@gmail.com> * remove non-gpt models and datasets from __init__ files Signed-off-by: ericharper <complex451@gmail.com> * set device in constructorfor gpu init Signed-off-by: ericharper <complex451@gmail.com> * set device in constructorfor gpu init Signed-off-by: ericharper <complex451@gmail.com> * set_device in constructor Signed-off-by: ericharper <complex451@gmail.com> * clean config Signed-off-by: ericharper <complex451@gmail.com> * update MegatronDataset Signed-off-by: ericharper <complex451@gmail.com> * clean up MegatronModule Signed-off-by: ericharper <complex451@gmail.com> * clean up MegatronModule Signed-off-by: ericharper <complex451@gmail.com> * rename fp16 and bf16 flags to fused_softmax_input_in_fp16/bf16 Signed-off-by: ericharper <complex451@gmail.com> * rename to fused_fp16 Signed-off-by: ericharper <complex451@gmail.com> * add fused_fp16 arg to LayerNorm calls Signed-off-by: ericharper <complex451@gmail.com> * fix arg name Signed-off-by: ericharper <complex451@gmail.com> * fix arg name Signed-off-by: ericharper <complex451@gmail.com> * fix import Signed-off-by: ericharper <complex451@gmail.com> * update arg Signed-off-by: ericharper <complex451@gmail.com> * skip warmup default to True Signed-off-by: ericharper <complex451@gmail.com> * skip warmup default to True Signed-off-by: ericharper <complex451@gmail.com> * Adding complete method to MegatronGPTModel (#2935) Signed-off-by: Oleksii Kuchaiev <okuchaiev@nvidia.com> * make ffn_hidden_size mandatory Signed-off-by: ericharper <complex451@gmail.com> * Manually migrating timing of step into branch (#2937) * 1. Manually migrating timing of step into branch. Signed-off-by: Micha Livne <mlivne@nvidia.com> * 1. Updated file name and content. Signed-off-by: Micha Livne <mlivne@nvidia.com> * 1. Updated to latest code. Signed-off-by: Micha Livne <mlivne@nvidia.com> Co-authored-by: Micha Livne <mlivne@nvidia.com> * remove unused imports Signed-off-by: ericharper <complex451@gmail.com> * remove unused import Signed-off-by: ericharper <complex451@gmail.com> * remove unused import Signed-off-by: ericharper <complex451@gmail.com> * remove unused import Signed-off-by: ericharper <complex451@gmail.com> * check fused_fp16 and fused_bf16 are not both True Signed-off-by: ericharper <complex451@gmail.com> * update predict script for model parallel .nemo Signed-off-by: ericharper <complex451@gmail.com> * typo Signed-off-by: ericharper <complex451@gmail.com> * typo Signed-off-by: ericharper <complex451@gmail.com> Co-authored-by: Oleksii Kuchaiev <okuchaiev@users.noreply.github.com> Co-authored-by: Micha Livne <michalivne@users.noreply.github.com> Co-authored-by: Micha Livne <mlivne@nvidia.com> * NVfuser (#2943) * activation checkpoint recompute Signed-off-by: Sangkug Lym <slym@nvidia.com> * selective nvfuser setup * Megatron gpt bfloat support (#2926) * Save/restore fix Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca> * Another merge Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca> * Bf16 args in init Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca> * Set precision Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca> * Remove debug stuff Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca> * add bf16 casting decorator Signed-off-by: Sangkug Lym <slym@nvidia.com> * Bfloat layernorm propagation Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca> * activation checkpoint recompute Signed-off-by: Sangkug Lym <slym@nvidia.com> * selective nvfuser setup * More arg removal Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca> * Remove BERTDataset Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca> * update to latest apex and patch transformer autocast Signed-off-by: ericharper <complex451@gmail.com> Co-authored-by: Sangkug Lym <slym@nvidia.com> Co-authored-by: ericharper <complex451@gmail.com> * don't set jit for bf16 Signed-off-by: ericharper <complex451@gmail.com> * replace apex.mpu Signed-off-by: ericharper <complex451@gmail.com> * fix grad clip Signed-off-by: ericharper <complex451@gmail.com> * NVFuser fixes (#2951) * Fuser fixes Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca> * Remove dummy handler Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca> * Remove PTL plugin based logic for fusion Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca> * remove duplicated file Signed-off-by: ericharper <complex451@gmail.com> * typo (#2960) Signed-off-by: ericharper <complex451@gmail.com> * [BigNLP] Script to convert GPT checkpoint to .nemo (#2958) * remove args Signed-off-by: ericharper <complex451@gmail.com> * remove args Signed-off-by: ericharper <complex451@gmail.com> * remove args Signed-off-by: ericharper <complex451@gmail.com> * remove args Signed-off-by: ericharper <complex451@gmail.com> * remove args in progress Signed-off-by: ericharper <complex451@gmail.com> * remove args in progress Signed-off-by: ericharper <complex451@gmail.com> * remove args in progress Signed-off-by: ericharper <complex451@gmail.com> * remove args in progress Signed-off-by: ericharper <complex451@gmail.com> * add load_fused_kernels Signed-off-by: ericharper <complex451@gmail.com> * add load_fused_kernels Signed-off-by: ericharper <complex451@gmail.com> * update megatron_init Signed-off-by: ericharper <complex451@gmail.com> * add fused kernels Signed-off-by: ericharper <complex451@gmail.com> * add fused kernels Signed-off-by: ericharper <complex451@gmail.com> * update process batch Signed-off-by: ericharper <complex451@gmail.com> * remove erroneous import Signed-off-by: ericharper <complex451@gmail.com> * remove erroneous import Signed-off-by: ericharper <complex451@gmail.com> * remove erroneous import Signed-off-by: ericharper <complex451@gmail.com> * add megatron clip_grad Signed-off-by: ericharper <complex451@gmail.com> * trying to resolve circular import error Signed-off-by: ericharper <complex451@gmail.com> * rename file Signed-off-by: ericharper <complex451@gmail.com> * remove non-gpt models and datasets from __init__ files Signed-off-by: ericharper <complex451@gmail.com> * set device in constructorfor gpu init Signed-off-by: ericharper <complex451@gmail.com> * set device in constructorfor gpu init Signed-off-by: ericharper <complex451@gmail.com> * set_device in constructor Signed-off-by: ericharper <complex451@gmail.com> * clean config Signed-off-by: ericharper <complex451@gmail.com> * update MegatronDataset Signed-off-by: ericharper <complex451@gmail.com> * clean up MegatronModule Signed-off-by: ericharper <complex451@gmail.com> * clean up MegatronModule Signed-off-by: ericharper <complex451@gmail.com> * rename fp16 and bf16 flags to fused_softmax_input_in_fp16/bf16 Signed-off-by: ericharper <complex451@gmail.com> * rename to fused_fp16 Signed-off-by: ericharper <complex451@gmail.com> * add fused_fp16 arg to LayerNorm calls Signed-off-by: ericharper <complex451@gmail.com> * fix arg name Signed-off-by: ericharper <complex451@gmail.com> * fix arg name Signed-off-by: ericharper <complex451@gmail.com> * fix import Signed-off-by: ericharper <complex451@gmail.com> * update arg Signed-off-by: ericharper <complex451@gmail.com> * skip warmup default to True Signed-off-by: ericharper <complex451@gmail.com> * skip warmup default to True Signed-off-by: ericharper <complex451@gmail.com> * Adding complete method to MegatronGPTModel (#2935) Signed-off-by: Oleksii Kuchaiev <okuchaiev@nvidia.com> * make ffn_hidden_size mandatory Signed-off-by: ericharper <complex451@gmail.com> * Manually migrating timing of step into branch (#2937) * 1. Manually migrating timing of step into branch. Signed-off-by: Micha Livne <mlivne@nvidia.com> * 1. Updated file name and content. Signed-off-by: Micha Livne <mlivne@nvidia.com> * 1. Updated to latest code. Signed-off-by: Micha Livne <mlivne@nvidia.com> Co-authored-by: Micha Livne <mlivne@nvidia.com> * remove unused imports Signed-off-by: ericharper <complex451@gmail.com> * remove unused import Signed-off-by: ericharper <complex451@gmail.com> * remove unused import Signed-off-by: ericharper <complex451@gmail.com> * remove unused import Signed-off-by: ericharper <complex451@gmail.com> * check fused_fp16 and fused_bf16 are not both True Signed-off-by: ericharper <complex451@gmail.com> * update predict script for model parallel .nemo Signed-off-by: ericharper <complex451@gmail.com> * typo Signed-off-by: ericharper <complex451@gmail.com> * add script to convert .ckpt to .nemo Signed-off-by: ericharper <complex451@gmail.com> * in progress Signed-off-by: ericharper <complex451@gmail.com> * update Signed-off-by: ericharper <complex451@gmail.com> * convert mp checkpoints to nemo Signed-off-by: ericharper <complex451@gmail.com> * update help Signed-off-by: ericharper <complex451@gmail.com> * add safeguard for model parallel save_to Signed-off-by: ericharper <complex451@gmail.com> * adjust NLPModel save_to to be safer for model parallel Signed-off-by: Oleksii Kuchaiev <okuchaiev@nvidia.com> Co-authored-by: Oleksii Kuchaiev <okuchaiev@users.noreply.github.com> Co-authored-by: Micha Livne <michalivne@users.noreply.github.com> Co-authored-by: Micha Livne <mlivne@nvidia.com> Co-authored-by: Oleksii Kuchaiev <okuchaiev@nvidia.com> * [BigNLP] Update GPT evaluation to work with tensor model parallel (#2959) * in progress Signed-off-by: ericharper <complex451@gmail.com> * update args Signed-off-by: ericharper <complex451@gmail.com> * add request dataset Signed-off-by: ericharper <complex451@gmail.com> * tokenize request Signed-off-by: ericharper <complex451@gmail.com> * in progress Signed-off-by: ericharper <complex451@gmail.com> * able to run Signed-off-by: ericharper <complex451@gmail.com> * reduce logits Signed-off-by: ericharper <complex451@gmail.com> * capture response Signed-off-by: ericharper <complex451@gmail.com> * squeeze and unsqueeze Signed-off-by: ericharper <complex451@gmail.com> * handle non model parallel case Signed-off-by: ericharper <complex451@gmail.com> * clean imports Signed-off-by: ericharper <complex451@gmail.com> * add file Signed-off-by: ericharper <complex451@gmail.com> * convert logits to log_probs Signed-off-by: Oleksii Kuchaiev <okuchaiev@nvidia.com> * rename logits to log_probs Signed-off-by: Oleksii Kuchaiev <okuchaiev@nvidia.com> Co-authored-by: Oleksii Kuchaiev <okuchaiev@nvidia.com> * add megatron gpt pretraining Signed-off-by: ericharper <complex451@gmail.com> * add megatron gpt pretraining Signed-off-by: ericharper <complex451@gmail.com> * add megatron gpt pretraining Signed-off-by: ericharper <complex451@gmail.com> * updating to work with latest megatron Signed-off-by: ericharper <complex451@gmail.com> * updating to work with latest megatron Signed-off-by: ericharper <complex451@gmail.com> * update _del_model Signed-off-by: ericharper <complex451@gmail.com> * adding gpt model Signed-off-by: ericharper <complex451@gmail.com> * adding gpt model Signed-off-by: ericharper <complex451@gmail.com> * adding gpt model Signed-off-by: ericharper <complex451@gmail.com> * instantiate GPTmodel Signed-off-by: ericharper <complex451@gmail.com> * adding build dataset Signed-off-by: ericharper <complex451@gmail.com> * build megatron dataset in .setup Signed-off-by: ericharper <complex451@gmail.com> * setup dataloader Signed-off-by: ericharper <complex451@gmail.com> * add vocab_file and merge_file to megatron init Signed-off-by: ericharper <complex451@gmail.com> * add forward Signed-off-by: ericharper <complex451@gmail.com> * add train loss Signed-off-by: ericharper <complex451@gmail.com> * add optimizer Signed-off-by: ericharper <complex451@gmail.com> * add exp_manager Signed-off-by: ericharper <complex451@gmail.com> * multi-gpu is working Signed-off-by: ericharper <complex451@gmail.com> * adding val loop Signed-off-by: ericharper <complex451@gmail.com> * style Signed-off-by: ericharper <complex451@gmail.com> * adding val loop Signed-off-by: ericharper <complex451@gmail.com> * fix ranks Signed-off-by: ericharper <complex451@gmail.com> * fix model parallel checkpoint saving Signed-off-by: ericharper <complex451@gmail.com> * fix _del_model Signed-off-by: ericharper <complex451@gmail.com> * added megatron batch sampler Signed-off-by: ericharper <complex451@gmail.com> * try to fix num steps Signed-off-by: ericharper <complex451@gmail.com> * add wandb to config Signed-off-by: ericharper <complex451@gmail.com> * log lr Signed-off-by: ericharper <complex451@gmail.com> * add warmup ratio to config Signed-off-by: ericharper <complex451@gmail.com> * update configs Signed-off-by: ericharper <complex451@gmail.com> * update configs Signed-off-by: ericharper <complex451@gmail.com> * add cpu init to args Signed-off-by: ericharper <complex451@gmail.com> * update config Signed-off-by: ericharper <complex451@gmail.com> * update config Signed-off-by: ericharper <complex451@gmail.com> * Initial megatron dataset port Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca> * Fix merge conflicts Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca> * License fixes and megatron model porting Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca> * Style fixes Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca> * More fixes to import from nemo rather than megatron Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca> * Fix circular imports Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca> * Style fixes Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca> * Revert config file Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca> * Restructure further to avoid circular imports Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca> * add Makefile Signed-off-by: ericharper <complex451@gmail.com> * Add megatron modules Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca> * add license Signed-off-by: ericharper <complex451@gmail.com> * Port from latest megatron Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca> * update cfg Signed-off-by: ericharper <complex451@gmail.com> * update config Signed-off-by: ericharper <complex451@gmail.com> * add _del_model_without_trainer Signed-off-by: ericharper <complex451@gmail.com> * add data preprocessing script Signed-off-by: ericharper <complex451@gmail.com> * update config Signed-off-by: ericharper <complex451@gmail.com> * use apex mpu Signed-off-by: ericharper <complex451@gmail.com> * replace print_rank_0 with nemo utils logging Signed-off-by: ericharper <complex451@gmail.com> * use apex mpu Signed-off-by: ericharper <complex451@gmail.com> * use apex mpu Signed-off-by: ericharper <complex451@gmail.com> * add use_cpu_initialization Signed-off-by: ericharper <complex451@gmail.com> * fixing autoresume in progress Signed-off-by: ericharper <complex451@gmail.com> * properly removing last checkpoint Signed-off-by: ericharper <complex451@gmail.com> * log consumed samples Signed-off-by: ericharper <complex451@gmail.com> * fix mp autoresume Signed-off-by: ericharper <complex451@gmail.com> * add NLPSaveRestoreConnector Signed-off-by: ericharper <complex451@gmail.com> * Megatron GPT training with NeMo tokenizers (#2818) * Update files from megatron repo Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca> * Remove non NLP data related files from megatron Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca> * Merge megatron and nemo tokenizers Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca> * Remove get_tokenizer() calls from gpt model Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca> * Update tokenizer yaml config Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca> * add todo Signed-off-by: ericharper <complex451@gmail.com> * update config Signed-off-by: ericharper <complex451@gmail.com> * make init_method_std configurable Signed-off-by: ericharper <complex451@gmail.com> * make gpu init work by setting random seed earlier Signed-off-by: ericharper <complex451@gmail.com> * fix gpu init after removing debug print in mpu Signed-off-by: ericharper <complex451@gmail.com> * add fused_adam Signed-off-by: ericharper <complex451@gmail.com> * check ds is not none before logging len Signed-off-by: ericharper <complex451@gmail.com> * set fp16 arg to true and fix enum conflict Signed-off-by: ericharper <complex451@gmail.com> * make fp16 arg configurable Signed-off-by: ericharper <complex451@gmail.com> * add grad clip from megatron Signed-off-by: ericharper <complex451@gmail.com> * Linear warmup with cosine annealing and constant holding (#2846) * Testing cosine schedule Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca> * Style fixes Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca> * Fixes Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca> * More fixes Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca> * update config for constant steps in schedule Signed-off-by: ericharper <complex451@gmail.com> * temporarily import enum from megatron Signed-off-by: ericharper <complex451@gmail.com> * add grad clip for fp32 Signed-off-by: ericharper <complex451@gmail.com> * update check for _del_model_without_trainer Signed-off-by: ericharper <complex451@gmail.com> * updating restore for model parallel Signed-off-by: ericharper <complex451@gmail.com> * add predict script Signed-off-by: ericharper <complex451@gmail.com> * update test iters Signed-off-by: ericharper <complex451@gmail.com> * add barrier Signed-off-by: ericharper <complex451@gmail.com> * return if clip_val is 0 or None Signed-off-by: ericharper <complex451@gmail.com> * when using amp clip grads after they are unscaled Signed-off-by: ericharper <complex451@gmail.com> * make native amp scaler hyperparams configurable Signed-off-by: ericharper <complex451@gmail.com> * (1) nvfuser, (2) amp-casting decoration (#2894) * (1) nvfuser, (2) amp-casting decoration Signed-off-by: Sangkug Lym <slym@nvidia.com> * support bf16 Signed-off-by: Sangkug Lym <slym@nvidia.com> * update package info Signed-off-by: ericharper <complex451@gmail.com> * add set device to constructor Signed-off-by: ericharper <complex451@gmail.com> * set_device in constructor Signed-off-by: ericharper <complex451@gmail.com> * [BigNLP] Remove megatron-lm dependency. (#2910) * remove args Signed-off-by: ericharper <complex451@gmail.com> * remove args Signed-off-by: ericharper <complex451@gmail.com> * remove args Signed-off-by: ericharper <complex451@gmail.com> * remove args Signed-off-by: ericharper <complex451@gmail.com> * remove args in progress Signed-off-by: ericharper <complex451@gmail.com> * remove args in progress Signed-off-by: ericharper <complex451@gmail.com> * remove args in progress Signed-off-by: ericharper <complex451@gmail.com> * remove args in progress Signed-off-by: ericharper <complex451@gmail.com> * add load_fused_kernels Signed-off-by: ericharper <complex451@gmail.com> * add load_fused_kernels Signed-off-by: ericharper <complex451@gmail.com> * update megatron_init Signed-off-by: ericharper <complex451@gmail.com> * add fused kernels Signed-off-by: ericharper <complex451@gmail.com> * add fused kernels Signed-off-by: ericharper <complex451@gmail.com> * update process batch Signed-off-by: ericharper <complex451@gmail.com> * remove erroneous import Signed-off-by: ericharper <complex451@gmail.com> * remove erroneous import Signed-off-by: ericharper <complex451@gmail.com> * remove erroneous import Signed-off-by: ericharper <complex451@gmail.com> * add megatron clip_grad Signed-off-by: ericharper <complex451@gmail.com> * trying to resolve circular import error Signed-off-by: ericharper <complex451@gmail.com> * rename file Signed-off-by: ericharper <complex451@gmail.com> * remove non-gpt models and datasets from __init__ files Signed-off-by: ericharper <complex451@gmail.com> * set device in constructorfor gpu init Signed-off-by: ericharper <complex451@gmail.com> * set device in constructorfor gpu init Signed-off-by: ericharper <complex451@gmail.com> * set_device in constructor Signed-off-by: ericharper <complex451@gmail.com> * clean config Signed-off-by: ericharper <complex451@gmail.com> * update MegatronDataset Signed-off-by: ericharper <complex451@gmail.com> * clean up MegatronModule Signed-off-by: ericharper <complex451@gmail.com> * clean up MegatronModule Signed-off-by: ericharper <complex451@gmail.com> * rename fp16 and bf16 flags to fused_softmax_input_in_fp16/bf16 Signed-off-by: ericharper <complex451@gmail.com> * rename to fused_fp16 Signed-off-by: ericharper <complex451@gmail.com> * add fused_fp16 arg to LayerNorm calls Signed-off-by: ericharper <complex451@gmail.com> * fix arg name Signed-off-by: ericharper <complex451@gmail.com> * fix arg name Signed-off-by: ericharper <complex451@gmail.com> * fix import Signed-off-by: ericharper <complex451@gmail.com> * update arg Signed-off-by: ericharper <complex451@gmail.com> * skip warmup default to True Signed-off-by: ericharper <complex451@gmail.com> * skip warmup default to True Signed-off-by: ericharper <complex451@gmail.com> * Adding complete method to MegatronGPTModel (#2935) Signed-off-by: Oleksii Kuchaiev <okuchaiev@nvidia.com> * make ffn_hidden_size mandatory Signed-off-by: ericharper <complex451@gmail.com> * Manually migrating timing of step into branch (#2937) * 1. Manually migrating timing of step into branch. Signed-off-by: Micha Livne <mlivne@nvidia.com> * 1. Updated file name and content. Signed-off-by: Micha Livne <mlivne@nvidia.com> * 1. Updated to latest code. Signed-off-by: Micha Livne <mlivne@nvidia.com> Co-authored-by: Micha Livne <mlivne@nvidia.com> * remove unused imports Signed-off-by: ericharper <complex451@gmail.com> * remove unused import Signed-off-by: ericharper <complex451@gmail.com> * remove unused import Signed-off-by: ericharper <complex451@gmail.com> * remove unused import Signed-off-by: ericharper <complex451@gmail.com> * check fused_fp16 and fused_bf16 are not both True Signed-off-by: ericharper <complex451@gmail.com> * update predict script for model parallel .nemo Signed-off-by: ericharper <complex451@gmail.com> * typo Signed-off-by: ericharper <complex451@gmail.com> * typo Signed-off-by: ericharper <complex451@gmail.com> Co-authored-by: Oleksii Kuchaiev <okuchaiev@users.noreply.github.com> Co-authored-by: Micha Livne <michalivne@users.noreply.github.com> Co-authored-by: Micha Livne <mlivne@nvidia.com> * NVfuser (#2943) * activation checkpoint recompute Signed-off-by: Sangkug Lym <slym@nvidia.com> * selective nvfuser setup * Megatron gpt bfloat support (#2926) * Save/restore fix Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca> * Another merge Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca> * Bf16 args in init Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca> * Set precision Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca> * Remove debug stuff Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca> * add bf16 casting decorator Signed-off-by: Sangkug Lym <slym@nvidia.com> * Bfloat layernorm propagation Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca> * activation checkpoint recompute Signed-off-by: Sangkug Lym <slym@nvidia.com> * selective nvfuser setup * More arg removal Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca> * Remove BERTDataset Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca> * update to latest apex and patch transformer autocast Signed-off-by: ericharper <complex451@gmail.com> Co-authored-by: Sangkug Lym <slym@nvidia.com> Co-authored-by: ericharper <complex451@gmail.com> * don't set jit for bf16 Signed-off-by: ericharper <complex451@gmail.com> * replace apex.mpu Signed-off-by: ericharper <complex451@gmail.com> * fix grad clip Signed-off-by: ericharper <complex451@gmail.com> * NVFuser fixes (#2951) * Fuser fixes Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca> * Remove dummy handler Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca> * Remove PTL plugin based logic for fusion Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca> * remove duplicated file Signed-off-by: ericharper <complex451@gmail.com> * typo (#2960) Signed-off-by: ericharper <complex451@gmail.com> * [BigNLP] Script to convert GPT checkpoint to .nemo (#2958) * remove args Signed-off-by: ericharper <complex451@gmail.com> * remove args Signed-off-by: ericharper <complex451@gmail.com> * remove args Signed-off-by: ericharper <complex451@gmail.com> * remove args Signed-off-by: ericharper <complex451@gmail.com> * remove args in progress Signed-off-by: ericharper <complex451@gmail.com> * remove args in progress Signed-off-by: ericharper <complex451@gmail.com> * remove args in progress Signed-off-by: ericharper <complex451@gmail.com> * remove args in progress Signed-off-by: ericharper <complex451@gmail.com> * add load_fused_kernels Signed-off-by: ericharper <complex451@gmail.com> * add load_fused_kernels Signed-off-by: ericharper <complex451@gmail.com> * update megatron_init Signed-off-by: ericharper <complex451@gmail.com> * add fused kernels Signed-off-by: ericharper <complex451@gmail.com> * add fused kernels Signed-off-by: ericharper <complex451@gmail.com> * update process batch Signed-off-by: ericharper <complex451@gmail.com> * remove erroneous import Signed-off-by: ericharper <complex451@gmail.com> * remove erroneous import Signed-off-by: ericharper <complex451@gmail.com> * remove erroneous import Signed-off-by: ericharper <complex451@gmail.com> * add megatron clip_grad Signed-off-by: ericharper <complex451@gmail.com> * trying to resolve circular import error Signed-off-by: ericharper <complex451@gmail.com> * rename file Signed-off-by: ericharper <complex451@gmail.com> * remove non-gpt models and datasets from __init__ files Signed-off-by: ericharper <complex451@gmail.com> * set device in constructorfor gpu init Signed-off-by: ericharper <complex451@gmail.com> * set device in constructorfor gpu init Signed-off-by: ericharper <complex451@gmail.com> * set_device in constructor Signed-off-by: ericharper <complex451@gmail.com> * clean config Signed-off-by: ericharper <complex451@gmail.com> * update MegatronDataset Signed-off-by: ericharper <complex451@gmail.com> * clean up MegatronModule Signed-off-by: ericharper <complex451@gmail.com> * clean up MegatronModule Signed-off-by: ericharper <complex451@gmail.com> * rename fp16 and bf16 flags to fused_softmax_input_in_fp16/bf16 Signed-off-by: ericharper <complex451@gmail.com> * rename to fused_fp16 Signed-off-by: ericharper <complex451@gmail.com> * add fused_fp16 arg to LayerNorm calls Signed-off-by: ericharper <complex451@gmail.com> * fix arg name Signed-off-by: ericharper <complex451@gmail.com> * fix arg name Signed-off-by: ericharper <complex451@gmail.com> * fix import Signed-off-by: ericharper <complex451@gmail.com> * update arg Signed-off-by: ericharper <complex451@gmail.com> * skip warmup default to True Signed-off-by: ericharper <complex451@gmail.com> * skip warmup default to True Signed-off-by: ericharper <complex451@gmail.com> * Adding complete method to MegatronGPTModel (#2935) Signed-off-by: Oleksii Kuchaiev <okuchaiev@nvidia.com> * make ffn_hidden_size mandatory Signed-off-by: ericharper <complex451@gmail.com> * Manually migrating timing of step into branch (#2937) * 1. Manually migrating timing of step into branch. Signed-off-by: Micha Livne <mlivne@nvidia.com> * 1. Updated file name and content. Signed-off-by: Micha Livne <mlivne@nvidia.com> * 1. Updated to latest code. Signed-off-by: Micha Livne <mlivne@nvidia.com> Co-authored-by: Micha Livne <mlivne@nvidia.com> * remove unused imports Signed-off-by: ericharper <complex451@gmail.com> * remove unused import Signed-off-by: ericharper <complex451@gmail.com> * remove unused import Signed-off-by: ericharper <complex451@gmail.com> * remove unused import Signed-off-by: ericharper <complex451@gmail.com> * check fused_fp16 and fused_bf16 are not both True Signed-off-by: ericharper <complex451@gmail.com> * update predict script for model parallel .nemo Signed-off-by: ericharper <complex451@gmail.com> * typo Signed-off-by: ericharper <complex451@gmail.com> * add script to convert .ckpt to .nemo Signed-off-by: ericharper <complex451@gmail.com> * in progress Signed-off-by: ericharper <complex451@gmail.com> * update Signed-off-by: ericharper <complex451@gmail.com> * convert mp checkpoints to nemo Signed-off-by: ericharper <complex451@gmail.com> * update help Signed-off-by: ericharper <complex451@gmail.com> * add safeguard for model parallel save_to Signed-off-by: ericharper <complex451@gmail.com> * adjust NLPModel save_to to be safer for model parallel Signed-off-by: Oleksii Kuchaiev <okuchaiev@nvidia.com> Co-authored-by: Oleksii Kuchaiev <okuchaiev@users.noreply.github.com> Co-authored-by: Micha Livne <michalivne@users.noreply.github.com> Co-authored-by: Micha Livne <mlivne@nvidia.com> Co-authored-by: Oleksii Kuchaiev <okuchaiev@nvidia.com> * [BigNLP] Update GPT evaluation to work with tensor model parallel (#2959) * in progress Signed-off-by: ericharper <complex451@gmail.com> * update args Signed-off-by: ericharper <complex451@gmail.com> * add request dataset Signed-off-by: ericharper <complex451@gmail.com> * tokenize request Signed-off-by: ericharper <complex451@gmail.com> * in progress Signed-off-by: ericharper <complex451@gmail.com> * able to run Signed-off-by: ericharper <complex451@gmail.com> * reduce logits Signed-off-by: ericharper <complex451@gmail.com> * capture response Signed-off-by: ericharper <complex451@gmail.com> * squeeze and unsqueeze Signed-off-by: ericharper <complex451@gmail.com> * handle non model parallel case Signed-off-by: ericharper <complex451@gmail.com> * clean imports Signed-off-by: ericharper <complex451@gmail.com> * add file Signed-off-by: ericharper <complex451@gmail.com> * convert logits to log_probs Signed-off-by: Oleksii Kuchaiev <okuchaiev@nvidia.com> * rename logits to log_probs Signed-off-by: Oleksii Kuchaiev <okuchaiev@nvidia.com> Co-authored-by: Oleksii Kuchaiev <okuchaiev@nvidia.com> * style Signed-off-by: ericharper <complex451@gmail.com> * fix copyright headers Signed-off-by: ericharper <complex451@gmail.com> * fix copyright headers Signed-off-by: ericharper <complex451@gmail.com> * remove old TimingCallback Signed-off-by: ericharper <complex451@gmail.com> * style Signed-off-by: ericharper <complex451@gmail.com> * update jenkins to use latest apex and sandeep's fork Signed-off-by: ericharper <complex451@gmail.com> * update jenkins Signed-off-by: ericharper <complex451@gmail.com> * update jenkins Signed-off-by: ericharper <complex451@gmail.com> * update jenkins Signed-off-by: ericharper <complex451@gmail.com> * update jenkins Signed-off-by: ericharper <complex451@gmail.com> * try 2109 container Signed-off-by: ericharper <complex451@gmail.com> * try cuda container Signed-off-by: ericharper <complex451@gmail.com> * use internal container Signed-off-by: ericharper <complex451@gmail.com> * update checkpoint tests Signed-off-by: ericharper <complex451@gmail.com> * fix scheduler args Signed-off-by: ericharper <complex451@gmail.com> * update eval Signed-off-by: ericharper <complex451@gmail.com> * style Signed-off-by: ericharper <complex451@gmail.com> * update jenkins to use ptl 1.5 rc Signed-off-by: ericharper <complex451@gmail.com> * add import guard to jenkins Signed-off-by: ericharper <complex451@gmail.com> * add import guard to jenkins Signed-off-by: ericharper <complex451@gmail.com> * remove deterministic Signed-off-by: ericharper <complex451@gmail.com> * install numba .53 Signed-off-by: ericharper <complex451@gmail.com> * allow for more variance Signed-off-by: ericharper <complex451@gmail.com> * update trainer config dataclass Signed-off-by: ericharper <complex451@gmail.com> * test_get_optimizer on gpu Signed-off-by: ericharper <complex451@gmail.com> * revert comment Signed-off-by: ericharper <complex451@gmail.com> * change trainer config default to 32 Signed-off-by: ericharper <complex451@gmail.com> * [BigNLP] Remove fused kernel code instead use Apex (#2984) * remove fused_kernels Signed-off-by: ericharper <complex451@gmail.com> * remove fused_kernels Signed-off-by: ericharper <complex451@gmail.com> * remove fused layer norm and fused softmax and use apex instead Signed-off-by: ericharper <complex451@gmail.com> * update imports Signed-off-by: ericharper <complex451@gmail.com> * remove comment Signed-off-by: ericharper <complex451@gmail.com> * use apex enums Signed-off-by: ericharper <complex451@gmail.com> * use apex enums Signed-off-by: ericharper <complex451@gmail.com> * add tab Signed-off-by: ericharper <complex451@gmail.com> * Timer with sliding window (#3002) Co-authored-by: Micha Livne <michalivne@users.noreply.github.com> * revert tab Signed-off-by: ericharper <complex451@gmail.com> * check for rank zero Signed-off-by: ericharper <complex451@gmail.com> * check for rank zero Signed-off-by: ericharper <complex451@gmail.com> * try explicit log dir Signed-off-by: ericharper <complex451@gmail.com> * add + Signed-off-by: ericharper <complex451@gmail.com> * don't rm Signed-off-by: ericharper <complex451@gmail.com> * make dir if it doesn't exist Signed-off-by: ericharper <complex451@gmail.com> * create mp nemo file in temp directory Signed-off-by: ericharper <complex451@gmail.com> * simplify mp save_to Signed-off-by: ericharper <complex451@gmail.com> * handle mp 1 case Signed-off-by: ericharper <complex451@gmail.com> * style fix Signed-off-by: ericharper <complex451@gmail.com> * remove files Signed-off-by: ericharper <complex451@gmail.com> * fix consumed_samples when resuming Signed-off-by: ericharper <complex451@gmail.com> * fix reinstall.sh Signed-off-by: ericharper <complex451@gmail.com> * update req Signed-off-by: ericharper <complex451@gmail.com> * add more detailed log for dataloaders Signed-off-by: ericharper <complex451@gmail.com> * check if cuda is available before using fused_adam Signed-off-by: ericharper <complex451@gmail.com> * revert comment Signed-off-by: ericharper <complex451@gmail.com> * update eval script to use model.freeze Signed-off-by: ericharper <complex451@gmail.com> * log train loss averaged over gradient accumulation steps Signed-off-by: ericharper <complex451@gmail.com> * check copyright earlier Signed-off-by: ericharper <complex451@gmail.com> * todo Signed-off-by: ericharper <complex451@gmail.com> * override SaveRestoreConnector in NLPModel init Signed-off-by: ericharper <complex451@gmail.com> * move to scripts Signed-off-by: ericharper <complex451@gmail.com> * remove star import Signed-off-by: ericharper <complex451@gmail.com> * remove comments Signed-off-by: ericharper <complex451@gmail.com> * remove unused dataset Signed-off-by: ericharper <complex451@gmail.com> * removed barrier Signed-off-by: ericharper <complex451@gmail.com> * check cfg Signed-off-by: ericharper <complex451@gmail.com> * remove logging Signed-off-by: ericharper <complex451@gmail.com> * freeze, unfreeze Signed-off-by: ericharper <complex451@gmail.com> * return None Signed-off-by: ericharper <complex451@gmail.com> * remove unused imports Signed-off-by: ericharper <complex451@gmail.com> * add TODO Signed-off-by: ericharper <complex451@gmail.com> * typecheck Signed-off-by: ericharper <complex451@gmail.com> * typo Signed-off-by: ericharper <complex451@gmail.com> * todo Signed-off-by: ericharper <complex451@gmail.com> * add common native plugin Signed-off-by: ericharper <complex451@gmail.com> * restore with trainer Signed-off-by: ericharper <complex451@gmail.com> * style Signed-off-by: ericharper <complex451@gmail.com> * deprecate megatron-lm bert Signed-off-by: ericharper <complex451@gmail.com> * deprecate megatron-lm bert Signed-off-by: ericharper <complex451@gmail.com> * compile helpers ont he fly Signed-off-by: ericharper <complex451@gmail.com> * remove amp_level Signed-off-by: ericharper <complex451@gmail.com> * remove amp_level from configs Signed-off-by: ericharper <complex451@gmail.com> * add missing import Signed-off-by: ericharper <complex451@gmail.com> * typo Signed-off-by: ericharper <complex451@gmail.com> * remove amp_level Signed-off-by: ericharper <complex451@gmail.com> * use fast huggingface tokenizers by default Signed-off-by: ericharper <complex451@gmail.com> * deal with huggingface tokenizer positional args Signed-off-by: ericharper <complex451@gmail.com> * deal with huggingface tokenizer positional args Signed-off-by: ericharper <complex451@gmail.com> * deal with huggingface tokenizer positional args Signed-off-by: ericharper <complex451@gmail.com> * revert use_fast default to False Signed-off-by: ericharper <complex451@gmail.com> * return super training_epoch_end Signed-off-by: ericharper <complex451@gmail.com> * remove optimizer_idx arg from training_step Signed-off-by: ericharper <complex451@gmail.com> * remove unused arg from on_train_epoch_end Signed-off-by: ericharper <complex451@gmail.com> * add restore_from_path to nemo config Signed-off-by: ericharper <complex451@gmail.com> * add comment Signed-off-by: ericharper <complex451@gmail.com> * revert Signed-off-by: ericharper <complex451@gmail.com> * override connector if not subclassing NLPSaveRestoreConnector for model parallel save Signed-off-by: ericharper <complex451@gmail.com> * update test optimizer Signed-off-by: ericharper <complex451@gmail.com> * clean up Signed-off-by: ericharper <complex451@gmail.com> * clean up Signed-off-by: ericharper <complex451@gmail.com> * clean up Signed-off-by: ericharper <complex451@gmail.com> * clean up Signed-off-by: ericharper <complex451@gmail.com> * make data_prefix mandatory in config Signed-off-by: ericharper <complex451@gmail.com> * update installation instructions on readme Signed-off-by: ericharper <complex451@gmail.com> * update dockerfile Signed-off-by: ericharper <complex451@gmail.com> * add todo Signed-off-by: ericharper <complex451@gmail.com> * raise error if trying to use always_save_nemo with model parallel model Signed-off-by: ericharper <complex451@gmail.com> * remove comment Signed-off-by: ericharper <complex451@gmail.com> Co-authored-by: Sandeep Subramanian <sandeep.subramanian.1@umontreal.ca> Co-authored-by: Sangkug Lym <slym@nvidia.com> Co-authored-by: Oleksii Kuchaiev <okuchaiev@users.noreply.github.com> Co-authored-by: Micha Livne <michalivne@users.noreply.github.com> Co-authored-by: Micha Livne <mlivne@nvidia.com> Co-authored-by: Oleksii Kuchaiev <okuchaiev@nvidia.com>
This commit is contained in:
parent
620e8a8986
commit
32fa5cfaf3
|
@ -14,7 +14,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
ARG BASE_IMAGE=nvcr.io/nvidia/pytorch:21.08-py3
|
||||
ARG BASE_IMAGE=nvcr.io/nvidia/pytorch:21.10-py3
|
||||
|
||||
|
||||
# build an image that includes only the nemo dependencies, ensures that dependencies
|
||||
|
|
403
Jenkinsfile
vendored
403
Jenkinsfile
vendored
|
@ -1,7 +1,7 @@
|
|||
pipeline {
|
||||
agent {
|
||||
docker {
|
||||
image 'nvcr.io/nvidia/pytorch:21.08-py3'
|
||||
image 'gitlab-master.nvidia.com/dl/dgx/pytorch:21.10-py3-devel'
|
||||
args '--device=/dev/nvidia0 --gpus all --user 0:128 -v /home/TestData:/home/TestData -v $HOME/.cache/torch:/root/.cache/torch --shm-size=8g'
|
||||
}
|
||||
}
|
||||
|
@ -10,6 +10,7 @@ pipeline {
|
|||
disableConcurrentBuilds()
|
||||
}
|
||||
stages {
|
||||
|
||||
stage('PyTorch version') {
|
||||
steps {
|
||||
sh 'python -c "import torch; print(torch.__version__)"'
|
||||
|
@ -23,6 +24,12 @@ pipeline {
|
|||
}
|
||||
}
|
||||
|
||||
stage('Code formatting checks') {
|
||||
steps {
|
||||
sh 'python setup.py style'
|
||||
}
|
||||
}
|
||||
|
||||
stage('Copyright Headers check') {
|
||||
steps {
|
||||
sh 'python tests/check_copyright_header.py --dir .'
|
||||
|
@ -35,12 +42,6 @@ pipeline {
|
|||
}
|
||||
}
|
||||
|
||||
stage('Code formatting checks') {
|
||||
steps {
|
||||
sh 'python setup.py style'
|
||||
}
|
||||
}
|
||||
|
||||
stage('Torch TTS unit tests') {
|
||||
when {
|
||||
anyOf {
|
||||
|
@ -58,7 +59,14 @@ pipeline {
|
|||
}
|
||||
}
|
||||
|
||||
stage('Installation') {
|
||||
// Revert once import guards are added by PTL or version comparing is fixed
|
||||
stage('Install PyTorch Lighting 1.5 RC') {
|
||||
steps{
|
||||
sh 'pip install pytorch-lightning==1.5.0rc0 && sed -i "s/from pytorch_lightning.callbacks.quantization import QuantizationAwareTraining/try:\\n\\tfrom pytorch_lightning.callbacks.quantization import QuantizationAwareTraining\\nexcept:\\n\\tpass/g" /opt/conda/lib/python3.8/site-packages/pytorch_lightning/callbacks/__init__.py'
|
||||
}
|
||||
}
|
||||
|
||||
stage('NeMo Installation') {
|
||||
steps {
|
||||
sh './reinstall.sh release'
|
||||
}
|
||||
|
@ -605,30 +613,31 @@ pipeline {
|
|||
}
|
||||
}
|
||||
|
||||
stage('L2: Multi-GPU Megatron finetuning') {
|
||||
when {
|
||||
anyOf {
|
||||
branch 'main'
|
||||
changeRequest target: 'main'
|
||||
}
|
||||
}
|
||||
failFast true
|
||||
parallel {
|
||||
stage('L2: Cased Megatron finetuning on MRPC') {
|
||||
steps {
|
||||
sh 'cd examples/nlp/glue_benchmark && \
|
||||
python glue_benchmark.py \
|
||||
model.dataset.data_dir=/home/TestData/nlp/glue_fake/MRPC \
|
||||
trainer.gpus=[0,1] \
|
||||
+trainer.fast_dev_run=true \
|
||||
model.dataset.use_cache=false \
|
||||
model.language_model.pretrained_model_name=megatron-bert-345m-cased \
|
||||
trainer.accelerator=ddp \
|
||||
exp_manager=null'
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// TODO: add test once megatron-bert is supported again
|
||||
// stage('L2: Multi-GPU Megatron finetuning') {
|
||||
// when {
|
||||
// anyOf {
|
||||
// branch 'main'
|
||||
// changeRequest target: 'main'
|
||||
// }
|
||||
// }
|
||||
// failFast true
|
||||
// parallel {
|
||||
// stage('L2: Cased Megatron finetuning on MRPC') {
|
||||
// steps {
|
||||
// sh 'cd examples/nlp/glue_benchmark && \
|
||||
// python glue_benchmark.py \
|
||||
// model.dataset.data_dir=/home/TestData/nlp/glue_fake/MRPC \
|
||||
// trainer.gpus=[0,1] \
|
||||
// +trainer.fast_dev_run=true \
|
||||
// model.dataset.use_cache=false \
|
||||
// model.language_model.pretrained_model_name=megatron-bert-345m-cased \
|
||||
// trainer.accelerator=ddp \
|
||||
// exp_manager=null'
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
|
||||
stage('L2: SGD-QA') {
|
||||
when {
|
||||
|
@ -725,7 +734,6 @@ pipeline {
|
|||
model.language_model.pretrained_model_name=bert-base-uncased \
|
||||
model.dataset.version_2_with_negative=false \
|
||||
trainer.precision=16 \
|
||||
trainer.amp_level=O1 \
|
||||
trainer.gpus=[0] \
|
||||
exp_manager=null'
|
||||
}
|
||||
|
@ -747,7 +755,6 @@ pipeline {
|
|||
model.language_model.pretrained_model_name=bert-base-uncased \
|
||||
model.dataset.version_2_with_negative=true \
|
||||
trainer.precision=16 \
|
||||
trainer.amp_level=O1 \
|
||||
trainer.gpus=[1] \
|
||||
exp_manager=null'
|
||||
}
|
||||
|
@ -777,30 +784,31 @@ pipeline {
|
|||
}
|
||||
}
|
||||
// Runs out of memory on the 12G TITAN V (GPU 0 on main CI)
|
||||
stage('L2: MegaBERT Token Classification') {
|
||||
when {
|
||||
anyOf {
|
||||
branch 'main'
|
||||
changeRequest target: 'main'
|
||||
}
|
||||
}
|
||||
failFast true
|
||||
steps {
|
||||
sh 'cd examples/nlp/token_classification && \
|
||||
python token_classification_train.py \
|
||||
model.dataset.data_dir=/home/TestData/nlp/token_classification_punctuation/ \
|
||||
model.language_model.pretrained_model_name=megatron-bert-345m-uncased \
|
||||
model.train_ds.batch_size=10 \
|
||||
model.dataset.max_seq_length=50 \
|
||||
model.dataset.use_cache=false \
|
||||
trainer.accelerator=ddp \
|
||||
trainer.precision=16 \
|
||||
trainer.amp_level=O1 \
|
||||
trainer.gpus=[1] \
|
||||
+trainer.fast_dev_run=true \
|
||||
exp_manager=null'
|
||||
}
|
||||
}
|
||||
// TODO: add when megatron bert is supported again in NeMo
|
||||
// stage('L2: MegaBERT Token Classification') {
|
||||
// when {
|
||||
// anyOf {
|
||||
// branch 'main'
|
||||
// changeRequest target: 'main'
|
||||
// }
|
||||
// }
|
||||
// failFast true
|
||||
// steps {
|
||||
// sh 'cd examples/nlp/token_classification && \
|
||||
// python token_classification_train.py \
|
||||
// model.dataset.data_dir=/home/TestData/nlp/token_classification_punctuation/ \
|
||||
// model.language_model.pretrained_model_name=megatron-bert-345m-uncased \
|
||||
// model.train_ds.batch_size=10 \
|
||||
// model.dataset.max_seq_length=50 \
|
||||
// model.dataset.use_cache=false \
|
||||
// trainer.accelerator=ddp \
|
||||
// trainer.precision=16 \
|
||||
// trainer.gpus=[1] \
|
||||
// +trainer.fast_dev_run=true \
|
||||
// exp_manager=null'
|
||||
// }
|
||||
// }
|
||||
|
||||
stage('L2: Parallel SQUAD v1.1 & v2.0') {
|
||||
when {
|
||||
anyOf {
|
||||
|
@ -810,8 +818,13 @@ pipeline {
|
|||
}
|
||||
failFast true
|
||||
parallel {
|
||||
stage('SQUAD v2.0 with Megatron with ckpt & config') {
|
||||
// TODO: use megatron bert when supported again
|
||||
stage('SQUAD v2.0 with DistilBERT Uncased') {
|
||||
// stage('SQUAD v2.0 with Megatron with ckpt & config') {
|
||||
// Cannot do fast_dev_run because squad needs whole dev dataset
|
||||
// model.language_model.pretrained_model_name=megatron-bert-uncased \
|
||||
// model.language_model.lm_checkpoint=/home/TestData/nlp/megatron_345m_uncased/model_optim_rng.pt \
|
||||
// model.language_model.config_file=/home/TestData/nlp/megatron_345m_uncased/345m_config.json \
|
||||
steps {
|
||||
sh 'cd examples/nlp/question_answering && \
|
||||
python question_answering_squad.py \
|
||||
|
@ -825,12 +838,9 @@ pipeline {
|
|||
trainer.max_epochs=1 \
|
||||
+trainer.max_steps=1 \
|
||||
model.validation_ds.file=/home/TestData/nlp/squad_mini/v2.0/dev-v2.0.json \
|
||||
model.language_model.pretrained_model_name=megatron-bert-uncased \
|
||||
model.language_model.lm_checkpoint=/home/TestData/nlp/megatron_345m_uncased/model_optim_rng.pt \
|
||||
model.language_model.config_file=/home/TestData/nlp/megatron_345m_uncased/345m_config.json \
|
||||
model.language_model.pretrained_model_name=distilbert-base-uncased \
|
||||
model.dataset.version_2_with_negative=true \
|
||||
trainer.precision=16 \
|
||||
trainer.amp_level=O1 \
|
||||
trainer.gpus=[1] \
|
||||
exp_manager=null'
|
||||
}
|
||||
|
@ -852,7 +862,6 @@ pipeline {
|
|||
model.language_model.pretrained_model_name=roberta-base \
|
||||
model.dataset.version_2_with_negative=false \
|
||||
trainer.precision=16 \
|
||||
trainer.amp_level=O1 \
|
||||
trainer.gpus=[0] \
|
||||
exp_manager=null'
|
||||
}
|
||||
|
@ -864,7 +873,7 @@ pipeline {
|
|||
model.dataset.num_classes=6 \
|
||||
model.train_ds.file_path=/home/TestData/nlp/retail_text_classification/train.tsv \
|
||||
model.validation_ds.file_path=/home/TestData/nlp/retail_text_classification/dev.tsv \
|
||||
model.language_model.pretrained_model_name=bert-base-uncased \
|
||||
model.language_model.pretrained_model_name=distilbert-base-uncased \
|
||||
model.train_ds.batch_size=10 \
|
||||
model.dataset.max_seq_length=50 \
|
||||
model.dataset.use_cache=false \
|
||||
|
@ -889,107 +898,106 @@ pipeline {
|
|||
}
|
||||
}
|
||||
|
||||
stage('L2: Model Parallel Size 2 Megatron Text Classification') {
|
||||
when {
|
||||
anyOf{
|
||||
branch 'main'
|
||||
changeRequest target: 'main'
|
||||
}
|
||||
}
|
||||
failFast true
|
||||
steps{
|
||||
sh 'cd examples/nlp/text_classification && \
|
||||
python text_classification_with_bert.py \
|
||||
trainer.gpus=[0,1] \
|
||||
trainer.num_nodes=1 \
|
||||
trainer.precision=16 \
|
||||
trainer.gradient_clip_val=1.0 \
|
||||
~trainer.amp_level \
|
||||
+trainer.fast_dev_run=true \
|
||||
model.dataset.num_classes=6 \
|
||||
model.train_ds.file_path=/home/TestData/nlp/retail_text_classification/train.tsv \
|
||||
model.train_ds.batch_size=4 \
|
||||
model.language_model.pretrained_model_name=megatron-bert-uncased \
|
||||
model.language_model.config_file=/home/TestData/nlp/mp_2_bert_toy/config.json \
|
||||
model.language_model.lm_checkpoint=/home/TestData/nlp/mp_2_bert_toy/iter_2000000 \
|
||||
model.nemo_path=null \
|
||||
~model.infer_samples \
|
||||
exp_manager=null'
|
||||
}
|
||||
}
|
||||
// TODO: add when megatron-bert is supported again
|
||||
// stage('L2: Model Parallel Size 2 Megatron Text Classification') {
|
||||
// when {
|
||||
// anyOf{
|
||||
// branch 'main'
|
||||
// changeRequest target: 'main'
|
||||
// }
|
||||
// }
|
||||
// failFast true
|
||||
// steps{
|
||||
// sh 'cd examples/nlp/text_classification && \
|
||||
// python text_classification_with_bert.py \
|
||||
// trainer.gpus=[0,1] \
|
||||
// trainer.num_nodes=1 \
|
||||
// trainer.precision=16 \
|
||||
// trainer.gradient_clip_val=1.0 \
|
||||
// +trainer.fast_dev_run=true \
|
||||
// model.dataset.num_classes=6 \
|
||||
// model.train_ds.file_path=/home/TestData/nlp/retail_text_classification/train.tsv \
|
||||
// model.train_ds.batch_size=4 \
|
||||
// model.language_model.pretrained_model_name=megatron-bert-uncased \
|
||||
// model.language_model.config_file=/home/TestData/nlp/mp_2_bert_toy/config.json \
|
||||
// model.language_model.lm_checkpoint=/home/TestData/nlp/mp_2_bert_toy/iter_2000000 \
|
||||
// model.nemo_path=null \
|
||||
// ~model.infer_samples \
|
||||
// exp_manager=null'
|
||||
// }
|
||||
// }
|
||||
|
||||
stage('L2: Model Parallel Size 2 Megatron Autoresume') {
|
||||
when {
|
||||
anyOf{
|
||||
branch 'main'
|
||||
changeRequest target: 'main'
|
||||
}
|
||||
}
|
||||
failFast true
|
||||
steps{
|
||||
sh 'cd examples/nlp/text_classification && \
|
||||
python text_classification_with_bert.py \
|
||||
trainer.gpus=[0,1] \
|
||||
trainer.num_nodes=1 \
|
||||
trainer.precision=16 \
|
||||
trainer.gradient_clip_val=1.0 \
|
||||
trainer.max_epochs=1 \
|
||||
~trainer.amp_level \
|
||||
+trainer.fast_dev_run=true \
|
||||
model.dataset.num_classes=6 \
|
||||
model.train_ds.file_path=/home/TestData/nlp/retail_text_classification/train.tsv \
|
||||
model.train_ds.batch_size=4 \
|
||||
model.language_model.pretrained_model_name=megatron-bert-uncased \
|
||||
model.language_model.config_file=/home/TestData/nlp/mp_2_bert_toy/config.json \
|
||||
model.language_model.lm_checkpoint=/home/TestData/nlp/mp_2_bert_toy/iter_2000000 \
|
||||
model.nemo_path=null \
|
||||
~model.infer_samples \
|
||||
+exp_manager.explicit_log_dir=/home/TestData/nlp/mp_autoresume \
|
||||
+exp_manager.resume_if_exists=true'
|
||||
}
|
||||
}
|
||||
// stage('L2: Model Parallel Size 2 Megatron Autoresume') {
|
||||
// when {
|
||||
// anyOf{
|
||||
// branch 'main'
|
||||
// changeRequest target: 'main'
|
||||
// }
|
||||
// }
|
||||
// failFast true
|
||||
// steps{
|
||||
// sh 'cd examples/nlp/text_classification && \
|
||||
// python text_classification_with_bert.py \
|
||||
// trainer.gpus=[0,1] \
|
||||
// trainer.num_nodes=1 \
|
||||
// trainer.precision=16 \
|
||||
// trainer.gradient_clip_val=1.0 \
|
||||
// trainer.max_epochs=1 \
|
||||
// +trainer.fast_dev_run=true \
|
||||
// model.dataset.num_classes=6 \
|
||||
// model.train_ds.file_path=/home/TestData/nlp/retail_text_classification/train.tsv \
|
||||
// model.train_ds.batch_size=4 \
|
||||
// model.language_model.pretrained_model_name=megatron-bert-uncased \
|
||||
// model.language_model.config_file=/home/TestData/nlp/mp_2_bert_toy/config.json \
|
||||
// model.language_model.lm_checkpoint=/home/TestData/nlp/mp_2_bert_toy/iter_2000000 \
|
||||
// model.nemo_path=null \
|
||||
// ~model.infer_samples \
|
||||
// +exp_manager.explicit_log_dir=/home/TestData/nlp/mp_autoresume \
|
||||
// +exp_manager.resume_if_exists=true'
|
||||
// }
|
||||
// }
|
||||
|
||||
stage('L2: Model Parallel Size 2 Megatron Evaluation from .nemo') {
|
||||
when {
|
||||
anyOf{
|
||||
branch 'main'
|
||||
changeRequest target: 'main'
|
||||
}
|
||||
}
|
||||
failFast true
|
||||
steps{
|
||||
sh 'cd examples/nlp/text_classification && \
|
||||
python model_parallel_text_classification_evaluation.py \
|
||||
trainer.gpus=[0,1] \
|
||||
trainer.num_nodes=1 \
|
||||
model.dataset.num_classes=6 \
|
||||
model.test_ds.file_path=/home/TestData/nlp/retail_text_classification/dev.tsv \
|
||||
model.nemo_path=/home/TestData/nlp/mp_2_nemo/retail_text_class_350M.nemo \
|
||||
exp_manager=null'
|
||||
}
|
||||
}
|
||||
// stage('L2: Model Parallel Size 2 Megatron Evaluation from .nemo') {
|
||||
// when {
|
||||
// anyOf{
|
||||
// branch 'main'
|
||||
// changeRequest target: 'main'
|
||||
// }
|
||||
// }
|
||||
// failFast true
|
||||
// steps{
|
||||
// sh 'cd examples/nlp/text_classification && \
|
||||
// python model_parallel_text_classification_evaluation.py \
|
||||
// trainer.gpus=[0,1] \
|
||||
// trainer.num_nodes=1 \
|
||||
// model.dataset.num_classes=6 \
|
||||
// model.test_ds.file_path=/home/TestData/nlp/retail_text_classification/dev.tsv \
|
||||
// model.nemo_path=/home/TestData/nlp/mp_2_nemo/retail_text_class_350M.nemo \
|
||||
// exp_manager=null'
|
||||
// }
|
||||
// }
|
||||
|
||||
stage('L2: Model Parallel Size 2 Megatron Train from .nemo') {
|
||||
when {
|
||||
anyOf{
|
||||
branch 'main'
|
||||
changeRequest target: 'main'
|
||||
}
|
||||
}
|
||||
failFast true
|
||||
steps{
|
||||
sh 'cd examples/nlp/token_classification && \
|
||||
python token_classification_train.py \
|
||||
pretrained_model=/home/TestData/nlp/mp_2_nemo/ner_350M.nemo \
|
||||
model.dataset.data_dir=/home/TestData/nlp/ner/ \
|
||||
model.train_ds.batch_size=2 \
|
||||
model.dataset.use_cache=false \
|
||||
trainer.gpus=[0,1] \
|
||||
+trainer.fast_dev_run=true \
|
||||
model.dataset.class_balancing="weighted_loss" \
|
||||
exp_manager=null'
|
||||
}
|
||||
}
|
||||
// stage('L2: Model Parallel Size 2 Megatron Train from .nemo') {
|
||||
// when {
|
||||
// anyOf{
|
||||
// branch 'main'
|
||||
// changeRequest target: 'main'
|
||||
// }
|
||||
// }
|
||||
// failFast true
|
||||
// steps{
|
||||
// sh 'cd examples/nlp/token_classification && \
|
||||
// python token_classification_train.py \
|
||||
// pretrained_model=/home/TestData/nlp/mp_2_nemo/ner_350M.nemo \
|
||||
// model.dataset.data_dir=/home/TestData/nlp/ner/ \
|
||||
// model.train_ds.batch_size=2 \
|
||||
// model.dataset.use_cache=false \
|
||||
// trainer.gpus=[0,1] \
|
||||
// +trainer.fast_dev_run=true \
|
||||
// model.dataset.class_balancing="weighted_loss" \
|
||||
// exp_manager=null'
|
||||
// }
|
||||
// }
|
||||
|
||||
stage('L2: Parallel NLP Examples 2') {
|
||||
when {
|
||||
|
@ -1089,7 +1097,6 @@ pipeline {
|
|||
--config-name=bert_pretraining_from_text_config.yaml \
|
||||
trainer.gpus=[0] \
|
||||
trainer.precision=16 \
|
||||
trainer.amp_level=O1 \
|
||||
+trainer.fast_dev_run=true \
|
||||
model.train_ds.data_file=/home/TestData/nlp/wikitext-2/train.txt \
|
||||
model.train_ds.batch_size=32 \
|
||||
|
@ -1116,7 +1123,6 @@ pipeline {
|
|||
--config-name=bert_pretraining_from_preprocessed_config.yaml \
|
||||
trainer.gpus=[1] \
|
||||
trainer.precision=16 \
|
||||
trainer.amp_level=O1 \
|
||||
+trainer.fast_dev_run=true \
|
||||
model.train_ds.data_file=/home/TestData/nlp/wiki_book_mini/training \
|
||||
model.train_ds.batch_size=8 \
|
||||
|
@ -1351,39 +1357,40 @@ pipeline {
|
|||
}
|
||||
}
|
||||
|
||||
stage('L2: NMT Megatron BERT Model Parallel Size 2 Encoder') {
|
||||
when {
|
||||
anyOf{
|
||||
branch 'main'
|
||||
changeRequest target: 'main'
|
||||
}
|
||||
}
|
||||
failFast true
|
||||
steps{
|
||||
sh 'cd examples/nlp/machine_translation && \
|
||||
python enc_dec_nmt.py \
|
||||
--config-path=conf \
|
||||
--config-name=megatron \
|
||||
model.encoder.model_name=megatron-bert-uncased \
|
||||
model.encoder.checkpoint_file=/home/TestData/nlp/mp_2_bert_toy/iter_2000000 \
|
||||
model.encoder.hidden_size=1024 \
|
||||
model.encoder.num_attention_heads=16 \
|
||||
model.encoder.num_layers=24 \
|
||||
model.encoder.max_position_embeddings=512 \
|
||||
model.train_ds.src_file_name=/home/TestData/nlp/nmt/toy_data/wmt14-de-en.src \
|
||||
model.train_ds.tgt_file_name=/home/TestData/nlp/nmt/toy_data/wmt14-de-en.ref \
|
||||
model.validation_ds.src_file_name=/home/TestData/nlp/nmt/toy_data/wmt14-de-en.src \
|
||||
model.validation_ds.tgt_file_name=/home/TestData/nlp/nmt/toy_data/wmt14-de-en.src \
|
||||
model.test_ds.src_file_name=/home/TestData/nlp/nmt/toy_data/wmt14-de-en.src \
|
||||
model.test_ds.tgt_file_name=/home/TestData/nlp/nmt/toy_data/wmt14-de-en.src \
|
||||
model.decoder_tokenizer.tokenizer_model=/home/TestData/nlp/nmt/toy_data/tt_tokenizer.BPE.4096.model \
|
||||
model.decoder.hidden_size=1024 \
|
||||
trainer.gpus=[0,1] \
|
||||
+trainer.fast_dev_run=true \
|
||||
exp_manager=null \
|
||||
'
|
||||
}
|
||||
}
|
||||
// TODO: add when megatron bert is supported again in NeMo
|
||||
// stage('L2: NMT Megatron BERT Model Parallel Size 2 Encoder') {
|
||||
// when {
|
||||
// anyOf{
|
||||
// branch 'main'
|
||||
// changeRequest target: 'main'
|
||||
// }
|
||||
// }
|
||||
// failFast true
|
||||
// steps{
|
||||
// sh 'cd examples/nlp/machine_translation && \
|
||||
// python enc_dec_nmt.py \
|
||||
// --config-path=conf \
|
||||
// --config-name=megatron \
|
||||
// model.encoder.model_name=megatron-bert-uncased \
|
||||
// model.encoder.checkpoint_file=/home/TestData/nlp/mp_2_bert_toy/iter_2000000 \
|
||||
// model.encoder.hidden_size=1024 \
|
||||
// model.encoder.num_attention_heads=16 \
|
||||
// model.encoder.num_layers=24 \
|
||||
// model.encoder.max_position_embeddings=512 \
|
||||
// model.train_ds.src_file_name=/home/TestData/nlp/nmt/toy_data/wmt14-de-en.src \
|
||||
// model.train_ds.tgt_file_name=/home/TestData/nlp/nmt/toy_data/wmt14-de-en.ref \
|
||||
// model.validation_ds.src_file_name=/home/TestData/nlp/nmt/toy_data/wmt14-de-en.src \
|
||||
// model.validation_ds.tgt_file_name=/home/TestData/nlp/nmt/toy_data/wmt14-de-en.src \
|
||||
// model.test_ds.src_file_name=/home/TestData/nlp/nmt/toy_data/wmt14-de-en.src \
|
||||
// model.test_ds.tgt_file_name=/home/TestData/nlp/nmt/toy_data/wmt14-de-en.src \
|
||||
// model.decoder_tokenizer.tokenizer_model=/home/TestData/nlp/nmt/toy_data/tt_tokenizer.BPE.4096.model \
|
||||
// model.decoder.hidden_size=1024 \
|
||||
// trainer.gpus=[0,1] \
|
||||
// +trainer.fast_dev_run=true \
|
||||
// exp_manager=null \
|
||||
// '
|
||||
// }
|
||||
// }
|
||||
|
||||
stage('L2: NMT Tarred Dataset Creation') {
|
||||
when {
|
||||
|
|
18
README.rst
18
README.rst
|
@ -85,7 +85,7 @@ Requirements
|
|||
------------
|
||||
|
||||
1) Python 3.6, 3.7 or 3.8
|
||||
2) Pytorch 1.8.1 or above
|
||||
2) Pytorch 1.10.0 or above
|
||||
3) NVIDIA GPU for training
|
||||
|
||||
Documentation
|
||||
|
@ -163,16 +163,28 @@ Note that RNNT requires numba to be installed from conda.
|
|||
pip uninstall numba
|
||||
conda install -c numba numba
|
||||
|
||||
Megatron GPT
|
||||
~~~~~~~~~~~~
|
||||
Megatron GPT training requires NVIDIA Apex to be installed.
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
git clone https://github.com/NVIDIA/apex
|
||||
cd apex
|
||||
pip install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./
|
||||
|
||||
Docker containers:
|
||||
~~~~~~~~~~~~~~~~~~
|
||||
|
||||
If you chose to work with main branch, we recommend using NVIDIA's PyTorch container version 21.08-py3 and then installing from GitHub.
|
||||
If you chose to work with main branch, we recommend using NVIDIA's PyTorch container version 21.10-py3 and then installing from GitHub.
|
||||
Note NVIDIA's PyTorch 21.10-py3 has not yet been released publicy. Please use a container with the nightly version of PyTorch installed if you are
|
||||
unable to access the NVIDIA's PyTorch 21.10 container.
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
docker run --gpus all -it --rm -v <nemo_github_folder>:/NeMo --shm-size=8g \
|
||||
-p 8888:8888 -p 6006:6006 --ulimit memlock=-1 --ulimit \
|
||||
stack=67108864 --device=/dev/snd nvcr.io/nvidia/pytorch:21.08-py3
|
||||
stack=67108864 --device=/dev/snd nvcr.io/nvidia/pytorch:21.10-py3
|
||||
|
||||
Examples
|
||||
--------
|
||||
|
|
|
@ -156,7 +156,6 @@ trainer:
|
|||
accelerator: ddp
|
||||
accumulate_grad_batches: 1
|
||||
gradient_clip_val: 0.0
|
||||
amp_level: O0 # O1/O2 for mixed precision
|
||||
precision: 32 # Should be set to 16 for O1 and O2 to enable the AMP.
|
||||
log_every_n_steps: 10 # Interval of logging.
|
||||
progress_bar_refresh_rate: 10
|
||||
|
|
|
@ -131,7 +131,6 @@ trainer:
|
|||
accelerator: ddp
|
||||
accumulate_grad_batches: 1
|
||||
gradient_clip_val: 0.0
|
||||
amp_level: O0 # O1/O2 for mixed precision
|
||||
precision: 32 # Should be set to 16 for O1 and O2 to enable the AMP.
|
||||
log_every_n_steps: 10 # Interval of logging.
|
||||
progress_bar_refresh_rate: 10
|
||||
|
|
|
@ -206,7 +206,6 @@ trainer:
|
|||
accelerator: ddp
|
||||
accumulate_grad_batches: 1
|
||||
gradient_clip_val: 0.0
|
||||
amp_level: O0 # O1/O2 for mixed precision
|
||||
precision: 32 # Should be set to 16 for O1 and O2 to enable the AMP.
|
||||
log_every_n_steps: 10 # Interval of logging.
|
||||
progress_bar_refresh_rate: 10
|
||||
|
|
|
@ -201,7 +201,6 @@ trainer:
|
|||
accelerator: ddp
|
||||
accumulate_grad_batches: 1
|
||||
gradient_clip_val: 0.0
|
||||
amp_level: O0 # O1/O2 for mixed precision
|
||||
precision: 32 # Should be set to 16 for O1 and O2 to enable the AMP.
|
||||
log_every_n_steps: 10 # Interval of logging.
|
||||
progress_bar_refresh_rate: 10
|
||||
|
|
|
@ -100,7 +100,6 @@ trainer:
|
|||
accelerator: ddp
|
||||
accumulate_grad_batches: 1
|
||||
gradient_clip_val: 0.0
|
||||
amp_level: O0 # O1/O2 for mixed precision
|
||||
precision: 32 # Should be set to 16 for O1 and O2 to enable the AMP.
|
||||
log_every_n_steps: 10 # Interval of logging.
|
||||
resume_from_checkpoint: null # The path to a checkpoint file to continue the training, restores the whole state including the epoch, step, LR schedulers, apex, etc.
|
||||
|
|
|
@ -100,7 +100,6 @@ trainer:
|
|||
accelerator: ddp
|
||||
accumulate_grad_batches: 1
|
||||
gradient_clip_val: 0.0
|
||||
amp_level: O0 # O1/O2 for mixed precision
|
||||
precision: 32 # Should be set to 16 for O1 and O2 to enable the AMP.
|
||||
log_every_n_steps: 10 # Interval of logging.
|
||||
resume_from_checkpoint: null # The path to a checkpoint file to continue the training, restores the whole state including the epoch, step, LR schedulers, apex, etc.
|
||||
|
|
|
@ -22,7 +22,6 @@ trainer:
|
|||
max_steps: null # precedence over max_epochs
|
||||
accumulate_grad_batches: 1 # accumulates grads every k batches
|
||||
gradient_clip_val: 1.0
|
||||
amp_level: O1 # O1/O2 for mixed precision
|
||||
precision: 16 # Should be set to 16 for O1 and O2 to enable the AMP.
|
||||
accelerator: ddp
|
||||
log_every_n_steps: 5 # Interval of logging.
|
||||
|
|
|
@ -15,7 +15,6 @@ tagger_trainer:
|
|||
logger: false # provided by exp_manager
|
||||
accumulate_grad_batches: 1 # accumulates grads every k batches
|
||||
gradient_clip_val: 0.0
|
||||
amp_level: O0 # O1/O2 for mixed precision
|
||||
precision: 32 # Should be set to 16 for O1 and O2 to enable the AMP.
|
||||
accelerator: ddp
|
||||
|
||||
|
@ -66,7 +65,6 @@ decoder_trainer:
|
|||
logger: false # provided by exp_manager
|
||||
accumulate_grad_batches: 1 # accumulates grads every k batches
|
||||
gradient_clip_val: 0.0
|
||||
amp_level: O0 # O1/O2 for mixed precision
|
||||
precision: 32 # Should be set to 16 for O1 and O2 to enable the AMP.
|
||||
accelerator: ddp
|
||||
log_every_n_steps: 1 # Interval of logging.
|
||||
|
|
|
@ -7,7 +7,6 @@ trainer:
|
|||
max_steps: null
|
||||
accumulate_grad_batches: 1
|
||||
precision: 16
|
||||
amp_level: O1
|
||||
accelerator: ddp
|
||||
gradient_clip_val: 0.0
|
||||
log_every_n_steps: 1
|
||||
|
|
|
@ -7,7 +7,6 @@ trainer:
|
|||
max_steps: null
|
||||
accumulate_grad_batches: 1
|
||||
precision: 16
|
||||
amp_level: O1
|
||||
accelerator: ddp
|
||||
gradient_clip_val: 0.0
|
||||
log_every_n_steps: 1
|
||||
|
|
|
@ -7,7 +7,6 @@ trainer:
|
|||
max_epochs: 3
|
||||
max_steps: null # precedence over max_epochs
|
||||
accumulate_grad_batches: 1 # accumulates grads every k batches
|
||||
amp_level: O0 # O1/O2 for mixed precision
|
||||
precision: 16
|
||||
accelerator: ddp
|
||||
checkpoint_callback: False # Provided by exp_manager
|
||||
|
|
|
@ -7,7 +7,6 @@ trainer:
|
|||
max_steps: null # precedence over max_epochs
|
||||
accumulate_grad_batches: 1 # accumulates grads every k batches
|
||||
precision: 16 # 16 to use AMP
|
||||
amp_level: O1 # O1 or O2 if using AMP
|
||||
accelerator: ddp
|
||||
log_every_n_steps: 1 # Interval of logging.
|
||||
val_check_interval: 0.05 # Set to 0.25 to check 4 times per epoch, or an int for number of iterations
|
||||
|
|
|
@ -6,7 +6,6 @@ trainer:
|
|||
max_epochs: 50
|
||||
max_steps: null # precedence over max_epochs
|
||||
accumulate_grad_batches: 1 # accumulates grads every k batches
|
||||
amp_level: O0 # O1/O2 for mixed precision
|
||||
precision: 32 # Should be set to 16 for O1 and O2 amp_level to enable the AMP.
|
||||
accelerator: ddp
|
||||
log_every_n_steps: 1 # Interval of logging.
|
||||
|
|
|
@ -8,7 +8,6 @@ trainer:
|
|||
replace_sampler_ddp: false # needed for bert pretraining from preproc
|
||||
accumulate_grad_batches: 1 # accumulates grads every k batches
|
||||
precision: 16 # 16 to use AMP
|
||||
amp_level: O1 # O1 or O2 if using AMP
|
||||
accelerator: ddp
|
||||
gradient_clip_val: 1.0
|
||||
log_every_n_steps: 1
|
||||
|
|
|
@ -7,7 +7,6 @@ trainer:
|
|||
max_steps: null # precedence over max_epochs
|
||||
accumulate_grad_batches: 1 # accumulates grads every k batches
|
||||
precision: 16 # 16 to use AMP
|
||||
amp_level: O1 # O1 or O2 if using AMP
|
||||
accelerator: ddp
|
||||
gradient_clip_val: 0.0
|
||||
log_every_n_steps: 1
|
||||
|
|
117
examples/nlp/language_modeling/conf/megatron_gpt_config.yaml
Normal file
117
examples/nlp/language_modeling/conf/megatron_gpt_config.yaml
Normal file
|
@ -0,0 +1,117 @@
|
|||
name: megatron_gpt
|
||||
restore_from_path: null # used when starting from a .nemo file
|
||||
|
||||
trainer:
|
||||
gpus: 1
|
||||
num_nodes: 1
|
||||
accelerator: ddp
|
||||
precision: 16
|
||||
logger: False # logger provided by exp_manager
|
||||
checkpoint_callback: False
|
||||
replace_sampler_ddp: False
|
||||
max_epochs: null
|
||||
max_steps: 100000 # consumed_samples = global_step * micro_batch_size * data_parallel_size * accumulate_grad_batches
|
||||
log_every_n_steps: 10
|
||||
val_check_interval: 100
|
||||
limit_val_batches: 50
|
||||
limit_test_batches: 500
|
||||
accumulate_grad_batches: 1
|
||||
gradient_clip_val: 1.0
|
||||
|
||||
exp_manager:
|
||||
explicit_log_dir: null
|
||||
exp_dir: null
|
||||
name: megatron_gpt
|
||||
create_wandb_logger: True
|
||||
wandb_logger_kwargs:
|
||||
project: null
|
||||
name: null
|
||||
resume_if_exists: True
|
||||
resume_ignore_no_checkpoint: True
|
||||
create_checkpoint_callback: True
|
||||
checkpoint_callback_params:
|
||||
monitor: val_loss
|
||||
save_top_k: 10
|
||||
mode: min
|
||||
always_save_nemo: False # saves nemo file during validation, not implemented for model parallel
|
||||
filename: 'megatron_gpt--{val_loss:.2f}-{step}-{consumed_samples}'
|
||||
|
||||
|
||||
model:
|
||||
# model parallelism
|
||||
micro_batch_size: 4
|
||||
tensor_model_parallel_size: 1
|
||||
|
||||
# model architecture
|
||||
encoder_seq_length: 512
|
||||
max_position_embeddings: 512
|
||||
num_layers: 12
|
||||
hidden_size: 768
|
||||
ffn_hidden_size: 3072 # Transformer FFN hidden size. Usually 4 * hidden_size.
|
||||
num_attention_heads: 12
|
||||
init_method_std: 0.02 # Standard deviation of the zero mean normal distribution used for weight initialization.')
|
||||
hidden_dropout: 0.1 # Dropout probability for hidden state transformer.
|
||||
kv_channels: null # Projection weights dimension in multi-head attention. Set to hidden_size // num_attention_heads if null
|
||||
apply_query_key_layer_scaling: True # scale Q * K^T by 1 / layer-number.
|
||||
layernorm_epsilon: 1e-5
|
||||
make_vocab_size_divisible_by: 128 # Pad the vocab size to be divisible by this value for computation efficiency.
|
||||
pre_process: True # add embedding
|
||||
post_process: True # add pooler
|
||||
|
||||
tokenizer:
|
||||
library: 'megatron'
|
||||
type: 'GPT2BPETokenizer'
|
||||
model: null
|
||||
vocab_file: null
|
||||
merge_file: null
|
||||
|
||||
# precision
|
||||
native_amp_init_scale: 4294967296 # 2 ** 32
|
||||
native_amp_growth_interval: 1000
|
||||
fused_fp16: True # False if using fp32 or bf16
|
||||
fused_bf16: False # True if using bf16
|
||||
fp32_residual_connection: False # Move residual connections to fp32
|
||||
fp16_lm_cross_entropy: False # Move the cross entropy unreduced loss calculation for lm head to fp16
|
||||
|
||||
|
||||
# miscellaneous
|
||||
seed: 1234
|
||||
use_cpu_initialization: False # Init weights on the CPU (slow for large models)
|
||||
onnx_safe: False # Use work-arounds for known problems with Torch ONNX exporter.
|
||||
|
||||
# not implemented in NeMo yet
|
||||
activations_checkpoint_method: null # 'uniform', 'block'
|
||||
activations_checkpoint_num_layers: 1
|
||||
|
||||
data:
|
||||
# Path to data must be specified by the user.
|
||||
# can override from the CLI: "model.data.data_prefix=[.5,/raid/data/pile/my-gpt3_00_text_document,.5,/raid/data/pile/my-gpt3_01_text_document]",
|
||||
# Or see example below:
|
||||
# data_prefix:
|
||||
# - .5
|
||||
# - /raid/data/pile/my-gpt3_00_text_document
|
||||
# - .5
|
||||
# - /raid/data/pile/my-gpt3_01_text_document
|
||||
data_prefix: ???
|
||||
data_impl: mmap
|
||||
splits_string: 900,50,50
|
||||
seq_length: 1024
|
||||
skip_warmup: True
|
||||
num_workers: 0
|
||||
dataloader_type: single # cyclic
|
||||
reset_position_ids: False # Reset position ids after end-of-document token
|
||||
reset_attention_mask: False # Reset attention mask after end-of-document token
|
||||
eod_mask_loss: False # Mask loss for the end of document tokens
|
||||
|
||||
optim:
|
||||
name: adam
|
||||
lr: 2e-4
|
||||
weight_decay: 0.01
|
||||
betas:
|
||||
- 0.9
|
||||
- 0.98
|
||||
sched:
|
||||
name: CosineAnnealing
|
||||
warmup_steps: 500
|
||||
constant_steps: 50000
|
||||
min_lr: 2e-5
|
|
@ -88,7 +88,6 @@ trainer:
|
|||
gpus: 4
|
||||
num_nodes: 1
|
||||
max_epochs: 200
|
||||
amp_level: O2 # O1/O2 for mixed precision
|
||||
precision: 16 # Should be set to 16 for O1 and O2, default is 16 as PT ignores it when am_level is O0
|
||||
accelerator: ddp
|
||||
checkpoint_callback: False
|
||||
|
|
80
examples/nlp/language_modeling/megatron_gpt_ckpt_to_nemo.py
Normal file
80
examples/nlp/language_modeling/megatron_gpt_ckpt_to_nemo.py
Normal file
|
@ -0,0 +1,80 @@
|
|||
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import os
|
||||
from argparse import ArgumentParser
|
||||
|
||||
import torch.multiprocessing as mp
|
||||
from pytorch_lightning.trainer.trainer import Trainer
|
||||
|
||||
from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel
|
||||
from nemo.collections.nlp.parts.nlp_overrides import NLPSaveRestoreConnector
|
||||
from nemo.utils import AppState, logging
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--checkpoint_folder",
|
||||
type=str,
|
||||
default=None,
|
||||
required=True,
|
||||
help="Path to PTL checkpoints saved during training. Ex: /raid/nemo_experiments/megatron_gpt/checkpoints",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--checkpoint_name",
|
||||
type=str,
|
||||
default=None,
|
||||
required=True,
|
||||
help="Name of checkpoint to be used. Ex: megatron_gpt--val_loss=6.34-step=649-last.ckpt",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--hparams_file",
|
||||
type=str,
|
||||
default=None,
|
||||
required=False,
|
||||
help="Path config for restoring. It's created during training and may need to be modified during restore if restore environment is different than training. Ex: /raid/nemo_experiments/megatron_gpt/hparams.yaml",
|
||||
)
|
||||
parser.add_argument("--nemo_file_path", type=str, default=None, required=True, help="Path to output .nemo file.")
|
||||
|
||||
parser.add_argument("--tensor_model_parallel_size", type=int, required=True, default=None)
|
||||
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def convert(rank, world_size, args):
|
||||
|
||||
app_state = AppState()
|
||||
app_state.data_parallel_rank = 0
|
||||
trainer = Trainer(gpus=args.tensor_model_parallel_size)
|
||||
# TODO: reach out to PTL For an API-safe local rank override
|
||||
trainer.accelerator.training_type_plugin._local_rank = rank
|
||||
checkpoint_path = os.path.join(args.checkpoint_folder, f'mp_rank_{rank:02d}', args.checkpoint_name)
|
||||
model = MegatronGPTModel.load_from_checkpoint(checkpoint_path, hparams_file=args.hparams_file, trainer=trainer)
|
||||
model._save_restore_connector = NLPSaveRestoreConnector()
|
||||
model.save_to(args.nemo_file_path)
|
||||
logging.info(f'NeMo model saved to: {args.nemo_file_path}')
|
||||
|
||||
|
||||
def main() -> None:
|
||||
args = get_args()
|
||||
world_size = args.tensor_model_parallel_size
|
||||
mp.spawn(convert, args=(world_size, args), nprocs=world_size, join=True)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main() # noqa pylint: disable=no-value-for-parameter
|
90
examples/nlp/language_modeling/megatron_gpt_eval.py
Normal file
90
examples/nlp/language_modeling/megatron_gpt_eval.py
Normal file
|
@ -0,0 +1,90 @@
|
|||
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from argparse import ArgumentParser
|
||||
|
||||
import torch
|
||||
from pytorch_lightning.trainer.trainer import Trainer
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from nemo.collections.nlp.data.language_modeling.megatron.gpt_request_dataset import GPTRequestDataset
|
||||
from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel
|
||||
from nemo.collections.nlp.modules.common.megatron.megatron_utils import compute_model_parallel_rank
|
||||
from nemo.collections.nlp.parts.nlp_overrides import NLPDDPPlugin
|
||||
from nemo.utils import logging
|
||||
from nemo.utils.app_state import AppState
|
||||
|
||||
assert torch.cuda.is_available()
|
||||
|
||||
|
||||
def main():
|
||||
parser = ArgumentParser()
|
||||
parser.add_argument("--model_file", type=str, default="", required=True, help="Pass path to model's .nemo file")
|
||||
parser.add_argument(
|
||||
"--prompt", type=str, default="", required=True, help="Prompt for the model (a text to complete)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tokens_to_generate", type=int, default="64", required=False, help="How many tokens to add to prompt"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--stop_after_sentence",
|
||||
type=bool,
|
||||
default="True",
|
||||
required=False,
|
||||
help="True/False: whether to stop after full sentence has been generated.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tensor_model_parallel_size", type=int, default=1, required=True,
|
||||
)
|
||||
parser.add_argument("--precision", default=32, help="PyTorch Lightning Trainer precision flag")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# cast precision to int if 32 or 16
|
||||
if args.precision in ["32", "16"]:
|
||||
args.precision = int(float(args.precision))
|
||||
|
||||
# trainer required for restoring model parallel models
|
||||
trainer = Trainer(plugins=NLPDDPPlugin(), gpus=args.tensor_model_parallel_size, precision=args.precision)
|
||||
|
||||
app_state = AppState()
|
||||
if args.tensor_model_parallel_size is not None and args.tensor_model_parallel_size > 1:
|
||||
app_state.model_parallel_size = args.tensor_model_parallel_size
|
||||
app_state.model_parallel_rank = compute_model_parallel_rank(trainer.local_rank, app_state.model_parallel_size)
|
||||
|
||||
model = MegatronGPTModel.restore_from(restore_path=args.model_file, trainer=trainer)
|
||||
|
||||
model.freeze()
|
||||
|
||||
request = {
|
||||
"prompt": args.prompt,
|
||||
"tokens_to_generate": args.tokens_to_generate,
|
||||
"stop_after_sentence": args.stop_after_sentence,
|
||||
}
|
||||
|
||||
dataset = GPTRequestDataset(request, model.tokenizer)
|
||||
|
||||
request_dl = DataLoader(dataset)
|
||||
|
||||
response = trainer.predict(model, request_dl)
|
||||
|
||||
print("***************************")
|
||||
print(response[0]['completion']['text'])
|
||||
print("***************************")
|
||||
logging.info(f"Generation stopped because: {response[0]['completion']['stop reason']}")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main() # noqa pylint: disable=no-value-for-parameter
|
79
examples/nlp/language_modeling/megatron_gpt_pretraining.py
Normal file
79
examples/nlp/language_modeling/megatron_gpt_pretraining.py
Normal file
|
@ -0,0 +1,79 @@
|
|||
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from omegaconf.omegaconf import OmegaConf
|
||||
from pytorch_lightning import Trainer
|
||||
|
||||
from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel
|
||||
from nemo.collections.nlp.modules.common.megatron.megatron_utils import compute_model_parallel_rank
|
||||
from nemo.collections.nlp.parts.nlp_overrides import (
|
||||
NLPCheckpointConnector,
|
||||
NLPDDPPlugin,
|
||||
NLPNativeBfloat16PrecisionPlugin,
|
||||
NLPNativeMixedPrecisionPlugin,
|
||||
NLPPrecisionPlugin,
|
||||
)
|
||||
from nemo.core.config import hydra_runner
|
||||
from nemo.utils import logging
|
||||
from nemo.utils.exp_manager import exp_manager
|
||||
|
||||
|
||||
@hydra_runner(config_path="conf", config_name="megatron_gpt_config")
|
||||
def main(cfg) -> None:
|
||||
logging.info("\n\n************** Experiment configuration ***********")
|
||||
logging.info(f'\n{OmegaConf.to_yaml(cfg)}')
|
||||
|
||||
if cfg.trainer.precision == 16:
|
||||
trainer = Trainer(
|
||||
plugins=[
|
||||
NLPDDPPlugin(num_nodes=cfg.trainer.num_nodes),
|
||||
NLPNativeMixedPrecisionPlugin(
|
||||
init_scale=cfg.model.get('native_amp_init_scale', 2 ** 32),
|
||||
growth_interval=cfg.model.get('native_amp_growth_interval', 1000),
|
||||
),
|
||||
],
|
||||
**cfg.trainer,
|
||||
)
|
||||
elif cfg.trainer.precision == 'bf16':
|
||||
trainer = Trainer(
|
||||
plugins=[NLPDDPPlugin(num_nodes=cfg.trainer.num_nodes), NLPNativeBfloat16PrecisionPlugin(),],
|
||||
**cfg.trainer,
|
||||
)
|
||||
else:
|
||||
trainer = Trainer(plugins=[NLPDDPPlugin(num_nodes=cfg.trainer.num_nodes), NLPPrecisionPlugin()], **cfg.trainer)
|
||||
|
||||
exp_manager(trainer, cfg.exp_manager)
|
||||
|
||||
# update resume from checkpoint found by exp_manager
|
||||
resume_from_checkpoint = trainer.resume_from_checkpoint
|
||||
if resume_from_checkpoint is not None:
|
||||
mp_rank = compute_model_parallel_rank(trainer.local_rank, cfg.model.tensor_model_parallel_size)
|
||||
resume_from_checkpoint = Path(resume_from_checkpoint)
|
||||
resume_from_checkpoint = resume_from_checkpoint.parent.parent.joinpath(f'mp_rank_{mp_rank:02d}').joinpath(
|
||||
resume_from_checkpoint.name
|
||||
)
|
||||
resume_from_checkpoint = str(resume_from_checkpoint)
|
||||
logging.info(f'Resuming training from checkpoint: {resume_from_checkpoint}')
|
||||
|
||||
trainer.checkpoint_connector = NLPCheckpointConnector(trainer, resume_from_checkpoint=resume_from_checkpoint)
|
||||
|
||||
model = MegatronGPTModel(cfg.model, trainer)
|
||||
|
||||
trainer.fit(model)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
72
examples/nlp/language_modeling/megatron_gpt_test.py
Normal file
72
examples/nlp/language_modeling/megatron_gpt_test.py
Normal file
|
@ -0,0 +1,72 @@
|
|||
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from omegaconf.omegaconf import OmegaConf
|
||||
from pytorch_lightning import Trainer
|
||||
|
||||
from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel
|
||||
from nemo.collections.nlp.modules.common.megatron.megatron_utils import compute_model_parallel_rank
|
||||
from nemo.collections.nlp.parts.nlp_overrides import (
|
||||
NLPDDPPlugin,
|
||||
NLPNativeMixedPrecisionPlugin,
|
||||
NLPPrecisionPlugin,
|
||||
NLPSaveRestoreConnector,
|
||||
)
|
||||
from nemo.core.config import hydra_runner
|
||||
from nemo.utils import logging
|
||||
from nemo.utils.app_state import AppState
|
||||
|
||||
|
||||
@hydra_runner(config_path="conf", config_name="megatron_gpt_config")
|
||||
def main(cfg) -> None:
|
||||
logging.info("\n\n************** Experiment configuration ***********")
|
||||
logging.info(f'\n{OmegaConf.to_yaml(cfg)}')
|
||||
|
||||
trainer = None
|
||||
if cfg.trainer.precision == 16:
|
||||
trainer = Trainer(
|
||||
plugins=[
|
||||
NLPDDPPlugin(num_nodes=cfg.trainer.num_nodes),
|
||||
NLPNativeMixedPrecisionPlugin(
|
||||
init_scale=cfg.model.get('native_amp_init_scale', 2 ** 32),
|
||||
growth_interval=cfg.model.get('native_amp_growth_interval', 1000),
|
||||
),
|
||||
],
|
||||
**cfg.trainer,
|
||||
)
|
||||
elif cfg.trainer.precision == 'bf16':
|
||||
trainer = Trainer(
|
||||
plugins=[NLPDDPPlugin(num_nodes=cfg.trainer.num_nodes), NLPNativeBfloat16PrecisionPlugin(),],
|
||||
**cfg.trainer,
|
||||
)
|
||||
else:
|
||||
trainer = Trainer(plugins=[NLPDDPPlugin(num_nodes=cfg.trainer.num_nodes), NLPPrecisionPlugin()], **cfg.trainer)
|
||||
|
||||
app_state = AppState()
|
||||
app_state.model_parallel_size = cfg.model.tensor_model_parallel_size
|
||||
app_state.model_parallel_rank = compute_model_parallel_rank(trainer.local_rank, app_state.model_parallel_size)
|
||||
|
||||
model = MegatronGPTModel.restore_from(
|
||||
cfg.restore_from_path, trainer=trainer, save_restore_connector=NLPSaveRestoreConnector(),
|
||||
)
|
||||
|
||||
# Note: most nemo models must have the data paths configured before instantiating the model
|
||||
# MegatronGPTMOdel sets up the data in the PTL method .setup which happens after DDP spawns.
|
||||
model.cfg.data.splits_string = cfg.model.data.splits_string
|
||||
|
||||
trainer.test(model)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
|
@ -145,7 +145,6 @@ trainer:
|
|||
gpus: 4
|
||||
num_nodes: 1
|
||||
max_epochs: 200
|
||||
amp_level: O2 # O1/O2 for mixed precision
|
||||
precision: 16 # Should be set to 16 for O1 and O2, default is 16 as PT ignores it when am_level is O0
|
||||
accelerator: ddp
|
||||
checkpoint_callback: False
|
||||
|
|
|
@ -172,7 +172,6 @@ trainer:
|
|||
gpus: 4
|
||||
num_nodes: 1
|
||||
max_epochs: 200
|
||||
amp_level: O2 # O1/O2 for mixed precision
|
||||
precision: 16 # Should be set to 16 for O1 and O2, default is 16 as PT ignores it when am_level is O0
|
||||
accelerator: ddp
|
||||
checkpoint_callback: False
|
||||
|
|
|
@ -119,7 +119,6 @@ trainer:
|
|||
gpus: 4
|
||||
num_nodes: 1
|
||||
max_epochs: 200
|
||||
amp_level: O2 # O1/O2 for mixed precision
|
||||
precision: 16 # Should be set to 16 for O1 and O2, default is 16 as PT ignores it when am_level is O0
|
||||
accelerator: ddp
|
||||
checkpoint_callback: False
|
||||
|
|
|
@ -147,7 +147,6 @@ trainer:
|
|||
gpus: 4
|
||||
num_nodes: 1
|
||||
max_epochs: 200
|
||||
amp_level: O2 # O1/O2 for mixed precision
|
||||
precision: 16 # Should be set to 16 for O1 and O2, default is 16 as PT ignores it when am_level is O0
|
||||
accelerator: ddp
|
||||
checkpoint_callback: False
|
||||
|
|
|
@ -10,7 +10,6 @@ trainer:
|
|||
max_steps: null # precedence over max_epochs
|
||||
accumulate_grad_batches: 1 # accumulates grads every k batches
|
||||
precision: 16 # 16 to use AMP
|
||||
amp_level: O1 # O1 or O2 if using AMP
|
||||
accelerator: ddp
|
||||
gradient_clip_val: 0.0
|
||||
val_check_interval: 1.0 # check once per epoch .25 for 4 times per epoch
|
||||
|
|
|
@ -21,7 +21,6 @@ trainer:
|
|||
max_steps: null # precedence over max_epochs
|
||||
accumulate_grad_batches: 1 # accumulates grads every k batches
|
||||
gradient_clip_val: 0.0
|
||||
amp_level: O0 # O1/O2 for mixed precision
|
||||
precision: 32 # Should be set to 16 for O1 and O2 to enable the AMP.
|
||||
accelerator: ddp
|
||||
log_every_n_steps: 1 # Interval of logging.
|
||||
|
|
|
@ -24,7 +24,6 @@ trainer:
|
|||
max_steps: null # precedence over max_epochs
|
||||
accumulate_grad_batches: 1 # accumulates grads every k batches
|
||||
gradient_clip_val: 0.0
|
||||
amp_level: O0 # O1/O2 for mixed precision
|
||||
precision: 16 # Should be set to 16 for O1 and O2, default is 16 as PT ignores it when am_level is O0
|
||||
accelerator: ddp
|
||||
checkpoint_callback: false # Provided by exp_manager
|
||||
|
|
|
@ -23,7 +23,6 @@ trainer:
|
|||
max_steps: null # precedence over max_epochs
|
||||
accumulate_grad_batches: 1 # accumulates grads every k batches
|
||||
gradient_clip_val: 0.0
|
||||
amp_level: O0 # O1/O2 for mixed precision
|
||||
precision: 16 # Should be set to 16 for O1 and O2, default is 16 as PT ignores it when am_level is O0
|
||||
accelerator: ddp
|
||||
checkpoint_callback: False # Provided by exp_manager
|
||||
|
|
|
@ -62,13 +62,7 @@ def main(cfg: DictConfig) -> None:
|
|||
else:
|
||||
gpu = 1 if cfg.trainer.gpus != 0 else 0
|
||||
|
||||
trainer = pl.Trainer(
|
||||
gpus=gpu,
|
||||
precision=cfg.trainer.precision,
|
||||
amp_level=cfg.trainer.amp_level,
|
||||
logger=False,
|
||||
checkpoint_callback=False,
|
||||
)
|
||||
trainer = pl.Trainer(gpus=gpu, precision=cfg.trainer.precision, logger=False, checkpoint_callback=False,)
|
||||
exp_dir = exp_manager(trainer, cfg.exp_manager)
|
||||
|
||||
if not cfg.pretrained_model:
|
||||
|
|
|
@ -70,13 +70,7 @@ def main(cfg: DictConfig) -> None:
|
|||
else:
|
||||
gpu = 1 if cfg.trainer.gpus != 0 else 0
|
||||
|
||||
trainer = pl.Trainer(
|
||||
gpus=gpu,
|
||||
precision=cfg.trainer.precision,
|
||||
amp_level=cfg.trainer.amp_level,
|
||||
logger=False,
|
||||
checkpoint_callback=False,
|
||||
)
|
||||
trainer = pl.Trainer(gpus=gpu, precision=cfg.trainer.precision, logger=False, checkpoint_callback=False,)
|
||||
exp_dir = exp_manager(trainer, cfg.exp_manager)
|
||||
|
||||
if not cfg.pretrained_model:
|
||||
|
|
|
@ -19,7 +19,6 @@ trainer:
|
|||
max_epochs: 1
|
||||
max_steps: null # precedence over max_epochs
|
||||
accumulate_grad_batches: 1 # accumulates grads every k batches
|
||||
amp_level: O0 # O1/O2 for mixed precision
|
||||
precision: 16
|
||||
accelerator: ddp
|
||||
log_every_n_steps: 1 # Interval of logging.
|
||||
|
|
|
@ -143,7 +143,6 @@ trainer:
|
|||
num_nodes: 1
|
||||
accelerator: ddp
|
||||
accumulate_grad_batches: 1
|
||||
amp_level: O0
|
||||
deterministic: True
|
||||
checkpoint_callback: False
|
||||
logger: False
|
||||
|
|
|
@ -131,7 +131,6 @@ trainer:
|
|||
num_nodes: 1
|
||||
accelerator: ddp
|
||||
accumulate_grad_batches: 1
|
||||
amp_level: O0
|
||||
deterministic: True
|
||||
checkpoint_callback: False
|
||||
logger: False
|
||||
|
|
|
@ -93,7 +93,6 @@ trainer:
|
|||
num_nodes: 1
|
||||
accelerator: ddp
|
||||
accumulate_grad_batches: 1
|
||||
amp_level: O0
|
||||
deterministic: False
|
||||
checkpoint_callback: False
|
||||
logger: False
|
||||
|
|
|
@ -95,7 +95,6 @@ trainer:
|
|||
log_every_n_steps: 200
|
||||
check_val_every_n_epoch: 25
|
||||
precision: 16
|
||||
amp_level: O1
|
||||
|
||||
exp_manager:
|
||||
exp_dir: null
|
||||
|
|
|
@ -29,7 +29,7 @@ class LogEpochTimeCallback(Callback):
|
|||
self.epoch_start = time.time()
|
||||
|
||||
@rank_zero_only
|
||||
def on_train_epoch_end(self, trainer, pl_module, outputs):
|
||||
def on_train_epoch_end(self, trainer, pl_module):
|
||||
curr_time = time.time()
|
||||
duration = curr_time - self.epoch_start
|
||||
trainer.logger.log_metrics({"epoch_time": duration}, step=trainer.global_step)
|
||||
|
|
23
nemo/collections/common/parts/ptl_overrides.py
Normal file
23
nemo/collections/common/parts/ptl_overrides.py
Normal file
|
@ -0,0 +1,23 @@
|
|||
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
from pytorch_lightning.plugins.precision.native_amp import NativeMixedPrecisionPlugin
|
||||
|
||||
|
||||
class NeMoNativeMixedPrecisionPlugin(NativeMixedPrecisionPlugin):
|
||||
def __init__(self, init_scale: float = 2 ** 32, growth_interval: int = 1000) -> None:
|
||||
super().__init__(precision=16)
|
||||
|
||||
self.scaler = torch.cuda.amp.GradScaler(init_scale=init_scale, growth_interval=growth_interval)
|
|
@ -33,6 +33,7 @@ class AutoTokenizer(TokenizerSpec):
|
|||
self,
|
||||
pretrained_model_name: str,
|
||||
vocab_file: Optional[str] = None,
|
||||
merges_file: Optional[str] = None,
|
||||
mask_token: Optional[str] = None,
|
||||
bos_token: Optional[str] = None,
|
||||
eos_token: Optional[str] = None,
|
||||
|
@ -60,19 +61,26 @@ class AutoTokenizer(TokenizerSpec):
|
|||
use_fast: whether to use fast HuggingFace tokenizer
|
||||
"""
|
||||
try:
|
||||
if vocab_file is not None:
|
||||
message = 'Using "slow" HuggingFace tokenizer'
|
||||
if use_fast:
|
||||
message += f'{vocab_file} is ignored in "fast" tokenizers, using a "slow" version'
|
||||
# this logic deals with different huggingface tokenizers having different positional args
|
||||
if vocab_file is None:
|
||||
self.tokenizer = AUTOTOKENIZER.from_pretrained(
|
||||
pretrained_model_name_or_path=pretrained_model_name, vocab_file=vocab_file, use_fast=False
|
||||
pretrained_model_name_or_path=pretrained_model_name, use_fast=use_fast,
|
||||
)
|
||||
elif merges_file is None:
|
||||
self.tokenizer = AUTOTOKENIZER.from_pretrained(
|
||||
pretrained_model_name_or_path=pretrained_model_name, vocab_file=vocab_file, use_fast=use_fast,
|
||||
)
|
||||
else:
|
||||
self.tokenizer = AUTOTOKENIZER.from_pretrained(
|
||||
pretrained_model_name_or_path=pretrained_model_name, use_fast=use_fast
|
||||
pretrained_model_name_or_path=pretrained_model_name,
|
||||
vocab_file=vocab_file,
|
||||
merges_file=merges_file,
|
||||
use_fast=use_fast,
|
||||
)
|
||||
except Exception as e:
|
||||
raise ValueError(f'{pretrained_model_name} is not supported by HuggingFace. {e}')
|
||||
raise ValueError(
|
||||
f'Unable to instantiate HuggingFace AUTOTOKENIZER for {pretrained_model_name}. Exception: {e}'
|
||||
)
|
||||
|
||||
special_tokens_dict = {}
|
||||
|
||||
|
|
|
@ -17,4 +17,5 @@ from nemo.collections.nlp.data.language_modeling.lm_bert_dataset import (
|
|||
BertPretrainingDataset,
|
||||
BertPretrainingPreprocessedDataloader,
|
||||
)
|
||||
from nemo.collections.nlp.data.language_modeling.megatron import GPTDataset, IndexedDataset, MMapIndexedDataset
|
||||
from nemo.collections.nlp.data.language_modeling.sentence_dataset import SentenceDataset, TarredSentenceDataset
|
||||
|
|
|
@ -0,0 +1,23 @@
|
|||
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
CXXFLAGS += -O3 -Wall -shared -std=c++11 -fPIC -fdiagnostics-color
|
||||
CPPFLAGS += $(shell python3 -m pybind11 --includes)
|
||||
LIBNAME = helpers
|
||||
LIBEXT = $(shell python3-config --extension-suffix)
|
||||
|
||||
default: $(LIBNAME)$(LIBEXT)
|
||||
|
||||
%$(LIBEXT): %.cpp
|
||||
$(CXX) $(CXXFLAGS) $(CPPFLAGS) $< -o $@
|
|
@ -0,0 +1,20 @@
|
|||
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from nemo.collections.nlp.data.language_modeling.megatron.gpt_dataset import GPTDataset
|
||||
from nemo.collections.nlp.data.language_modeling.megatron.indexed_dataset import IndexedDataset, MMapIndexedDataset
|
||||
|
||||
# TODO: refactor these datasets to work without megatron-lm dependency
|
||||
# from nemo.collections.nlp.data.language_modeling.megatron.bert_dataset import BertDataset
|
||||
# from nemo.collections.nlp.data.language_modeling.megatron.t5_dataset import T5Dataset
|
|
@ -0,0 +1,216 @@
|
|||
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""BERT Style dataset."""
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from nemo.collections.nlp.data.language_modeling.megatron.dataset_utils import (
|
||||
create_masked_lm_predictions,
|
||||
create_tokens_and_tokentypes,
|
||||
get_a_and_b_segments,
|
||||
get_samples_mapping,
|
||||
truncate_segments,
|
||||
)
|
||||
|
||||
|
||||
class BertDataset(torch.utils.data.Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
name,
|
||||
indexed_dataset,
|
||||
data_prefix,
|
||||
num_epochs,
|
||||
max_num_samples,
|
||||
masked_lm_prob,
|
||||
max_seq_length,
|
||||
short_seq_prob,
|
||||
seed,
|
||||
binary_head,
|
||||
tokenizer,
|
||||
):
|
||||
|
||||
# Params to store.
|
||||
self.name = name
|
||||
self.seed = seed
|
||||
self.masked_lm_prob = masked_lm_prob
|
||||
self.max_seq_length = max_seq_length
|
||||
self.binary_head = binary_head
|
||||
|
||||
# Dataset.
|
||||
self.indexed_dataset = indexed_dataset
|
||||
|
||||
# Build the samples mapping.
|
||||
self.samples_mapping = get_samples_mapping(
|
||||
self.indexed_dataset,
|
||||
data_prefix,
|
||||
num_epochs,
|
||||
max_num_samples,
|
||||
self.max_seq_length - 3, # account for added tokens
|
||||
short_seq_prob,
|
||||
self.seed,
|
||||
self.name,
|
||||
self.binary_head,
|
||||
)
|
||||
|
||||
# Vocab stuff.
|
||||
self.vocab_id_list = list(tokenizer.inv_vocab.keys())
|
||||
self.vocab_id_to_token_dict = tokenizer.inv_vocab
|
||||
self.cls_id = tokenizer.cls_id
|
||||
self.sep_id = tokenizer.sep_id
|
||||
self.mask_id = tokenizer.mask_id
|
||||
self.pad_id = tokenizer.pad_id
|
||||
|
||||
def __len__(self):
|
||||
return self.samples_mapping.shape[0]
|
||||
|
||||
def __getitem__(self, idx):
|
||||
start_idx, end_idx, seq_length = self.samples_mapping[idx]
|
||||
sample = [self.indexed_dataset[i] for i in range(start_idx, end_idx)]
|
||||
# Note that this rng state should be numpy and not python since
|
||||
# python randint is inclusive whereas the numpy one is exclusive.
|
||||
# We % 2**32 since numpy requres the seed to be between 0 and 2**32 - 1
|
||||
np_rng = np.random.RandomState(seed=((self.seed + idx) % 2 ** 32))
|
||||
return build_training_sample(
|
||||
sample,
|
||||
seq_length,
|
||||
self.max_seq_length, # needed for padding
|
||||
self.vocab_id_list,
|
||||
self.vocab_id_to_token_dict,
|
||||
self.cls_id,
|
||||
self.sep_id,
|
||||
self.mask_id,
|
||||
self.pad_id,
|
||||
self.masked_lm_prob,
|
||||
np_rng,
|
||||
self.binary_head,
|
||||
)
|
||||
|
||||
|
||||
def build_training_sample(
|
||||
sample,
|
||||
target_seq_length,
|
||||
max_seq_length,
|
||||
vocab_id_list,
|
||||
vocab_id_to_token_dict,
|
||||
cls_id,
|
||||
sep_id,
|
||||
mask_id,
|
||||
pad_id,
|
||||
masked_lm_prob,
|
||||
np_rng,
|
||||
binary_head,
|
||||
):
|
||||
"""Biuld training sample.
|
||||
|
||||
Arguments:
|
||||
sample: A list of sentences in which each sentence is a list token ids.
|
||||
target_seq_length: Desired sequence length.
|
||||
max_seq_length: Maximum length of the sequence. All values are padded to
|
||||
this length.
|
||||
vocab_id_list: List of vocabulary ids. Used to pick a random id.
|
||||
vocab_id_to_token_dict: A dictionary from vocab ids to text tokens.
|
||||
cls_id: Start of example id.
|
||||
sep_id: Separator id.
|
||||
mask_id: Mask token id.
|
||||
pad_id: Padding token id.
|
||||
masked_lm_prob: Probability to mask tokens.
|
||||
np_rng: Random number genenrator. Note that this rng state should be
|
||||
numpy and not python since python randint is inclusive for
|
||||
the opper bound whereas the numpy one is exclusive.
|
||||
"""
|
||||
|
||||
if binary_head:
|
||||
# We assume that we have at least two sentences in the sample
|
||||
assert len(sample) > 1
|
||||
assert target_seq_length <= max_seq_length
|
||||
|
||||
# Divide sample into two segments (A and B).
|
||||
if binary_head:
|
||||
tokens_a, tokens_b, is_next_random = get_a_and_b_segments(sample, np_rng)
|
||||
else:
|
||||
tokens_a = []
|
||||
for j in range(len(sample)):
|
||||
tokens_a.extend(sample[j])
|
||||
tokens_b = []
|
||||
is_next_random = False
|
||||
|
||||
# Truncate to `target_sequence_length`.
|
||||
max_num_tokens = target_seq_length
|
||||
truncated = truncate_segments(tokens_a, tokens_b, len(tokens_a), len(tokens_b), max_num_tokens, np_rng)
|
||||
|
||||
# Build tokens and toketypes.
|
||||
tokens, tokentypes = create_tokens_and_tokentypes(tokens_a, tokens_b, cls_id, sep_id)
|
||||
|
||||
# Masking.
|
||||
max_predictions_per_seq = masked_lm_prob * max_num_tokens
|
||||
(tokens, masked_positions, masked_labels, _, _) = create_masked_lm_predictions(
|
||||
tokens,
|
||||
vocab_id_list,
|
||||
vocab_id_to_token_dict,
|
||||
masked_lm_prob,
|
||||
cls_id,
|
||||
sep_id,
|
||||
mask_id,
|
||||
max_predictions_per_seq,
|
||||
np_rng,
|
||||
)
|
||||
|
||||
# Padding.
|
||||
tokens_np, tokentypes_np, labels_np, padding_mask_np, loss_mask_np = pad_and_convert_to_numpy(
|
||||
tokens, tokentypes, masked_positions, masked_labels, pad_id, max_seq_length
|
||||
)
|
||||
|
||||
train_sample = {
|
||||
'text': tokens_np,
|
||||
'types': tokentypes_np,
|
||||
'labels': labels_np,
|
||||
'is_random': int(is_next_random),
|
||||
'loss_mask': loss_mask_np,
|
||||
'padding_mask': padding_mask_np,
|
||||
'truncated': int(truncated),
|
||||
}
|
||||
return train_sample
|
||||
|
||||
|
||||
def pad_and_convert_to_numpy(tokens, tokentypes, masked_positions, masked_labels, pad_id, max_seq_length):
|
||||
"""Pad sequences and convert them to numpy."""
|
||||
|
||||
# Some checks.
|
||||
num_tokens = len(tokens)
|
||||
padding_length = max_seq_length - num_tokens
|
||||
assert padding_length >= 0
|
||||
assert len(tokentypes) == num_tokens
|
||||
assert len(masked_positions) == len(masked_labels)
|
||||
|
||||
# Tokens and token types.
|
||||
filler = [pad_id] * padding_length
|
||||
tokens_np = np.array(tokens + filler, dtype=np.int64)
|
||||
tokentypes_np = np.array(tokentypes + filler, dtype=np.int64)
|
||||
|
||||
# Padding mask.
|
||||
padding_mask_np = np.array([1] * num_tokens + [0] * padding_length, dtype=np.int64)
|
||||
|
||||
# Lables and loss mask.
|
||||
labels = [-1] * max_seq_length
|
||||
loss_mask = [0] * max_seq_length
|
||||
for i in range(len(masked_positions)):
|
||||
assert masked_positions[i] < num_tokens
|
||||
labels[masked_positions[i]] = masked_labels[i]
|
||||
loss_mask[masked_positions[i]] = 1
|
||||
labels_np = np.array(labels, dtype=np.int64)
|
||||
loss_mask_np = np.array(loss_mask, dtype=np.int64)
|
||||
|
||||
return tokens_np, tokentypes_np, labels_np, padding_mask_np, loss_mask_np
|
|
@ -0,0 +1,75 @@
|
|||
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Blendable dataset."""
|
||||
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from nemo.utils import logging
|
||||
from nemo.utils.get_rank import is_global_rank_zero
|
||||
|
||||
|
||||
class BlendableDataset(torch.utils.data.Dataset):
|
||||
def __init__(self, datasets, weights):
|
||||
|
||||
self.datasets = datasets
|
||||
num_datasets = len(datasets)
|
||||
assert num_datasets == len(weights)
|
||||
|
||||
self.size = 0
|
||||
for dataset in self.datasets:
|
||||
self.size += len(dataset)
|
||||
|
||||
# Normalize weights.
|
||||
weights = np.array(weights, dtype=np.float64)
|
||||
sum_weights = np.sum(weights)
|
||||
assert sum_weights > 0.0
|
||||
weights /= sum_weights
|
||||
|
||||
# Build indecies.
|
||||
start_time = time.time()
|
||||
assert num_datasets < 255
|
||||
self.dataset_index = np.zeros(self.size, dtype=np.uint8)
|
||||
self.dataset_sample_index = np.zeros(self.size, dtype=np.int64)
|
||||
|
||||
try:
|
||||
if is_global_rank_zero():
|
||||
from nemo.collections.nlp.data.language_modeling.megatron.dataset_utils import compile_helper
|
||||
|
||||
compile_helper()
|
||||
from nemo.collections.nlp.data.language_modeling.megatron import helpers
|
||||
except:
|
||||
raise Exception(f'Could not compile helpers.')
|
||||
helpers.build_blending_indices(
|
||||
self.dataset_index,
|
||||
self.dataset_sample_index,
|
||||
weights,
|
||||
num_datasets,
|
||||
self.size,
|
||||
torch.distributed.get_rank() == 0,
|
||||
)
|
||||
logging.info(
|
||||
'> elapsed time for building blendable dataset indices: ' '{:.2f} (sec)'.format(time.time() - start_time)
|
||||
)
|
||||
|
||||
def __len__(self):
|
||||
return self.size
|
||||
|
||||
def __getitem__(self, idx):
|
||||
dataset_idx = self.dataset_index[idx]
|
||||
sample_idx = self.dataset_sample_index[idx]
|
||||
return self.datasets[dataset_idx][sample_idx]
|
|
@ -0,0 +1,121 @@
|
|||
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Dataloaders."""
|
||||
|
||||
|
||||
import torch
|
||||
|
||||
from nemo.utils import logging
|
||||
|
||||
|
||||
class MegatronPretrainingSampler:
|
||||
def __init__(
|
||||
self, total_samples, consumed_samples, micro_batch_size, data_parallel_rank, data_parallel_size, drop_last=True
|
||||
):
|
||||
# Keep a copy of input params for later use.
|
||||
self.total_samples = total_samples
|
||||
self.consumed_samples = consumed_samples
|
||||
self.micro_batch_size = micro_batch_size
|
||||
self.data_parallel_rank = data_parallel_rank
|
||||
self.micro_batch_times_data_parallel_size = self.micro_batch_size * data_parallel_size
|
||||
self.drop_last = drop_last
|
||||
|
||||
logging.info(
|
||||
f'Instantiating MegatronPretrainingSampler with total_samples: {total_samples} and consumed_samples: {consumed_samples}'
|
||||
)
|
||||
|
||||
# Sanity checks.
|
||||
assert self.total_samples > 0, 'no sample to consume: {}'.format(self.total_samples)
|
||||
assert self.consumed_samples < self.total_samples, 'no samples left to consume: {}, {}'.format(
|
||||
self.consumed_samples, self.total_samples
|
||||
)
|
||||
assert self.micro_batch_size > 0
|
||||
assert data_parallel_size > 0
|
||||
assert self.data_parallel_rank < data_parallel_size, (
|
||||
'data_parallel_rank should be smaller than data size: {}, '
|
||||
'{}'.format(self.data_parallel_rank, data_parallel_size)
|
||||
)
|
||||
|
||||
def __len__(self):
|
||||
return self.total_samples
|
||||
|
||||
def get_start_end_idx(self):
|
||||
start_idx = self.data_parallel_rank * self.micro_batch_size
|
||||
end_idx = start_idx + self.micro_batch_size
|
||||
return start_idx, end_idx
|
||||
|
||||
def __iter__(self):
|
||||
batch = []
|
||||
# Last batch will be dropped if drop_last is not set False
|
||||
for idx in range(self.consumed_samples, self.total_samples):
|
||||
batch.append(idx)
|
||||
if len(batch) == self.micro_batch_times_data_parallel_size:
|
||||
start_idx, end_idx = self.get_start_end_idx()
|
||||
yield batch[start_idx:end_idx]
|
||||
batch = []
|
||||
|
||||
# Check the last partial batch and see drop_last is set
|
||||
if len(batch) > 0 and not self.drop_last:
|
||||
start_idx, end_idx = self.get_start_end_idx()
|
||||
yield batch[start_idx:end_idx]
|
||||
|
||||
|
||||
class MegatronPretrainingRandomSampler:
|
||||
def __init__(self, total_samples, consumed_samples, micro_batch_size, data_parallel_rank, data_parallel_size):
|
||||
# Keep a copy of input params for later use.
|
||||
self.total_samples = total_samples
|
||||
self.consumed_samples = consumed_samples
|
||||
self.micro_batch_size = micro_batch_size
|
||||
self.data_parallel_rank = data_parallel_rank
|
||||
self.data_parallel_size = data_parallel_size
|
||||
self.micro_batch_times_data_parallel_size = self.micro_batch_size * data_parallel_size
|
||||
self.last_batch_size = self.total_samples % self.micro_batch_times_data_parallel_size
|
||||
|
||||
# Sanity checks.
|
||||
assert self.total_samples > 0, 'no sample to consume: {}'.format(self.total_samples)
|
||||
assert self.micro_batch_size > 0
|
||||
assert data_parallel_size > 0
|
||||
assert self.data_parallel_rank < data_parallel_size, (
|
||||
'data_parallel_rank should be smaller than data size: {}, '
|
||||
'{}'.format(self.data_parallel_rank, data_parallel_size)
|
||||
)
|
||||
|
||||
def __len__(self):
|
||||
return self.total_samples
|
||||
|
||||
def __iter__(self):
|
||||
active_total_samples = self.total_samples - self.last_batch_size
|
||||
self.epoch = self.consumed_samples // active_total_samples
|
||||
current_epoch_samples = self.consumed_samples % active_total_samples
|
||||
assert current_epoch_samples % self.micro_batch_times_data_parallel_size == 0
|
||||
|
||||
# data sharding and random sampling
|
||||
bucket_size = (self.total_samples // self.micro_batch_times_data_parallel_size) * self.micro_batch_size
|
||||
bucket_offset = current_epoch_samples // self.data_parallel_size
|
||||
start_idx = self.data_parallel_rank * bucket_size
|
||||
|
||||
g = torch.Generator()
|
||||
g.manual_seed(self.epoch)
|
||||
random_idx = torch.randperm(bucket_size, generator=g).tolist()
|
||||
idx_range = [start_idx + x for x in random_idx[bucket_offset:]]
|
||||
|
||||
batch = []
|
||||
# Last batch if not complete will be dropped.
|
||||
for idx in idx_range:
|
||||
batch.append(idx)
|
||||
if len(batch) == self.micro_batch_size:
|
||||
self.consumed_samples += self.micro_batch_times_data_parallel_size
|
||||
yield batch
|
||||
batch = []
|
|
@ -0,0 +1,732 @@
|
|||
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Copyright 2018 The Google AI Team Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Most of the code here has been copied from:
|
||||
# https://github.com/google-research/albert/blob/master/create_pretraining_data.py
|
||||
# with some modifications.
|
||||
|
||||
import collections
|
||||
import math
|
||||
import os
|
||||
import subprocess
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from apex.transformer import parallel_state
|
||||
|
||||
from nemo.collections.nlp.data.language_modeling.megatron.blendable_dataset import BlendableDataset
|
||||
from nemo.collections.nlp.data.language_modeling.megatron.indexed_dataset import make_dataset as make_indexed_dataset
|
||||
from nemo.utils import logging
|
||||
from nemo.utils.get_rank import is_global_rank_zero
|
||||
|
||||
DSET_TYPE_BERT = 'standard_bert'
|
||||
DSET_TYPE_ICT = 'ict'
|
||||
DSET_TYPE_T5 = 't5'
|
||||
|
||||
DSET_TYPES = [DSET_TYPE_BERT, DSET_TYPE_ICT, DSET_TYPE_T5]
|
||||
|
||||
|
||||
def get_datasets_weights_and_num_samples(data_prefix, train_valid_test_num_samples):
|
||||
|
||||
# The data prefix should be in the format of:
|
||||
# weight-1, data-prefix-1, weight-2, data-prefix-2, ..
|
||||
assert len(data_prefix) % 2 == 0
|
||||
num_datasets = len(data_prefix) // 2
|
||||
weights = [0] * num_datasets
|
||||
prefixes = [0] * num_datasets
|
||||
for i in range(num_datasets):
|
||||
weights[i] = float(data_prefix[2 * i])
|
||||
prefixes[i] = (data_prefix[2 * i + 1]).strip()
|
||||
# Normalize weights
|
||||
weight_sum = 0.0
|
||||
for weight in weights:
|
||||
weight_sum += weight
|
||||
assert weight_sum > 0.0
|
||||
weights = [weight / weight_sum for weight in weights]
|
||||
|
||||
# Add 0.5% (the 1.005 factor) so in case the bleding dataset does
|
||||
# not uniformly distribute the number of samples, we still have
|
||||
# samples left to feed to the network.
|
||||
datasets_train_valid_test_num_samples = []
|
||||
for weight in weights:
|
||||
datasets_train_valid_test_num_samples.append(
|
||||
[int(math.ceil(val * weight * 1.005)) for val in train_valid_test_num_samples]
|
||||
)
|
||||
|
||||
return prefixes, weights, datasets_train_valid_test_num_samples
|
||||
|
||||
|
||||
def compile_helper():
|
||||
"""Compile helper function ar runtime. Make sure this
|
||||
is invoked on a single process."""
|
||||
|
||||
path = os.path.abspath(os.path.dirname(__file__))
|
||||
ret = subprocess.run(['make', '-C', path])
|
||||
if ret.returncode != 0:
|
||||
logging.error("Making C++ dataset helpers module failed, exiting.")
|
||||
import sys
|
||||
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def get_a_and_b_segments(sample, np_rng):
|
||||
"""Divide sample into a and b segments."""
|
||||
|
||||
# Number of sentences in the sample.
|
||||
n_sentences = len(sample)
|
||||
# Make sure we always have two sentences.
|
||||
assert n_sentences > 1, 'make sure each sample has at least two sentences.'
|
||||
|
||||
# First part:
|
||||
# `a_end` is how many sentences go into the `A`.
|
||||
a_end = 1
|
||||
if n_sentences >= 3:
|
||||
# Note that randin in numpy is exclusive.
|
||||
a_end = np_rng.randint(1, n_sentences)
|
||||
tokens_a = []
|
||||
for j in range(a_end):
|
||||
tokens_a.extend(sample[j])
|
||||
|
||||
# Second part:
|
||||
tokens_b = []
|
||||
for j in range(a_end, n_sentences):
|
||||
tokens_b.extend(sample[j])
|
||||
|
||||
# Random next:
|
||||
is_next_random = False
|
||||
if np_rng.random() < 0.5:
|
||||
is_next_random = True
|
||||
tokens_a, tokens_b = tokens_b, tokens_a
|
||||
|
||||
return tokens_a, tokens_b, is_next_random
|
||||
|
||||
|
||||
def truncate_segments(tokens_a, tokens_b, len_a, len_b, max_num_tokens, np_rng):
|
||||
"""Truncates a pair of sequences to a maximum sequence length."""
|
||||
# print(len_a, len_b, max_num_tokens)
|
||||
assert len_a > 0
|
||||
if len_a + len_b <= max_num_tokens:
|
||||
return False
|
||||
while len_a + len_b > max_num_tokens:
|
||||
if len_a > len_b:
|
||||
len_a -= 1
|
||||
tokens = tokens_a
|
||||
else:
|
||||
len_b -= 1
|
||||
tokens = tokens_b
|
||||
if np_rng.random() < 0.5:
|
||||
del tokens[0]
|
||||
else:
|
||||
tokens.pop()
|
||||
return True
|
||||
|
||||
|
||||
def create_tokens_and_tokentypes(tokens_a, tokens_b, cls_id, sep_id):
|
||||
"""Merge segments A and B, add [CLS] and [SEP] and build tokentypes."""
|
||||
|
||||
tokens = []
|
||||
tokentypes = []
|
||||
# [CLS].
|
||||
tokens.append(cls_id)
|
||||
tokentypes.append(0)
|
||||
# Segment A.
|
||||
for token in tokens_a:
|
||||
tokens.append(token)
|
||||
tokentypes.append(0)
|
||||
# [SEP].
|
||||
tokens.append(sep_id)
|
||||
tokentypes.append(0)
|
||||
# Segment B.
|
||||
for token in tokens_b:
|
||||
tokens.append(token)
|
||||
tokentypes.append(1)
|
||||
if tokens_b:
|
||||
# [SEP].
|
||||
tokens.append(sep_id)
|
||||
tokentypes.append(1)
|
||||
|
||||
return tokens, tokentypes
|
||||
|
||||
|
||||
MaskedLmInstance = collections.namedtuple("MaskedLmInstance", ["index", "label"])
|
||||
|
||||
|
||||
def is_start_piece(piece):
|
||||
"""Check if the current word piece is the starting piece (BERT)."""
|
||||
# When a word has been split into
|
||||
# WordPieces, the first token does not have any marker and any subsequence
|
||||
# tokens are prefixed with ##. So whenever we see the ## token, we
|
||||
# append it to the previous set of word indexes.
|
||||
return not piece.startswith("##")
|
||||
|
||||
|
||||
def create_masked_lm_predictions(
|
||||
tokens,
|
||||
vocab_id_list,
|
||||
vocab_id_to_token_dict,
|
||||
masked_lm_prob,
|
||||
cls_id,
|
||||
sep_id,
|
||||
mask_id,
|
||||
max_predictions_per_seq,
|
||||
np_rng,
|
||||
max_ngrams=3,
|
||||
do_whole_word_mask=True,
|
||||
favor_longer_ngram=False,
|
||||
do_permutation=False,
|
||||
geometric_dist=False,
|
||||
masking_style="bert",
|
||||
):
|
||||
"""Creates the predictions for the masked LM objective.
|
||||
Note: Tokens here are vocab ids and not text tokens."""
|
||||
|
||||
cand_indexes = []
|
||||
# Note(mingdachen): We create a list for recording if the piece is
|
||||
# the starting piece of current token, where 1 means true, so that
|
||||
# on-the-fly whole word masking is possible.
|
||||
token_boundary = [0] * len(tokens)
|
||||
|
||||
for (i, token) in enumerate(tokens):
|
||||
if token == cls_id or token == sep_id:
|
||||
token_boundary[i] = 1
|
||||
continue
|
||||
# Whole Word Masking means that if we mask all of the wordpieces
|
||||
# corresponding to an original word.
|
||||
#
|
||||
# Note that Whole Word Masking does *not* change the training code
|
||||
# at all -- we still predict each WordPiece independently, softmaxed
|
||||
# over the entire vocabulary.
|
||||
if do_whole_word_mask and len(cand_indexes) >= 1 and not is_start_piece(vocab_id_to_token_dict[token]):
|
||||
cand_indexes[-1].append(i)
|
||||
else:
|
||||
cand_indexes.append([i])
|
||||
if is_start_piece(vocab_id_to_token_dict[token]):
|
||||
token_boundary[i] = 1
|
||||
|
||||
output_tokens = list(tokens)
|
||||
|
||||
masked_lm_positions = []
|
||||
masked_lm_labels = []
|
||||
|
||||
if masked_lm_prob == 0:
|
||||
return (output_tokens, masked_lm_positions, masked_lm_labels, token_boundary)
|
||||
|
||||
num_to_predict = min(max_predictions_per_seq, max(1, int(round(len(tokens) * masked_lm_prob))))
|
||||
|
||||
ngrams = np.arange(1, max_ngrams + 1, dtype=np.int64)
|
||||
if not geometric_dist:
|
||||
# Note(mingdachen):
|
||||
# By default, we set the probilities to favor shorter ngram sequences.
|
||||
pvals = 1.0 / np.arange(1, max_ngrams + 1)
|
||||
pvals /= pvals.sum(keepdims=True)
|
||||
if favor_longer_ngram:
|
||||
pvals = pvals[::-1]
|
||||
|
||||
ngram_indexes = []
|
||||
for idx in range(len(cand_indexes)):
|
||||
ngram_index = []
|
||||
for n in ngrams:
|
||||
ngram_index.append(cand_indexes[idx : idx + n])
|
||||
ngram_indexes.append(ngram_index)
|
||||
|
||||
np_rng.shuffle(ngram_indexes)
|
||||
|
||||
(masked_lms, masked_spans) = ([], [])
|
||||
covered_indexes = set()
|
||||
for cand_index_set in ngram_indexes:
|
||||
if len(masked_lms) >= num_to_predict:
|
||||
break
|
||||
if not cand_index_set:
|
||||
continue
|
||||
# Note(mingdachen):
|
||||
# Skip current piece if they are covered in lm masking or previous ngrams.
|
||||
for index_set in cand_index_set[0]:
|
||||
for index in index_set:
|
||||
if index in covered_indexes:
|
||||
continue
|
||||
|
||||
if not geometric_dist:
|
||||
n = np_rng.choice(
|
||||
ngrams[: len(cand_index_set)],
|
||||
p=pvals[: len(cand_index_set)] / pvals[: len(cand_index_set)].sum(keepdims=True),
|
||||
)
|
||||
else:
|
||||
# Sampling "n" from the geometric distribution and clipping it to
|
||||
# the max_ngrams. Using p=0.2 default from the SpanBERT paper
|
||||
# https://arxiv.org/pdf/1907.10529.pdf (Sec 3.1)
|
||||
n = min(np_rng.geometric(0.2), max_ngrams)
|
||||
|
||||
index_set = sum(cand_index_set[n - 1], [])
|
||||
n -= 1
|
||||
# Note(mingdachen):
|
||||
# Repeatedly looking for a candidate that does not exceed the
|
||||
# maximum number of predictions by trying shorter ngrams.
|
||||
while len(masked_lms) + len(index_set) > num_to_predict:
|
||||
if n == 0:
|
||||
break
|
||||
index_set = sum(cand_index_set[n - 1], [])
|
||||
n -= 1
|
||||
# If adding a whole-word mask would exceed the maximum number of
|
||||
# predictions, then just skip this candidate.
|
||||
if len(masked_lms) + len(index_set) > num_to_predict:
|
||||
continue
|
||||
is_any_index_covered = False
|
||||
for index in index_set:
|
||||
if index in covered_indexes:
|
||||
is_any_index_covered = True
|
||||
break
|
||||
if is_any_index_covered:
|
||||
continue
|
||||
for index in index_set:
|
||||
covered_indexes.add(index)
|
||||
masked_token = None
|
||||
if masking_style == "bert":
|
||||
# 80% of the time, replace with [MASK]
|
||||
if np_rng.random() < 0.8:
|
||||
masked_token = mask_id
|
||||
else:
|
||||
# 10% of the time, keep original
|
||||
if np_rng.random() < 0.5:
|
||||
masked_token = tokens[index]
|
||||
# 10% of the time, replace with random word
|
||||
else:
|
||||
masked_token = vocab_id_list[np_rng.randint(0, len(vocab_id_list))]
|
||||
elif masking_style == "t5":
|
||||
masked_token = mask_id
|
||||
else:
|
||||
raise ValueError("invalid value of masking style")
|
||||
|
||||
output_tokens[index] = masked_token
|
||||
masked_lms.append(MaskedLmInstance(index=index, label=tokens[index]))
|
||||
|
||||
masked_spans.append(MaskedLmInstance(index=index_set, label=[tokens[index] for index in index_set]))
|
||||
|
||||
assert len(masked_lms) <= num_to_predict
|
||||
np_rng.shuffle(ngram_indexes)
|
||||
|
||||
select_indexes = set()
|
||||
if do_permutation:
|
||||
for cand_index_set in ngram_indexes:
|
||||
if len(select_indexes) >= num_to_predict:
|
||||
break
|
||||
if not cand_index_set:
|
||||
continue
|
||||
# Note(mingdachen):
|
||||
# Skip current piece if they are covered in lm masking or previous ngrams.
|
||||
for index_set in cand_index_set[0]:
|
||||
for index in index_set:
|
||||
if index in covered_indexes or index in select_indexes:
|
||||
continue
|
||||
|
||||
n = np.random.choice(
|
||||
ngrams[: len(cand_index_set)],
|
||||
p=pvals[: len(cand_index_set)] / pvals[: len(cand_index_set)].sum(keepdims=True),
|
||||
)
|
||||
index_set = sum(cand_index_set[n - 1], [])
|
||||
n -= 1
|
||||
|
||||
while len(select_indexes) + len(index_set) > num_to_predict:
|
||||
if n == 0:
|
||||
break
|
||||
index_set = sum(cand_index_set[n - 1], [])
|
||||
n -= 1
|
||||
# If adding a whole-word mask would exceed the maximum number of
|
||||
# predictions, then just skip this candidate.
|
||||
if len(select_indexes) + len(index_set) > num_to_predict:
|
||||
continue
|
||||
is_any_index_covered = False
|
||||
for index in index_set:
|
||||
if index in covered_indexes or index in select_indexes:
|
||||
is_any_index_covered = True
|
||||
break
|
||||
if is_any_index_covered:
|
||||
continue
|
||||
for index in index_set:
|
||||
select_indexes.add(index)
|
||||
assert len(select_indexes) <= num_to_predict
|
||||
|
||||
select_indexes = sorted(select_indexes)
|
||||
permute_indexes = list(select_indexes)
|
||||
np_rng.shuffle(permute_indexes)
|
||||
orig_token = list(output_tokens)
|
||||
|
||||
for src_i, tgt_i in zip(select_indexes, permute_indexes):
|
||||
output_tokens[src_i] = orig_token[tgt_i]
|
||||
masked_lms.append(MaskedLmInstance(index=src_i, label=orig_token[src_i]))
|
||||
|
||||
masked_lms = sorted(masked_lms, key=lambda x: x.index)
|
||||
# Sort the spans by the index of the first span
|
||||
masked_spans = sorted(masked_spans, key=lambda x: x.index[0])
|
||||
|
||||
for p in masked_lms:
|
||||
masked_lm_positions.append(p.index)
|
||||
masked_lm_labels.append(p.label)
|
||||
return (output_tokens, masked_lm_positions, masked_lm_labels, token_boundary, masked_spans)
|
||||
|
||||
|
||||
def pad_and_convert_to_numpy(tokens, tokentypes, masked_positions, masked_labels, pad_id, max_seq_length):
|
||||
"""Pad sequences and convert them to numpy."""
|
||||
|
||||
# Some checks.
|
||||
num_tokens = len(tokens)
|
||||
padding_length = max_seq_length - num_tokens
|
||||
assert padding_length >= 0
|
||||
assert len(tokentypes) == num_tokens
|
||||
assert len(masked_positions) == len(masked_labels)
|
||||
|
||||
# Tokens and token types.
|
||||
filler = [pad_id] * padding_length
|
||||
tokens_np = np.array(tokens + filler, dtype=np.int64)
|
||||
tokentypes_np = np.array(tokentypes + filler, dtype=np.int64)
|
||||
|
||||
# Padding mask.
|
||||
padding_mask_np = np.array([1] * num_tokens + [0] * padding_length, dtype=np.int64)
|
||||
|
||||
# Lables and loss mask.
|
||||
labels = [-1] * max_seq_length
|
||||
loss_mask = [0] * max_seq_length
|
||||
for i in range(len(masked_positions)):
|
||||
assert masked_positions[i] < num_tokens
|
||||
labels[masked_positions[i]] = masked_labels[i]
|
||||
loss_mask[masked_positions[i]] = 1
|
||||
labels_np = np.array(labels, dtype=np.int64)
|
||||
loss_mask_np = np.array(loss_mask, dtype=np.int64)
|
||||
|
||||
return tokens_np, tokentypes_np, labels_np, padding_mask_np, loss_mask_np
|
||||
|
||||
|
||||
def build_train_valid_test_datasets(
|
||||
data_prefix,
|
||||
data_impl,
|
||||
splits_string,
|
||||
train_valid_test_num_samples,
|
||||
max_seq_length,
|
||||
masked_lm_prob,
|
||||
short_seq_prob,
|
||||
seed,
|
||||
skip_warmup,
|
||||
binary_head=False,
|
||||
max_seq_length_dec=None,
|
||||
dataset_type='standard_bert',
|
||||
):
|
||||
|
||||
if len(data_prefix) == 1:
|
||||
return _build_train_valid_test_datasets(
|
||||
data_prefix[0],
|
||||
data_impl,
|
||||
splits_string,
|
||||
train_valid_test_num_samples,
|
||||
max_seq_length,
|
||||
masked_lm_prob,
|
||||
short_seq_prob,
|
||||
seed,
|
||||
skip_warmup,
|
||||
binary_head,
|
||||
max_seq_length_dec,
|
||||
dataset_type=dataset_type,
|
||||
)
|
||||
# Blending dataset.
|
||||
# Parse the values.
|
||||
output = get_datasets_weights_and_num_samples(data_prefix, train_valid_test_num_samples)
|
||||
prefixes, weights, datasets_train_valid_test_num_samples = output
|
||||
|
||||
# Build individual datasets.
|
||||
train_datasets = []
|
||||
valid_datasets = []
|
||||
test_datasets = []
|
||||
for i in range(len(prefixes)):
|
||||
train_ds, valid_ds, test_ds = _build_train_valid_test_datasets(
|
||||
prefixes[i],
|
||||
data_impl,
|
||||
splits_string,
|
||||
datasets_train_valid_test_num_samples[i],
|
||||
max_seq_length,
|
||||
masked_lm_prob,
|
||||
short_seq_prob,
|
||||
seed,
|
||||
skip_warmup,
|
||||
binary_head,
|
||||
dataset_type=dataset_type,
|
||||
)
|
||||
if train_ds:
|
||||
train_datasets.append(train_ds)
|
||||
if valid_ds:
|
||||
valid_datasets.append(valid_ds)
|
||||
if test_ds:
|
||||
test_datasets.append(test_ds)
|
||||
|
||||
# Blend.
|
||||
blending_train_dataset = None
|
||||
if train_datasets:
|
||||
blending_train_dataset = BlendableDataset(train_datasets, weights)
|
||||
blending_valid_dataset = None
|
||||
if valid_datasets:
|
||||
blending_valid_dataset = BlendableDataset(valid_datasets, weights)
|
||||
blending_test_dataset = None
|
||||
if test_datasets:
|
||||
blending_test_dataset = BlendableDataset(test_datasets, weights)
|
||||
|
||||
return (blending_train_dataset, blending_valid_dataset, blending_test_dataset)
|
||||
|
||||
|
||||
def _build_train_valid_test_datasets(
|
||||
data_prefix,
|
||||
data_impl,
|
||||
splits_string,
|
||||
train_valid_test_num_samples,
|
||||
max_seq_length,
|
||||
masked_lm_prob,
|
||||
short_seq_prob,
|
||||
seed,
|
||||
skip_warmup,
|
||||
binary_head,
|
||||
max_seq_length_dec,
|
||||
dataset_type='standard_bert',
|
||||
):
|
||||
|
||||
if dataset_type not in DSET_TYPES:
|
||||
raise ValueError("Invalid dataset_type: ", dataset_type)
|
||||
|
||||
# Indexed dataset.
|
||||
indexed_dataset = get_indexed_dataset_(data_prefix, data_impl, skip_warmup)
|
||||
|
||||
if dataset_type == DSET_TYPE_ICT:
|
||||
args = get_args()
|
||||
title_dataset = get_indexed_dataset_(args.titles_data_path, data_impl, skip_warmup)
|
||||
|
||||
# Get start and end indices of train/valid/train into doc-idx
|
||||
# Note that doc-idx is desinged to be num-docs + 1 so we can
|
||||
# easily iterate over it.
|
||||
total_num_of_documents = indexed_dataset.doc_idx.shape[0] - 1
|
||||
splits = get_train_valid_test_split_(splits_string, total_num_of_documents)
|
||||
|
||||
# Print stats about the splits.
|
||||
logging.info(' > dataset split:')
|
||||
|
||||
def print_split_stats(name, index):
|
||||
logging.info(' {}:'.format(name))
|
||||
logging.info(
|
||||
' document indices in [{}, {}) total of {} '
|
||||
'documents'.format(splits[index], splits[index + 1], splits[index + 1] - splits[index])
|
||||
)
|
||||
start_index = indexed_dataset.doc_idx[splits[index]]
|
||||
end_index = indexed_dataset.doc_idx[splits[index + 1]]
|
||||
logging.info(
|
||||
' sentence indices in [{}, {}) total of {} '
|
||||
'sentences'.format(start_index, end_index, end_index - start_index)
|
||||
)
|
||||
|
||||
print_split_stats('train', 0)
|
||||
print_split_stats('validation', 1)
|
||||
print_split_stats('test', 2)
|
||||
|
||||
def build_dataset(index, name):
|
||||
from nemo.collections.nlp.data.language_modeling.megatron.bert_dataset import BertDataset
|
||||
from nemo.collections.nlp.data.language_modeling.megatron.t5_dataset import T5Dataset
|
||||
|
||||
dataset = None
|
||||
if splits[index + 1] > splits[index]:
|
||||
# Get the pointer to the original doc-idx so we can set it later.
|
||||
doc_idx_ptr = indexed_dataset.get_doc_idx()
|
||||
# Slice the doc-idx
|
||||
start_index = splits[index]
|
||||
# Add +1 so we can index into the dataset to get the upper bound.
|
||||
end_index = splits[index + 1] + 1
|
||||
# New doc_idx view.
|
||||
indexed_dataset.set_doc_idx(doc_idx_ptr[start_index:end_index])
|
||||
# Build the dataset accordingly.
|
||||
kwargs = dict(
|
||||
name=name,
|
||||
data_prefix=data_prefix,
|
||||
num_epochs=None,
|
||||
max_num_samples=train_valid_test_num_samples[index],
|
||||
max_seq_length=max_seq_length,
|
||||
seed=seed,
|
||||
)
|
||||
|
||||
if dataset_type == DSET_TYPE_T5:
|
||||
dataset = T5Dataset(
|
||||
indexed_dataset=indexed_dataset,
|
||||
masked_lm_prob=masked_lm_prob,
|
||||
max_seq_length_dec=max_seq_length_dec,
|
||||
short_seq_prob=short_seq_prob,
|
||||
**kwargs,
|
||||
)
|
||||
elif dataset_type == DSET_TYPE_BERT:
|
||||
dataset = BertDataset(
|
||||
indexed_dataset=indexed_dataset,
|
||||
masked_lm_prob=masked_lm_prob,
|
||||
short_seq_prob=short_seq_prob,
|
||||
binary_head=binary_head,
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError("Dataset type not fully implemented.")
|
||||
|
||||
# Set the original pointer so dataset remains the main dataset.
|
||||
indexed_dataset.set_doc_idx(doc_idx_ptr)
|
||||
# Checks.
|
||||
assert indexed_dataset.doc_idx[0] == 0
|
||||
assert indexed_dataset.doc_idx.shape[0] == (total_num_of_documents + 1)
|
||||
return dataset
|
||||
|
||||
train_dataset = build_dataset(0, 'train')
|
||||
valid_dataset = build_dataset(1, 'valid')
|
||||
test_dataset = build_dataset(2, 'test')
|
||||
|
||||
return (train_dataset, valid_dataset, test_dataset)
|
||||
|
||||
|
||||
def get_indexed_dataset_(data_prefix, data_impl, skip_warmup):
|
||||
|
||||
logging.info(' > building dataset index ...')
|
||||
|
||||
start_time = time.time()
|
||||
indexed_dataset = make_indexed_dataset(data_prefix, data_impl, skip_warmup)
|
||||
assert indexed_dataset.sizes.shape[0] == indexed_dataset.doc_idx[-1]
|
||||
logging.info(' > finished creating indexed dataset in {:4f} ' 'seconds'.format(time.time() - start_time))
|
||||
|
||||
logging.info(' > indexed dataset stats:')
|
||||
logging.info(' number of documents: {}'.format(indexed_dataset.doc_idx.shape[0] - 1))
|
||||
logging.info(' number of sentences: {}'.format(indexed_dataset.sizes.shape[0]))
|
||||
|
||||
return indexed_dataset
|
||||
|
||||
|
||||
def get_train_valid_test_split_(splits_string, size):
|
||||
""" Get dataset splits from comma or '/' separated string list."""
|
||||
|
||||
splits = []
|
||||
if splits_string.find(',') != -1:
|
||||
splits = [float(s) for s in splits_string.split(',')]
|
||||
elif splits_string.find('/') != -1:
|
||||
splits = [float(s) for s in splits_string.split('/')]
|
||||
else:
|
||||
splits = [float(splits_string)]
|
||||
while len(splits) < 3:
|
||||
splits.append(0.0)
|
||||
splits = splits[:3]
|
||||
splits_sum = sum(splits)
|
||||
assert splits_sum > 0.0
|
||||
splits = [split / splits_sum for split in splits]
|
||||
splits_index = [0]
|
||||
for index, split in enumerate(splits):
|
||||
splits_index.append(splits_index[index] + int(round(split * float(size))))
|
||||
diff = splits_index[-1] - size
|
||||
for index in range(1, len(splits_index)):
|
||||
splits_index[index] -= diff
|
||||
assert len(splits_index) == 4
|
||||
assert splits_index[-1] == size
|
||||
return splits_index
|
||||
|
||||
|
||||
def get_samples_mapping(
|
||||
indexed_dataset, data_prefix, num_epochs, max_num_samples, max_seq_length, short_seq_prob, seed, name, binary_head
|
||||
):
|
||||
"""Get a list that maps a sample index to a starting sentence index, end sentence index, and length"""
|
||||
|
||||
if not num_epochs:
|
||||
if not max_num_samples:
|
||||
raise ValueError("Need to specify either max_num_samples " "or num_epochs")
|
||||
num_epochs = np.iinfo(np.int32).max - 1
|
||||
if not max_num_samples:
|
||||
max_num_samples = np.iinfo(np.int64).max - 1
|
||||
|
||||
# Filename of the index mapping
|
||||
indexmap_filename = data_prefix
|
||||
indexmap_filename += '_{}_indexmap'.format(name)
|
||||
if num_epochs != (np.iinfo(np.int32).max - 1):
|
||||
indexmap_filename += '_{}ep'.format(num_epochs)
|
||||
if max_num_samples != (np.iinfo(np.int64).max - 1):
|
||||
indexmap_filename += '_{}mns'.format(max_num_samples)
|
||||
indexmap_filename += '_{}msl'.format(max_seq_length)
|
||||
indexmap_filename += '_{:0.2f}ssp'.format(short_seq_prob)
|
||||
indexmap_filename += '_{}s'.format(seed)
|
||||
indexmap_filename += '.npy'
|
||||
|
||||
# Build the indexed mapping if not exist.
|
||||
if torch.distributed.get_rank() == 0 and not os.path.isfile(indexmap_filename):
|
||||
print(
|
||||
' > WARNING: could not find index map file {}, building '
|
||||
'the indices on rank 0 ...'.format(indexmap_filename)
|
||||
)
|
||||
|
||||
# Make sure the types match the helpers input types.
|
||||
assert indexed_dataset.doc_idx.dtype == np.int64
|
||||
assert indexed_dataset.sizes.dtype == np.int32
|
||||
|
||||
# Build samples mapping
|
||||
verbose = torch.distributed.get_rank() == 0
|
||||
start_time = time.time()
|
||||
logging.info(' > building samples index mapping for {} ...'.format(name))
|
||||
# First compile and then import.
|
||||
|
||||
try:
|
||||
if is_global_rank_zero():
|
||||
compile_helper()
|
||||
from nemo.collections.nlp.data.language_modeling.megatron import helpers
|
||||
except:
|
||||
raise Exception(f'Could not compile helpers.')
|
||||
samples_mapping = helpers.build_mapping(
|
||||
indexed_dataset.doc_idx,
|
||||
indexed_dataset.sizes,
|
||||
num_epochs,
|
||||
max_num_samples,
|
||||
max_seq_length,
|
||||
short_seq_prob,
|
||||
seed,
|
||||
verbose,
|
||||
2 if binary_head else 1,
|
||||
)
|
||||
logging.info(' > done building samples index maping')
|
||||
np.save(indexmap_filename, samples_mapping, allow_pickle=True)
|
||||
logging.info(' > saved the index mapping in {}'.format(indexmap_filename))
|
||||
# Make sure all the ranks have built the mapping
|
||||
logging.info(
|
||||
' > elasped time to build and save samples mapping ' '(seconds): {:4f}'.format(time.time() - start_time)
|
||||
)
|
||||
# This should be a barrier but nccl barrier assumes
|
||||
# device_index=rank which is not the case for model
|
||||
# parallel case
|
||||
counts = torch.cuda.LongTensor([1])
|
||||
torch.distributed.all_reduce(counts, group=parallel_state.get_data_parallel_group())
|
||||
torch.distributed.all_reduce(counts, group=parallel_state.get_pipeline_model_parallel_group())
|
||||
assert counts[0].item() == (
|
||||
torch.distributed.get_world_size()
|
||||
// torch.distributed.get_world_size(group=parallel_state.get_tensor_model_parallel_group())
|
||||
)
|
||||
|
||||
# Load indexed dataset.
|
||||
logging.info(' > loading indexed mapping from {}'.format(indexmap_filename))
|
||||
start_time = time.time()
|
||||
samples_mapping = np.load(indexmap_filename, allow_pickle=True, mmap_mode='r')
|
||||
logging.info(' loaded indexed file in {:3.3f} seconds'.format(time.time() - start_time))
|
||||
logging.info(' total number of samples: {}'.format(samples_mapping.shape[0]))
|
||||
|
||||
return samples_mapping
|
|
@ -0,0 +1,449 @@
|
|||
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""GPT style dataset."""
|
||||
|
||||
import os
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from apex.transformer import parallel_state
|
||||
|
||||
from nemo.collections.nlp.data.language_modeling.megatron.blendable_dataset import BlendableDataset
|
||||
from nemo.collections.nlp.data.language_modeling.megatron.dataset_utils import (
|
||||
get_datasets_weights_and_num_samples,
|
||||
get_train_valid_test_split_,
|
||||
)
|
||||
from nemo.collections.nlp.data.language_modeling.megatron.indexed_dataset import make_dataset as make_indexed_dataset
|
||||
from nemo.collections.nlp.data.language_modeling.megatron.megatron_dataset import MegatronDataset
|
||||
from nemo.utils import logging
|
||||
from nemo.utils.get_rank import is_global_rank_zero
|
||||
|
||||
|
||||
def build_train_valid_test_datasets(
|
||||
cfg, trainer, data_prefix, data_impl, splits_string, train_valid_test_num_samples, seq_length, seed, skip_warmup
|
||||
):
|
||||
"""Build train, valid, and test datasets."""
|
||||
|
||||
# Single dataset.
|
||||
if len(data_prefix) == 1:
|
||||
return _build_train_valid_test_datasets(
|
||||
cfg,
|
||||
trainer,
|
||||
data_prefix[0],
|
||||
data_impl,
|
||||
splits_string,
|
||||
train_valid_test_num_samples,
|
||||
seq_length,
|
||||
seed,
|
||||
skip_warmup,
|
||||
)
|
||||
|
||||
# Blending dataset.
|
||||
# Parse the values.
|
||||
output = get_datasets_weights_and_num_samples(data_prefix, train_valid_test_num_samples)
|
||||
prefixes, weights, datasets_train_valid_test_num_samples = output
|
||||
|
||||
# Build individual datasets.
|
||||
train_datasets = []
|
||||
valid_datasets = []
|
||||
test_datasets = []
|
||||
for i in range(len(prefixes)):
|
||||
train_ds, valid_ds, test_ds = _build_train_valid_test_datasets(
|
||||
cfg,
|
||||
trainer,
|
||||
prefixes[i],
|
||||
data_impl,
|
||||
splits_string,
|
||||
datasets_train_valid_test_num_samples[i],
|
||||
seq_length,
|
||||
seed,
|
||||
skip_warmup,
|
||||
)
|
||||
if train_ds:
|
||||
train_datasets.append(train_ds)
|
||||
if valid_ds:
|
||||
valid_datasets.append(valid_ds)
|
||||
if test_ds:
|
||||
test_datasets.append(test_ds)
|
||||
|
||||
# Blend.
|
||||
blending_train_dataset = None
|
||||
if train_datasets:
|
||||
blending_train_dataset = BlendableDataset(train_datasets, weights)
|
||||
blending_valid_dataset = None
|
||||
if valid_datasets:
|
||||
blending_valid_dataset = BlendableDataset(valid_datasets, weights)
|
||||
blending_test_dataset = None
|
||||
if test_datasets:
|
||||
blending_test_dataset = BlendableDataset(test_datasets, weights)
|
||||
|
||||
return (blending_train_dataset, blending_valid_dataset, blending_test_dataset)
|
||||
|
||||
|
||||
def _build_train_valid_test_datasets(
|
||||
cfg, trainer, data_prefix, data_impl, splits_string, train_valid_test_num_samples, seq_length, seed, skip_warmup
|
||||
):
|
||||
"""Build train, valid, and test datasets."""
|
||||
|
||||
# Indexed dataset.
|
||||
indexed_dataset = get_indexed_dataset_(data_prefix, data_impl, skip_warmup)
|
||||
|
||||
total_num_of_documents = indexed_dataset.sizes.shape[0]
|
||||
splits = get_train_valid_test_split_(splits_string, total_num_of_documents)
|
||||
|
||||
# Print stats about the splits.
|
||||
logging.info(' > dataset split:')
|
||||
|
||||
def print_split_stats(name, index):
|
||||
logging.info(' {}:'.format(name))
|
||||
logging.info(
|
||||
' document indices in [{}, {}) total of {} '
|
||||
'documents'.format(splits[index], splits[index + 1], splits[index + 1] - splits[index])
|
||||
)
|
||||
|
||||
print_split_stats('train', 0)
|
||||
print_split_stats('validation', 1)
|
||||
print_split_stats('test', 2)
|
||||
|
||||
def build_dataset(index, name):
|
||||
dataset = None
|
||||
if splits[index + 1] > splits[index]:
|
||||
documents = np.arange(start=splits[index], stop=splits[index + 1], step=1, dtype=np.int32)
|
||||
dataset = GPTDataset(
|
||||
cfg,
|
||||
trainer,
|
||||
name,
|
||||
data_prefix,
|
||||
documents,
|
||||
indexed_dataset,
|
||||
train_valid_test_num_samples[index],
|
||||
seq_length,
|
||||
seed,
|
||||
)
|
||||
return dataset
|
||||
|
||||
train_dataset = build_dataset(0, 'train')
|
||||
valid_dataset = build_dataset(1, 'valid')
|
||||
test_dataset = build_dataset(2, 'test')
|
||||
|
||||
return (train_dataset, valid_dataset, test_dataset)
|
||||
|
||||
|
||||
def get_indexed_dataset_(data_prefix, data_impl, skip_warmup):
|
||||
"""Build indexed dataset."""
|
||||
logging.info(' > building dataset index ...')
|
||||
|
||||
start_time = time.time()
|
||||
indexed_dataset = make_indexed_dataset(data_prefix, data_impl, skip_warmup)
|
||||
logging.info(' > finished creating indexed dataset in {:4f} ' 'seconds'.format(time.time() - start_time))
|
||||
logging.info(' number of documents: {}'.format(indexed_dataset.sizes.shape[0]))
|
||||
|
||||
return indexed_dataset
|
||||
|
||||
|
||||
class GPTDataset(MegatronDataset):
|
||||
def __init__(self, cfg, trainer, name, data_prefix, documents, indexed_dataset, num_samples, seq_length, seed):
|
||||
|
||||
super().__init__(cfg, trainer=trainer)
|
||||
self.name = name
|
||||
self.indexed_dataset = indexed_dataset
|
||||
|
||||
# Checks
|
||||
assert np.min(documents) >= 0
|
||||
assert np.max(documents) < indexed_dataset.sizes.shape[0]
|
||||
|
||||
# Build index mappings.
|
||||
self.doc_idx, self.sample_idx, self.shuffle_idx = _build_index_mappings(
|
||||
self.name, data_prefix, documents, self.indexed_dataset.sizes, num_samples, seq_length, seed
|
||||
)
|
||||
|
||||
def __len__(self):
|
||||
# -1 is due to data structure used to retieve the index:
|
||||
# sample i --> [sample_idx[i], sample_idx[i+1])
|
||||
return self.sample_idx.shape[0] - 1
|
||||
|
||||
def __getitem__(self, idx):
|
||||
# Get the shuffled index.
|
||||
idx = self.shuffle_idx[idx]
|
||||
# Start and end documents and offsets.
|
||||
doc_index_f = self.sample_idx[idx][0]
|
||||
doc_index_l = self.sample_idx[idx + 1][0]
|
||||
offset_f = self.sample_idx[idx][1]
|
||||
offset_l = self.sample_idx[idx + 1][1]
|
||||
# If we are within the same document, just extract the chunk.
|
||||
if doc_index_f == doc_index_l:
|
||||
sample = self.indexed_dataset.get(
|
||||
self.doc_idx[doc_index_f], offset=offset_f, length=offset_l - offset_f + 1
|
||||
)
|
||||
else:
|
||||
# Otherwise, get the rest of the initial document.
|
||||
sample_list = [self.indexed_dataset.get(self.doc_idx[doc_index_f], offset=offset_f)]
|
||||
# Loop over all in between documents and add the entire document.
|
||||
for i in range(doc_index_f + 1, doc_index_l):
|
||||
sample_list.append(self.indexed_dataset.get(self.doc_idx[i]))
|
||||
# And finally add the relevant portion of last document.
|
||||
sample_list.append(self.indexed_dataset.get(self.doc_idx[doc_index_l], length=offset_l + 1))
|
||||
sample = np.concatenate(sample_list)
|
||||
|
||||
return {'text': np.array(sample, dtype=np.int64)}
|
||||
|
||||
|
||||
def _build_index_mappings(name, data_prefix, documents, sizes, num_samples, seq_length, seed):
|
||||
"""Build doc-idx, sample-idx, and shuffle-idx.
|
||||
doc-idx: is an array (ordered) of documents to be used in training.
|
||||
sample-idx: is the start document index and document offset for each
|
||||
training sample.
|
||||
shuffle-idx: maps the sample index into a random index into sample-idx.
|
||||
"""
|
||||
# Number of tokens in each epoch and number of required epochs.
|
||||
tokens_per_epoch = _num_tokens(documents, sizes)
|
||||
num_epochs = _num_epochs(tokens_per_epoch, seq_length, num_samples)
|
||||
# rng state
|
||||
np_rng = np.random.RandomState(seed=seed)
|
||||
|
||||
# Filename of the index mappings.
|
||||
_filename = data_prefix
|
||||
_filename += '_{}_indexmap'.format(name)
|
||||
_filename += '_{}ns'.format(num_samples)
|
||||
_filename += '_{}sl'.format(seq_length)
|
||||
_filename += '_{}s'.format(seed)
|
||||
doc_idx_filename = _filename + '_doc_idx.npy'
|
||||
sample_idx_filename = _filename + '_sample_idx.npy'
|
||||
shuffle_idx_filename = _filename + '_shuffle_idx.npy'
|
||||
|
||||
# Build the indexed mapping if not exist.
|
||||
if torch.distributed.get_rank() == 0:
|
||||
if (
|
||||
(not os.path.isfile(doc_idx_filename))
|
||||
or (not os.path.isfile(sample_idx_filename))
|
||||
or (not os.path.isfile(shuffle_idx_filename))
|
||||
):
|
||||
|
||||
logging.info(' > WARNING: could not find index map files, building ' 'the indices on rank 0 ...')
|
||||
|
||||
# For the last epoch, decide whether include the entire epoch
|
||||
# in the global shuffle or not.
|
||||
|
||||
# If we need only one epoch, then separating last epoch does
|
||||
# not mean anything.
|
||||
if num_epochs == 1:
|
||||
separate_last_epoch = False
|
||||
print(' > only one epoch required, setting ' 'separate_last_epoch to False', flush=True)
|
||||
|
||||
else:
|
||||
# Get the number of samples for the last epoch
|
||||
num_samples_from_epochs_minus_one = ((num_epochs - 1) * tokens_per_epoch - 1) // seq_length
|
||||
last_epoch_num_samples = num_samples - num_samples_from_epochs_minus_one
|
||||
assert last_epoch_num_samples >= 0, 'last epoch number of samples should be non-negative.'
|
||||
num_samples_per_epoch = (tokens_per_epoch - 1) // seq_length
|
||||
assert last_epoch_num_samples < (
|
||||
num_samples_per_epoch + 1
|
||||
), 'last epoch number of samples exceeded max value.'
|
||||
# If we have less than 80% of the samples for the last epoch,
|
||||
# seperate out the epoch and treat it differently.
|
||||
# Note: the 80% number is just based on common sense and can
|
||||
# be adjusted if needed.
|
||||
separate_last_epoch = last_epoch_num_samples < int(0.80 * num_samples_per_epoch)
|
||||
if separate_last_epoch:
|
||||
string = (
|
||||
' > last epoch number of samples ({}) is smaller '
|
||||
'than 80% of number of samples per epoch ({}), '
|
||||
'setting separate_last_epoch to True'
|
||||
)
|
||||
else:
|
||||
string = (
|
||||
' > last epoch number of samples ({}) is larger '
|
||||
'than 80% of number of samples per epoch ({}), '
|
||||
'setting separate_last_epoch to False'
|
||||
)
|
||||
print(string.format(last_epoch_num_samples, num_samples_per_epoch), flush=True)
|
||||
|
||||
# doc-idx.
|
||||
start_time = time.time()
|
||||
doc_idx = _build_doc_idx(documents, num_epochs, np_rng, separate_last_epoch)
|
||||
np.save(doc_idx_filename, doc_idx, allow_pickle=True)
|
||||
logging.info(
|
||||
' > elasped time to build and save doc-idx mapping '
|
||||
'(seconds): {:4f}'.format(time.time() - start_time)
|
||||
)
|
||||
# sample-idx.
|
||||
start_time = time.time()
|
||||
# Use C++ implementation for speed.
|
||||
# First compile and then import.
|
||||
assert doc_idx.dtype == np.int32
|
||||
assert sizes.dtype == np.int32
|
||||
try:
|
||||
if is_global_rank_zero():
|
||||
from nemo.collections.nlp.data.language_modeling.megatron.dataset_utils import compile_helper
|
||||
|
||||
compile_helper()
|
||||
from nemo.collections.nlp.data.language_modeling.megatron import helpers
|
||||
except:
|
||||
raise Exception(f'Could not compile helpers.')
|
||||
|
||||
sample_idx = helpers.build_sample_idx(sizes, doc_idx, seq_length, num_epochs, tokens_per_epoch)
|
||||
# sample_idx = _build_sample_idx(sizes, doc_idx, seq_length,
|
||||
# num_epochs, tokens_per_epoch)
|
||||
np.save(sample_idx_filename, sample_idx, allow_pickle=True)
|
||||
logging.info(
|
||||
' > elasped time to build and save sample-idx mapping '
|
||||
'(seconds): {:4f}'.format(time.time() - start_time)
|
||||
)
|
||||
# shuffle-idx.
|
||||
start_time = time.time()
|
||||
# -1 is due to data structure used to retieve the index:
|
||||
# sample i --> [sample_idx[i], sample_idx[i+1])
|
||||
if separate_last_epoch:
|
||||
num_samples_ = num_samples_from_epochs_minus_one
|
||||
else:
|
||||
num_samples_ = sample_idx.shape[0] - 1
|
||||
shuffle_idx = _build_shuffle_idx(num_samples_, sample_idx.shape[0] - 1, np_rng)
|
||||
np.save(shuffle_idx_filename, shuffle_idx, allow_pickle=True)
|
||||
logging.info(
|
||||
' > elasped time to build and save shuffle-idx mapping'
|
||||
' (seconds): {:4f}'.format(time.time() - start_time)
|
||||
)
|
||||
|
||||
torch.distributed.barrier()
|
||||
|
||||
counts = torch.cuda.LongTensor([1])
|
||||
torch.distributed.all_reduce(counts, group=parallel_state.get_data_parallel_group())
|
||||
torch.distributed.all_reduce(counts, group=parallel_state.get_pipeline_model_parallel_group())
|
||||
assert counts[0].item() == (
|
||||
torch.distributed.get_world_size()
|
||||
// torch.distributed.get_world_size(group=parallel_state.get_tensor_model_parallel_group())
|
||||
)
|
||||
|
||||
# Load mappings.
|
||||
start_time = time.time()
|
||||
logging.info(' > loading doc-idx mapping from {}'.format(doc_idx_filename))
|
||||
doc_idx = np.load(doc_idx_filename, allow_pickle=True, mmap_mode='r')
|
||||
logging.info(' > loading sample-idx mapping from {}'.format(sample_idx_filename))
|
||||
sample_idx = np.load(sample_idx_filename, allow_pickle=True, mmap_mode='r')
|
||||
logging.info(' > loading shuffle-idx mapping from {}'.format(shuffle_idx_filename))
|
||||
shuffle_idx = np.load(shuffle_idx_filename, allow_pickle=True, mmap_mode='r')
|
||||
logging.info(' loaded indexed file in {:3.3f} seconds'.format(time.time() - start_time))
|
||||
logging.info(' total number of samples: {}'.format(sample_idx.shape[0]))
|
||||
logging.info(' total number of epochs: {}'.format(num_epochs))
|
||||
|
||||
return doc_idx, sample_idx, shuffle_idx
|
||||
|
||||
|
||||
def _num_tokens(documents, sizes):
|
||||
"""Total number of tokens in the dataset."""
|
||||
return np.sum(sizes[documents])
|
||||
|
||||
|
||||
def _num_epochs(tokens_per_epoch, seq_length, num_samples):
|
||||
"""Based on number of samples and sequence lenght, calculate how many
|
||||
epochs will be needed."""
|
||||
num_epochs = 0
|
||||
total_tokens = 0
|
||||
while True:
|
||||
num_epochs += 1
|
||||
total_tokens += tokens_per_epoch
|
||||
# -1 is because we need to retrieve seq_length + 1 token each time
|
||||
# but the last token will overlap with the first token of the next
|
||||
# sample except for the last sample.
|
||||
if ((total_tokens - 1) // seq_length) >= num_samples:
|
||||
return num_epochs
|
||||
|
||||
|
||||
def _build_doc_idx(documents, num_epochs, np_rng, separate_last_epoch):
|
||||
"""Build an array with length = number-of-epochs * number-of-dcuments.
|
||||
Each index is mapped to a corresponding document."""
|
||||
if not separate_last_epoch or num_epochs == 1:
|
||||
doc_idx = np.mgrid[0:num_epochs, 0 : len(documents)][1]
|
||||
doc_idx[:] = documents
|
||||
doc_idx = doc_idx.reshape(-1)
|
||||
doc_idx = doc_idx.astype(np.int32)
|
||||
np_rng.shuffle(doc_idx)
|
||||
return doc_idx
|
||||
|
||||
doc_idx_first = _build_doc_idx(documents, num_epochs - 1, np_rng, False)
|
||||
doc_idx_last = _build_doc_idx(documents, 1, np_rng, False)
|
||||
return np.concatenate((doc_idx_first, doc_idx_last))
|
||||
|
||||
|
||||
def _build_sample_idx(sizes, doc_idx, seq_length, num_epochs, tokens_per_epoch):
|
||||
"""Sample index mapping is a 2D array with sizes
|
||||
[number-of-samples + 1, 2] where [..., 0] contains
|
||||
the index into `doc_idx` and [..., 1] is the
|
||||
starting offset in that document."""
|
||||
|
||||
# Total number of samples. For -1 see comments in `_num_epochs`.
|
||||
num_samples = (num_epochs * tokens_per_epoch - 1) // seq_length
|
||||
sample_idx = np.zeros([num_samples + 1, 2], dtype=np.int32)
|
||||
|
||||
# Index into sample_idx.
|
||||
sample_index = 0
|
||||
# Index into doc_idx.
|
||||
doc_idx_index = 0
|
||||
# Begining offset for each document.
|
||||
doc_offset = 0
|
||||
# Start with first document and no offset.
|
||||
sample_idx[sample_index][0] = doc_idx_index
|
||||
sample_idx[sample_index][1] = doc_offset
|
||||
sample_index += 1
|
||||
while sample_index <= num_samples:
|
||||
# Start with a fresh sequence.
|
||||
remaining_seq_length = seq_length + 1
|
||||
while remaining_seq_length != 0:
|
||||
# Get the document length.
|
||||
doc_id = doc_idx[doc_idx_index]
|
||||
doc_length = sizes[doc_id] - doc_offset
|
||||
# And add it to the current sequence.
|
||||
remaining_seq_length -= doc_length
|
||||
# If we have more than a full sequence, adjust offset and set
|
||||
# remaining length to zero so we return from the while loop.
|
||||
# Note that -1 here is for the same reason we have -1 in
|
||||
# `_num_epochs` calculations.
|
||||
if remaining_seq_length <= 0:
|
||||
doc_offset += remaining_seq_length + doc_length - 1
|
||||
remaining_seq_length = 0
|
||||
else:
|
||||
# Otherwise, start from the begining of the next document.
|
||||
doc_idx_index += 1
|
||||
doc_offset = 0
|
||||
# Record the sequence.
|
||||
sample_idx[sample_index][0] = doc_idx_index
|
||||
sample_idx[sample_index][1] = doc_offset
|
||||
sample_index += 1
|
||||
|
||||
return sample_idx
|
||||
|
||||
|
||||
def _build_shuffle_idx(num_samples, total_size, np_rng):
|
||||
"""Build the range [0, size) and shuffle."""
|
||||
print(
|
||||
' > building shuffle index with split [0, {}) and [{}, {}) '
|
||||
'...'.format(num_samples, num_samples, total_size),
|
||||
flush=True,
|
||||
)
|
||||
|
||||
dtype_ = np.uint32
|
||||
if total_size >= (np.iinfo(np.uint32).max - 1):
|
||||
dtype_ = np.int64
|
||||
|
||||
shuffle_idx_first = np.arange(start=0, stop=num_samples, step=1, dtype=dtype_)
|
||||
np_rng.shuffle(shuffle_idx_first)
|
||||
if num_samples == total_size:
|
||||
return shuffle_idx_first
|
||||
|
||||
shuffle_idx_last = np.arange(start=num_samples, stop=total_size, step=1, dtype=dtype_)
|
||||
np_rng.shuffle(shuffle_idx_last)
|
||||
|
||||
return np.concatenate((shuffle_idx_first, shuffle_idx_last))
|
|
@ -0,0 +1,37 @@
|
|||
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from typing import Dict
|
||||
|
||||
import torch
|
||||
from torch.utils.data.dataset import Dataset
|
||||
|
||||
|
||||
class GPTRequestDataset(Dataset):
|
||||
def __init__(self, request: Dict, tokenizer) -> None:
|
||||
super().__init__()
|
||||
self.request = request
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
# tokenize prompt
|
||||
self.request['tokenized_prompt'] = self.tokenizer.text_to_tokens(request['prompt'])
|
||||
tokens = self.tokenizer.text_to_ids(request['prompt'])
|
||||
self.request['tokens'] = torch.tensor(tokens)
|
||||
|
||||
def __len__(self):
|
||||
return 1
|
||||
|
||||
def __getitem__(self, index):
|
||||
return self.request
|
716
nemo/collections/nlp/data/language_modeling/megatron/helpers.cpp
Normal file
716
nemo/collections/nlp/data/language_modeling/megatron/helpers.cpp
Normal file
|
@ -0,0 +1,716 @@
|
|||
/*
|
||||
Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
/* Helper methods for fast index mapping builds */
|
||||
|
||||
#include <algorithm>
|
||||
#include <iostream>
|
||||
#include <limits>
|
||||
#include <math.h>
|
||||
#include <stdexcept>
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/numpy.h>
|
||||
#include <random>
|
||||
|
||||
namespace py = pybind11;
|
||||
using namespace std;
|
||||
|
||||
const int32_t LONG_SENTENCE_LEN = 512;
|
||||
|
||||
|
||||
void build_blending_indices(py::array_t<uint8_t>& dataset_index,
|
||||
py::array_t<int64_t>& dataset_sample_index,
|
||||
const py::array_t<double>& weights,
|
||||
const int32_t num_datasets,
|
||||
const int64_t size, const bool verbose) {
|
||||
/* Given multiple datasets and a weighting array, build samples
|
||||
such that it follows those wieghts.*/
|
||||
|
||||
if (verbose) {
|
||||
std::cout << "> building indices for blendable datasets ..." << std::endl;
|
||||
}
|
||||
|
||||
// Get the pointer access without the checks.
|
||||
auto dataset_index_ptr = dataset_index.mutable_unchecked<1>();
|
||||
auto dataset_sample_index_ptr = dataset_sample_index.mutable_unchecked<1>();
|
||||
auto weights_ptr = weights.unchecked<1>();
|
||||
|
||||
// Initialize buffer for number of samples used for each dataset.
|
||||
int64_t current_samples[num_datasets];
|
||||
for(int64_t i = 0; i < num_datasets; ++i) {
|
||||
current_samples[i] = 0;
|
||||
}
|
||||
|
||||
// For each sample:
|
||||
for(int64_t sample_idx = 0; sample_idx < size; ++sample_idx) {
|
||||
|
||||
// Determine where the max error in sampling is happening.
|
||||
auto sample_idx_double = std::max(static_cast<double>(sample_idx), 1.0);
|
||||
int64_t max_error_index = 0;
|
||||
double max_error = weights_ptr[0] * sample_idx_double -
|
||||
static_cast<double>(current_samples[0]);
|
||||
for (int64_t dataset_idx = 1; dataset_idx < num_datasets; ++dataset_idx) {
|
||||
double error = weights_ptr[dataset_idx] * sample_idx_double -
|
||||
static_cast<double>(current_samples[dataset_idx]);
|
||||
if (error > max_error) {
|
||||
max_error = error;
|
||||
max_error_index = dataset_idx;
|
||||
}
|
||||
}
|
||||
|
||||
// Populate the indices.
|
||||
dataset_index_ptr[sample_idx] = static_cast<uint8_t>(max_error_index);
|
||||
dataset_sample_index_ptr[sample_idx] = current_samples[max_error_index];
|
||||
|
||||
// Update the total samples.
|
||||
current_samples[max_error_index] += 1;
|
||||
|
||||
}
|
||||
|
||||
// print info
|
||||
if (verbose) {
|
||||
std::cout << " > sample ratios:" << std::endl;
|
||||
for (int64_t dataset_idx = 0; dataset_idx < num_datasets; ++dataset_idx) {
|
||||
auto ratio = static_cast<double>(current_samples[dataset_idx]) /
|
||||
static_cast<double>(size);
|
||||
std::cout << " dataset " << dataset_idx << ", input: " <<
|
||||
weights_ptr[dataset_idx] << ", achieved: " << ratio << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
||||
py::array build_sample_idx(const py::array_t<int32_t>& sizes_,
|
||||
const py::array_t<int32_t>& doc_idx_,
|
||||
const int32_t seq_length,
|
||||
const int32_t num_epochs,
|
||||
const int64_t tokens_per_epoch) {
|
||||
/* Sample index (sample_idx) is used for gpt2 like dataset for which
|
||||
the documents are flattened and the samples are built based on this
|
||||
1-D flatten array. It is a 2D array with sizes [number-of-samples + 1, 2]
|
||||
where [..., 0] contains the index into `doc_idx` and [..., 1] is the
|
||||
starting offset in that document.*/
|
||||
|
||||
// Consistency checks.
|
||||
assert(seq_length > 1);
|
||||
assert(num_epochs > 0);
|
||||
assert(tokens_per_epoch > 1);
|
||||
|
||||
// Remove bound checks.
|
||||
auto sizes = sizes_.unchecked<1>();
|
||||
auto doc_idx = doc_idx_.unchecked<1>();
|
||||
|
||||
// Mapping and it's length (1D).
|
||||
int64_t num_samples = (num_epochs * tokens_per_epoch - 1) / seq_length;
|
||||
int32_t* sample_idx = new int32_t[2*(num_samples+1)];
|
||||
|
||||
cout << " using:" << endl << std::flush;
|
||||
cout << " number of documents: " <<
|
||||
doc_idx_.shape(0) / num_epochs << endl << std::flush;
|
||||
cout << " number of epochs: " << num_epochs <<
|
||||
endl << std::flush;
|
||||
cout << " sequence length: " << seq_length <<
|
||||
endl << std::flush;
|
||||
cout << " total number of samples: " << num_samples <<
|
||||
endl << std::flush;
|
||||
|
||||
// Index into sample_idx.
|
||||
int64_t sample_index = 0;
|
||||
// Index into doc_idx.
|
||||
int64_t doc_idx_index = 0;
|
||||
// Begining offset for each document.
|
||||
int32_t doc_offset = 0;
|
||||
// Start with first document and no offset.
|
||||
sample_idx[2 * sample_index] = doc_idx_index;
|
||||
sample_idx[2 * sample_index + 1] = doc_offset;
|
||||
++sample_index;
|
||||
|
||||
while (sample_index <= num_samples) {
|
||||
// Start with a fresh sequence.
|
||||
int32_t remaining_seq_length = seq_length + 1;
|
||||
while (remaining_seq_length != 0) {
|
||||
// Get the document length.
|
||||
auto doc_id = doc_idx[doc_idx_index];
|
||||
auto doc_length = sizes[doc_id] - doc_offset;
|
||||
// And add it to the current sequence.
|
||||
remaining_seq_length -= doc_length;
|
||||
// If we have more than a full sequence, adjust offset and set
|
||||
// remaining length to zero so we return from the while loop.
|
||||
// Note that -1 here is for the same reason we have -1 in
|
||||
// `_num_epochs` calculations.
|
||||
if (remaining_seq_length <= 0) {
|
||||
doc_offset += (remaining_seq_length + doc_length - 1);
|
||||
remaining_seq_length = 0;
|
||||
} else {
|
||||
// Otherwise, start from the begining of the next document.
|
||||
++doc_idx_index;
|
||||
doc_offset = 0;
|
||||
}
|
||||
}
|
||||
// Record the sequence.
|
||||
sample_idx[2 * sample_index] = doc_idx_index;
|
||||
sample_idx[2 * sample_index + 1] = doc_offset;
|
||||
++sample_index;
|
||||
}
|
||||
|
||||
// Method to deallocate memory.
|
||||
py::capsule free_when_done(sample_idx, [](void *mem_) {
|
||||
int32_t *mem = reinterpret_cast<int32_t*>(mem_);
|
||||
delete[] mem;
|
||||
});
|
||||
|
||||
// Return the numpy array.
|
||||
const auto byte_size = sizeof(int32_t);
|
||||
return py::array(std::vector<int64_t>{num_samples+1, 2}, // shape
|
||||
{2*byte_size, byte_size}, // C-style contiguous strides
|
||||
sample_idx, // the data pointer
|
||||
free_when_done); // numpy array references
|
||||
|
||||
}
|
||||
|
||||
|
||||
inline int32_t get_target_sample_len(const int32_t short_seq_ratio,
|
||||
const int32_t max_length,
|
||||
std::mt19937& rand32_gen) {
|
||||
/* Training sample length. */
|
||||
if (short_seq_ratio == 0) {
|
||||
return max_length;
|
||||
}
|
||||
const auto random_number = rand32_gen();
|
||||
if ((random_number % short_seq_ratio) == 0) {
|
||||
return 2 + random_number % (max_length - 1);
|
||||
}
|
||||
return max_length;
|
||||
}
|
||||
|
||||
|
||||
template<typename DocIdx>
|
||||
py::array build_mapping_impl(const py::array_t<int64_t>& docs_,
|
||||
const py::array_t<int32_t>& sizes_,
|
||||
const int32_t num_epochs,
|
||||
const uint64_t max_num_samples,
|
||||
const int32_t max_seq_length,
|
||||
const double short_seq_prob,
|
||||
const int32_t seed,
|
||||
const bool verbose,
|
||||
const int32_t min_num_sent) {
|
||||
/* Build a mapping of (start-index, end-index, sequence-length) where
|
||||
start and end index are the indices of the sentences in the sample
|
||||
and sequence-length is the target sequence length.
|
||||
*/
|
||||
|
||||
// Consistency checks.
|
||||
assert(num_epochs > 0);
|
||||
assert(max_seq_length > 1);
|
||||
assert(short_seq_prob >= 0.0);
|
||||
assert(short_seq_prob <= 1.0);
|
||||
assert(seed > 0);
|
||||
|
||||
// Remove bound checks.
|
||||
auto docs = docs_.unchecked<1>();
|
||||
auto sizes = sizes_.unchecked<1>();
|
||||
|
||||
// For efficiency, convert probability to ratio. Note: rand() generates int.
|
||||
int32_t short_seq_ratio = 0;
|
||||
if (short_seq_prob > 0) {
|
||||
short_seq_ratio = static_cast<int32_t>(round(1.0 / short_seq_prob));
|
||||
}
|
||||
|
||||
if (verbose) {
|
||||
const auto sent_start_index = docs[0];
|
||||
const auto sent_end_index = docs[docs_.shape(0) - 1];
|
||||
const auto num_sentences = sent_end_index - sent_start_index;
|
||||
cout << " using:" << endl << std::flush;
|
||||
cout << " number of documents: " << docs_.shape(0) - 1 <<
|
||||
endl << std::flush;
|
||||
cout << " sentences range: [" << sent_start_index <<
|
||||
", " << sent_end_index << ")" << endl << std::flush;
|
||||
cout << " total number of sentences: " << num_sentences <<
|
||||
endl << std::flush;
|
||||
cout << " number of epochs: " << num_epochs <<
|
||||
endl << std::flush;
|
||||
cout << " maximum number of samples: " << max_num_samples <<
|
||||
endl << std::flush;
|
||||
cout << " maximum sequence length: " << max_seq_length <<
|
||||
endl << std::flush;
|
||||
cout << " short sequence probability: " << short_seq_prob <<
|
||||
endl << std::flush;
|
||||
cout << " short sequence ration (1/prob): " << short_seq_ratio <<
|
||||
endl << std::flush;
|
||||
cout << " seed: " << seed << endl <<
|
||||
std::flush;
|
||||
}
|
||||
|
||||
// Mapping and it's length (1D).
|
||||
int64_t num_samples = -1;
|
||||
DocIdx* maps = NULL;
|
||||
|
||||
// Perform two iterations, in the first iteration get the size
|
||||
// and allocate memory and in the second iteration populate the map.
|
||||
bool second = false;
|
||||
for (int32_t iteration=0; iteration<2; ++iteration) {
|
||||
|
||||
// Set the seed so both iterations produce the same results.
|
||||
std::mt19937 rand32_gen(seed);
|
||||
|
||||
// Set the flag on second iteration.
|
||||
second = (iteration == 1);
|
||||
|
||||
// Counters:
|
||||
uint64_t empty_docs = 0;
|
||||
uint64_t one_sent_docs = 0;
|
||||
uint64_t long_sent_docs = 0;
|
||||
|
||||
// Current map index.
|
||||
uint64_t map_index = 0;
|
||||
|
||||
// For each epoch:
|
||||
for (int32_t epoch=0; epoch<num_epochs; ++epoch) {
|
||||
if (map_index >= max_num_samples) {
|
||||
if (verbose && (!second)) {
|
||||
cout << " reached " << max_num_samples << " samples after "
|
||||
<< epoch << " epochs ..." << endl << std::flush;
|
||||
}
|
||||
break;
|
||||
}
|
||||
// For each document:
|
||||
for (int32_t doc=0; doc<(docs.shape(0) - 1); ++doc) {
|
||||
|
||||
// Document sentences are in [sent_index_first, sent_index_last)
|
||||
const auto sent_index_first = docs[doc];
|
||||
const auto sent_index_last = docs[doc + 1];
|
||||
|
||||
// At the begining of the document previous index is the
|
||||
// start index.
|
||||
auto prev_start_index = sent_index_first;
|
||||
|
||||
// Remaining documents.
|
||||
auto num_remain_sent = sent_index_last - sent_index_first;
|
||||
|
||||
// Some bookkeeping
|
||||
if ((epoch == 0) && (!second)) {
|
||||
if (num_remain_sent == 0) {
|
||||
++empty_docs;
|
||||
}
|
||||
if (num_remain_sent == 1) {
|
||||
++one_sent_docs;
|
||||
}
|
||||
}
|
||||
|
||||
// Detect documents with long sentences.
|
||||
bool contains_long_sentence = false;
|
||||
if (num_remain_sent > 1) {
|
||||
for (auto sent_index=sent_index_first;
|
||||
sent_index < sent_index_last; ++sent_index) {
|
||||
if (sizes[sent_index] > LONG_SENTENCE_LEN){
|
||||
if ((epoch == 0) && (!second)) {
|
||||
++long_sent_docs;
|
||||
}
|
||||
contains_long_sentence = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If we have more than two sentences.
|
||||
if ((num_remain_sent >= min_num_sent) && (!contains_long_sentence)) {
|
||||
|
||||
// Set values.
|
||||
auto seq_len = int32_t{0};
|
||||
auto num_sent = int32_t{0};
|
||||
auto target_seq_len = get_target_sample_len(short_seq_ratio,
|
||||
max_seq_length,
|
||||
rand32_gen);
|
||||
|
||||
// Loop through sentences.
|
||||
for (auto sent_index=sent_index_first;
|
||||
sent_index < sent_index_last; ++sent_index) {
|
||||
|
||||
// Add the size and number of sentences.
|
||||
seq_len += sizes[sent_index];
|
||||
++num_sent;
|
||||
--num_remain_sent;
|
||||
|
||||
// If we have reached the target length.
|
||||
// and if not only one sentence is left in the document.
|
||||
// and if we have at least two sentneces.
|
||||
// and if we have reached end of the document.
|
||||
if (((seq_len >= target_seq_len) &&
|
||||
(num_remain_sent > 1) &&
|
||||
(num_sent >= min_num_sent) ) || (num_remain_sent == 0)) {
|
||||
|
||||
// Check for overflow.
|
||||
if ((3 * map_index + 2) >
|
||||
std::numeric_limits<int64_t>::max()) {
|
||||
cout << "number of samples exceeded maximum "
|
||||
<< "allowed by type int64: "
|
||||
<< std::numeric_limits<int64_t>::max()
|
||||
<< endl;
|
||||
throw std::overflow_error("Number of samples");
|
||||
}
|
||||
|
||||
// Populate the map.
|
||||
if (second) {
|
||||
const auto map_index_0 = 3 * map_index;
|
||||
maps[map_index_0] = static_cast<DocIdx>(prev_start_index);
|
||||
maps[map_index_0 + 1] = static_cast<DocIdx>(sent_index + 1);
|
||||
maps[map_index_0 + 2] = static_cast<DocIdx>(target_seq_len);
|
||||
}
|
||||
|
||||
// Update indices / counters.
|
||||
++map_index;
|
||||
prev_start_index = sent_index + 1;
|
||||
target_seq_len = get_target_sample_len(short_seq_ratio,
|
||||
max_seq_length,
|
||||
rand32_gen);
|
||||
seq_len = 0;
|
||||
num_sent = 0;
|
||||
}
|
||||
|
||||
} // for (auto sent_index=sent_index_first; ...
|
||||
} // if (num_remain_sent > 1) {
|
||||
} // for (int doc=0; doc < num_docs; ++doc) {
|
||||
} // for (int epoch=0; epoch < num_epochs; ++epoch) {
|
||||
|
||||
if (!second) {
|
||||
if (verbose) {
|
||||
cout << " number of empty documents: " << empty_docs <<
|
||||
endl << std::flush;
|
||||
cout << " number of documents with one sentence: " <<
|
||||
one_sent_docs << endl << std::flush;
|
||||
cout << " number of documents with long sentences: " <<
|
||||
long_sent_docs << endl << std::flush;
|
||||
cout << " will create mapping for " << map_index <<
|
||||
" samples" << endl << std::flush;
|
||||
}
|
||||
assert(maps == NULL);
|
||||
assert(num_samples < 0);
|
||||
maps = new DocIdx[3*map_index];
|
||||
num_samples = static_cast<int64_t>(map_index);
|
||||
}
|
||||
|
||||
} // for (int iteration=0; iteration < 2; ++iteration) {
|
||||
|
||||
// Shuffle.
|
||||
// We need a 64 bit random number generator as we might have more
|
||||
// than 2 billion samples.
|
||||
std::mt19937_64 rand64_gen(seed + 1);
|
||||
for (auto i=(num_samples - 1); i > 0; --i) {
|
||||
const auto j = static_cast<int64_t>(rand64_gen() % (i + 1));
|
||||
const auto i0 = 3 * i;
|
||||
const auto j0 = 3 * j;
|
||||
// Swap values.
|
||||
swap(maps[i0], maps[j0]);
|
||||
swap(maps[i0 + 1], maps[j0 + 1]);
|
||||
swap(maps[i0 + 2], maps[j0 + 2]);
|
||||
}
|
||||
|
||||
// Method to deallocate memory.
|
||||
py::capsule free_when_done(maps, [](void *mem_) {
|
||||
DocIdx *mem = reinterpret_cast<DocIdx*>(mem_);
|
||||
delete[] mem;
|
||||
});
|
||||
|
||||
// Return the numpy array.
|
||||
const auto byte_size = sizeof(DocIdx);
|
||||
return py::array(std::vector<int64_t>{num_samples, 3}, // shape
|
||||
{3*byte_size, byte_size}, // C-style contiguous strides
|
||||
maps, // the data pointer
|
||||
free_when_done); // numpy array references
|
||||
|
||||
}
|
||||
|
||||
|
||||
py::array build_mapping(const py::array_t<int64_t>& docs_,
|
||||
const py::array_t<int>& sizes_,
|
||||
const int num_epochs,
|
||||
const uint64_t max_num_samples,
|
||||
const int max_seq_length,
|
||||
const double short_seq_prob,
|
||||
const int seed,
|
||||
const bool verbose,
|
||||
const int32_t min_num_sent) {
|
||||
|
||||
if (sizes_.size() > std::numeric_limits<uint32_t>::max()) {
|
||||
if (verbose) {
|
||||
cout << " using uint64 for data mapping..." << endl << std::flush;
|
||||
}
|
||||
return build_mapping_impl<uint64_t>(docs_, sizes_, num_epochs,
|
||||
max_num_samples, max_seq_length,
|
||||
short_seq_prob, seed, verbose,
|
||||
min_num_sent);
|
||||
} else {
|
||||
if (verbose) {
|
||||
cout << " using uint32 for data mapping..." << endl << std::flush;
|
||||
}
|
||||
return build_mapping_impl<uint32_t>(docs_, sizes_, num_epochs,
|
||||
max_num_samples, max_seq_length,
|
||||
short_seq_prob, seed, verbose,
|
||||
min_num_sent);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename DocIdx>
|
||||
py::array build_blocks_mapping_impl(const py::array_t<int64_t>& docs_,
|
||||
const py::array_t<int32_t>& sizes_,
|
||||
const py::array_t<int32_t>& titles_sizes_,
|
||||
const int32_t num_epochs,
|
||||
const uint64_t max_num_samples,
|
||||
const int32_t max_seq_length,
|
||||
const int32_t seed,
|
||||
const bool verbose,
|
||||
const bool use_one_sent_blocks) {
|
||||
/* Build a mapping of (start-index, end-index, sequence-length) where
|
||||
start and end index are the indices of the sentences in the sample
|
||||
and sequence-length is the target sequence length.
|
||||
*/
|
||||
|
||||
// Consistency checks.
|
||||
assert(num_epochs > 0);
|
||||
assert(max_seq_length > 1);
|
||||
assert(seed > 0);
|
||||
|
||||
// Remove bound checks.
|
||||
auto docs = docs_.unchecked<1>();
|
||||
auto sizes = sizes_.unchecked<1>();
|
||||
auto titles_sizes = titles_sizes_.unchecked<1>();
|
||||
|
||||
if (verbose) {
|
||||
const auto sent_start_index = docs[0];
|
||||
const auto sent_end_index = docs[docs_.shape(0) - 1];
|
||||
const auto num_sentences = sent_end_index - sent_start_index;
|
||||
cout << " using:" << endl << std::flush;
|
||||
cout << " number of documents: " << docs_.shape(0) - 1 <<
|
||||
endl << std::flush;
|
||||
cout << " sentences range: [" << sent_start_index <<
|
||||
", " << sent_end_index << ")" << endl << std::flush;
|
||||
cout << " total number of sentences: " << num_sentences <<
|
||||
endl << std::flush;
|
||||
cout << " number of epochs: " << num_epochs <<
|
||||
endl << std::flush;
|
||||
cout << " maximum number of samples: " << max_num_samples <<
|
||||
endl << std::flush;
|
||||
cout << " maximum sequence length: " << max_seq_length <<
|
||||
endl << std::flush;
|
||||
cout << " seed: " << seed << endl <<
|
||||
std::flush;
|
||||
}
|
||||
|
||||
// Mapping and its length (1D).
|
||||
int64_t num_samples = -1;
|
||||
DocIdx* maps = NULL;
|
||||
|
||||
// Acceptable number of sentences per block.
|
||||
int min_num_sent = 2;
|
||||
if (use_one_sent_blocks) {
|
||||
min_num_sent = 1;
|
||||
}
|
||||
|
||||
// Perform two iterations, in the first iteration get the size
|
||||
// and allocate memory and in the second iteration populate the map.
|
||||
bool second = false;
|
||||
for (int32_t iteration=0; iteration<2; ++iteration) {
|
||||
|
||||
// Set the flag on second iteration.
|
||||
second = (iteration == 1);
|
||||
|
||||
// Current map index.
|
||||
uint64_t map_index = 0;
|
||||
|
||||
uint64_t empty_docs = 0;
|
||||
uint64_t one_sent_docs = 0;
|
||||
uint64_t long_sent_docs = 0;
|
||||
// For each epoch:
|
||||
for (int32_t epoch=0; epoch<num_epochs; ++epoch) {
|
||||
// assign every block a unique id
|
||||
int32_t block_id = 0;
|
||||
|
||||
if (map_index >= max_num_samples) {
|
||||
if (verbose && (!second)) {
|
||||
cout << " reached " << max_num_samples << " samples after "
|
||||
<< epoch << " epochs ..." << endl << std::flush;
|
||||
}
|
||||
break;
|
||||
}
|
||||
// For each document:
|
||||
for (int32_t doc=0; doc<(docs.shape(0) - 1); ++doc) {
|
||||
|
||||
// Document sentences are in [sent_index_first, sent_index_last)
|
||||
const auto sent_index_first = docs[doc];
|
||||
const auto sent_index_last = docs[doc + 1];
|
||||
const auto target_seq_len = max_seq_length - titles_sizes[doc];
|
||||
|
||||
// At the begining of the document previous index is the
|
||||
// start index.
|
||||
auto prev_start_index = sent_index_first;
|
||||
|
||||
// Remaining documents.
|
||||
auto num_remain_sent = sent_index_last - sent_index_first;
|
||||
|
||||
// Some bookkeeping
|
||||
if ((epoch == 0) && (!second)) {
|
||||
if (num_remain_sent == 0) {
|
||||
++empty_docs;
|
||||
}
|
||||
if (num_remain_sent == 1) {
|
||||
++one_sent_docs;
|
||||
}
|
||||
}
|
||||
// Detect documents with long sentences.
|
||||
bool contains_long_sentence = false;
|
||||
if (num_remain_sent >= min_num_sent) {
|
||||
for (auto sent_index=sent_index_first;
|
||||
sent_index < sent_index_last; ++sent_index) {
|
||||
if (sizes[sent_index] > LONG_SENTENCE_LEN){
|
||||
if ((epoch == 0) && (!second)) {
|
||||
++long_sent_docs;
|
||||
}
|
||||
contains_long_sentence = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
// If we have enough sentences and no long sentences.
|
||||
if ((num_remain_sent >= min_num_sent) && (!contains_long_sentence)) {
|
||||
|
||||
// Set values.
|
||||
auto seq_len = int32_t{0};
|
||||
auto num_sent = int32_t{0};
|
||||
|
||||
// Loop through sentences.
|
||||
for (auto sent_index=sent_index_first;
|
||||
sent_index < sent_index_last; ++sent_index) {
|
||||
|
||||
// Add the size and number of sentences.
|
||||
seq_len += sizes[sent_index];
|
||||
++num_sent;
|
||||
--num_remain_sent;
|
||||
|
||||
// If we have reached the target length.
|
||||
// and there are an acceptable number of sentences left
|
||||
// and if we have at least the minimum number of sentences.
|
||||
// or if we have reached end of the document.
|
||||
if (((seq_len >= target_seq_len) &&
|
||||
(num_remain_sent >= min_num_sent) &&
|
||||
(num_sent >= min_num_sent) ) || (num_remain_sent == 0)) {
|
||||
|
||||
// Populate the map.
|
||||
if (second) {
|
||||
const auto map_index_0 = 4 * map_index;
|
||||
// Each sample has 4 items: the starting sentence index, ending sentence index,
|
||||
// the index of the document from which the block comes (used for fetching titles)
|
||||
// and the unique id of the block (used for creating block indexes)
|
||||
|
||||
maps[map_index_0] = static_cast<DocIdx>(prev_start_index);
|
||||
maps[map_index_0 + 1] = static_cast<DocIdx>(sent_index + 1);
|
||||
maps[map_index_0 + 2] = static_cast<DocIdx>(doc);
|
||||
maps[map_index_0 + 3] = static_cast<DocIdx>(block_id);
|
||||
}
|
||||
|
||||
// Update indices / counters.
|
||||
++map_index;
|
||||
++block_id;
|
||||
prev_start_index = sent_index + 1;
|
||||
seq_len = 0;
|
||||
num_sent = 0;
|
||||
}
|
||||
} // for (auto sent_index=sent_index_first; ...
|
||||
} // if (num_remain_sent > 1) {
|
||||
} // for (int doc=0; doc < num_docs; ++doc) {
|
||||
} // for (int epoch=0; epoch < num_epochs; ++epoch) {
|
||||
|
||||
if (!second) {
|
||||
if (verbose) {
|
||||
cout << " number of empty documents: " << empty_docs <<
|
||||
endl << std::flush;
|
||||
cout << " number of documents with one sentence: " <<
|
||||
one_sent_docs << endl << std::flush;
|
||||
cout << " number of documents with long sentences: " <<
|
||||
long_sent_docs << endl << std::flush;
|
||||
cout << " will create mapping for " << map_index <<
|
||||
" samples" << endl << std::flush;
|
||||
}
|
||||
assert(maps == NULL);
|
||||
assert(num_samples < 0);
|
||||
maps = new DocIdx[4*map_index];
|
||||
num_samples = static_cast<int64_t>(map_index);
|
||||
}
|
||||
|
||||
} // for (int iteration=0; iteration < 2; ++iteration) {
|
||||
|
||||
// Shuffle.
|
||||
// We need a 64 bit random number generator as we might have more
|
||||
// than 2 billion samples.
|
||||
std::mt19937_64 rand64_gen(seed + 1);
|
||||
for (auto i=(num_samples - 1); i > 0; --i) {
|
||||
const auto j = static_cast<int64_t>(rand64_gen() % (i + 1));
|
||||
const auto i0 = 4 * i;
|
||||
const auto j0 = 4 * j;
|
||||
// Swap values.
|
||||
swap(maps[i0], maps[j0]);
|
||||
swap(maps[i0 + 1], maps[j0 + 1]);
|
||||
swap(maps[i0 + 2], maps[j0 + 2]);
|
||||
swap(maps[i0 + 3], maps[j0 + 3]);
|
||||
}
|
||||
|
||||
// Method to deallocate memory.
|
||||
py::capsule free_when_done(maps, [](void *mem_) {
|
||||
DocIdx *mem = reinterpret_cast<DocIdx*>(mem_);
|
||||
delete[] mem;
|
||||
});
|
||||
|
||||
// Return the numpy array.
|
||||
const auto byte_size = sizeof(DocIdx);
|
||||
return py::array(std::vector<int64_t>{num_samples, 4}, // shape
|
||||
{4*byte_size, byte_size}, // C-style contiguous strides
|
||||
maps, // the data pointer
|
||||
free_when_done); // numpy array references
|
||||
|
||||
}
|
||||
|
||||
py::array build_blocks_mapping(const py::array_t<int64_t>& docs_,
|
||||
const py::array_t<int>& sizes_,
|
||||
const py::array_t<int>& titles_sizes_,
|
||||
const int num_epochs,
|
||||
const uint64_t max_num_samples,
|
||||
const int max_seq_length,
|
||||
const int seed,
|
||||
const bool verbose,
|
||||
const bool use_one_sent_blocks) {
|
||||
|
||||
if (sizes_.size() > std::numeric_limits<uint32_t>::max()) {
|
||||
if (verbose) {
|
||||
cout << " using uint64 for data mapping..." << endl << std::flush;
|
||||
}
|
||||
return build_blocks_mapping_impl<uint64_t>(docs_, sizes_, titles_sizes_,
|
||||
num_epochs, max_num_samples, max_seq_length, seed, verbose, use_one_sent_blocks);
|
||||
} else {
|
||||
if (verbose) {
|
||||
cout << " using uint32 for data mapping..." << endl << std::flush;
|
||||
}
|
||||
return build_blocks_mapping_impl<uint32_t>(docs_, sizes_, titles_sizes_,
|
||||
num_epochs, max_num_samples, max_seq_length, seed, verbose, use_one_sent_blocks);
|
||||
}
|
||||
}
|
||||
|
||||
PYBIND11_MODULE(helpers, m) {
|
||||
m.def("build_mapping", &build_mapping);
|
||||
m.def("build_blocks_mapping", &build_blocks_mapping);
|
||||
m.def("build_sample_idx", &build_sample_idx);
|
||||
m.def("build_blending_indices", &build_blending_indices);
|
||||
}
|
|
@ -0,0 +1,564 @@
|
|||
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
# Most of the code here has been copied from:
|
||||
# fairseq/fairseq/data/indexed_dataset.py
|
||||
|
||||
# with some modifications:
|
||||
|
||||
# Removed IndexedRawTextDataset since it relied on Fairseq dictionary
|
||||
# other slight modifications to remove fairseq dependencies
|
||||
# Added document index to index file and made it accessible.
|
||||
# An empty sentence no longer separates documents.
|
||||
|
||||
import os
|
||||
import shutil
|
||||
import struct
|
||||
from functools import lru_cache
|
||||
from itertools import accumulate
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from nemo.utils import logging
|
||||
|
||||
|
||||
def __best_fitting_dtype(vocab_size=None):
|
||||
if vocab_size is not None and vocab_size < 65500:
|
||||
return np.uint16
|
||||
else:
|
||||
return np.int32
|
||||
|
||||
|
||||
def get_available_dataset_impl():
|
||||
return ['lazy', 'cached', 'mmap']
|
||||
|
||||
|
||||
def infer_dataset_impl(path):
|
||||
if IndexedDataset.exists(path):
|
||||
with open(index_file_path(path), 'rb') as f:
|
||||
magic = f.read(8)
|
||||
if magic == IndexedDataset._HDR_MAGIC:
|
||||
return 'cached'
|
||||
elif magic == MMapIndexedDataset.Index._HDR_MAGIC[:8]:
|
||||
return 'mmap'
|
||||
else:
|
||||
return None
|
||||
else:
|
||||
print(f"Dataset does not exist: {path}")
|
||||
print("Path should be a basename that both .idx and .bin can be appended to get full filenames.")
|
||||
return None
|
||||
|
||||
|
||||
def make_builder(out_file, impl, vocab_size=None):
|
||||
if impl == 'mmap':
|
||||
return MMapIndexedDatasetBuilder(out_file, dtype=__best_fitting_dtype(vocab_size))
|
||||
else:
|
||||
return IndexedDatasetBuilder(out_file)
|
||||
|
||||
|
||||
def make_dataset(path, impl, skip_warmup=False):
|
||||
if not IndexedDataset.exists(path):
|
||||
print(f"Dataset does not exist: {path}")
|
||||
print("Path should be a basename that both .idx and .bin can be appended to get full filenames.")
|
||||
return None
|
||||
if impl == 'infer':
|
||||
impl = infer_dataset_impl(path)
|
||||
if impl == 'lazy' and IndexedDataset.exists(path):
|
||||
return IndexedDataset(path)
|
||||
elif impl == 'cached' and IndexedDataset.exists(path):
|
||||
return IndexedCachedDataset(path)
|
||||
elif impl == 'mmap' and MMapIndexedDataset.exists(path):
|
||||
return MMapIndexedDataset(path, skip_warmup)
|
||||
print(f"Unknown dataset implementation: {impl}")
|
||||
return None
|
||||
|
||||
|
||||
def dataset_exists(path, impl):
|
||||
if impl == 'mmap':
|
||||
return MMapIndexedDataset.exists(path)
|
||||
else:
|
||||
return IndexedDataset.exists(path)
|
||||
|
||||
|
||||
def read_longs(f, n):
|
||||
a = np.empty(n, dtype=np.int64)
|
||||
f.readinto(a)
|
||||
return a
|
||||
|
||||
|
||||
def write_longs(f, a):
|
||||
f.write(np.array(a, dtype=np.int64))
|
||||
|
||||
|
||||
dtypes = {1: np.uint8, 2: np.int8, 3: np.int16, 4: np.int32, 5: np.int64, 6: np.float, 7: np.double, 8: np.uint16}
|
||||
|
||||
|
||||
def code(dtype):
|
||||
for k in dtypes.keys():
|
||||
if dtypes[k] == dtype:
|
||||
return k
|
||||
raise ValueError(dtype)
|
||||
|
||||
|
||||
def index_file_path(prefix_path):
|
||||
return prefix_path + '.idx'
|
||||
|
||||
|
||||
def data_file_path(prefix_path):
|
||||
return prefix_path + '.bin'
|
||||
|
||||
|
||||
def create_doc_idx(sizes):
|
||||
doc_idx = [0]
|
||||
for i, s in enumerate(sizes):
|
||||
if s == 0:
|
||||
doc_idx.append(i + 1)
|
||||
return doc_idx
|
||||
|
||||
|
||||
class IndexedDataset(torch.utils.data.Dataset):
|
||||
"""Loader for IndexedDataset"""
|
||||
|
||||
_HDR_MAGIC = b'TNTIDX\x00\x00'
|
||||
|
||||
def __init__(self, path):
|
||||
super().__init__()
|
||||
self.path = path
|
||||
self.data_file = None
|
||||
self.read_index(path)
|
||||
|
||||
def read_index(self, path):
|
||||
with open(index_file_path(path), 'rb') as f:
|
||||
magic = f.read(8)
|
||||
assert magic == self._HDR_MAGIC, (
|
||||
'Index file doesn\'t match expected format. ' 'Make sure that --dataset-impl is configured properly.'
|
||||
)
|
||||
version = f.read(8)
|
||||
assert struct.unpack('<Q', version) == (1,)
|
||||
code, self.element_size = struct.unpack('<QQ', f.read(16))
|
||||
self.dtype = dtypes[code]
|
||||
self._len, self.s = struct.unpack('<QQ', f.read(16))
|
||||
self.doc_count = struct.unpack('<Q', f.read(8))
|
||||
self.dim_offsets = read_longs(f, self._len + 1)
|
||||
self.data_offsets = read_longs(f, self._len + 1)
|
||||
self.sizes = read_longs(f, self.s)
|
||||
self.doc_idx = read_longs(f, self.doc_count)
|
||||
|
||||
def read_data(self, path):
|
||||
self.data_file = open(data_file_path(path), 'rb', buffering=0)
|
||||
|
||||
def check_index(self, i):
|
||||
if i < 0 or i >= self._len:
|
||||
raise IndexError('index out of range')
|
||||
|
||||
def __del__(self):
|
||||
if self.data_file:
|
||||
self.data_file.close()
|
||||
|
||||
# @lru_cache(maxsize=8)
|
||||
def __getitem__(self, idx):
|
||||
if not self.data_file:
|
||||
self.read_data(self.path)
|
||||
if isinstance(idx, int):
|
||||
i = idx
|
||||
self.check_index(i)
|
||||
tensor_size = self.sizes[self.dim_offsets[i] : self.dim_offsets[i + 1]]
|
||||
a = np.empty(tensor_size, dtype=self.dtype)
|
||||
self.data_file.seek(self.data_offsets[i] * self.element_size)
|
||||
self.data_file.readinto(a)
|
||||
return a
|
||||
elif isinstance(idx, slice):
|
||||
start, stop, step = idx.indices(len(self))
|
||||
if step != 1:
|
||||
raise ValueError("Slices into indexed_dataset must be contiguous")
|
||||
sizes = self.sizes[self.dim_offsets[start] : self.dim_offsets[stop]]
|
||||
size = sum(sizes)
|
||||
a = np.empty(size, dtype=self.dtype)
|
||||
self.data_file.seek(self.data_offsets[start] * self.element_size)
|
||||
self.data_file.readinto(a)
|
||||
offsets = list(accumulate(sizes))
|
||||
sents = np.split(a, offsets[:-1])
|
||||
return sents
|
||||
|
||||
def __len__(self):
|
||||
return self._len
|
||||
|
||||
def num_tokens(self, index):
|
||||
return self.sizes[index]
|
||||
|
||||
def size(self, index):
|
||||
return self.sizes[index]
|
||||
|
||||
@staticmethod
|
||||
def exists(path):
|
||||
return os.path.exists(index_file_path(path)) and os.path.exists(data_file_path(path))
|
||||
|
||||
@property
|
||||
def supports_prefetch(self):
|
||||
return False # avoid prefetching to save memory
|
||||
|
||||
|
||||
class IndexedCachedDataset(IndexedDataset):
|
||||
def __init__(self, path):
|
||||
super().__init__(path)
|
||||
self.cache = None
|
||||
self.cache_index = {}
|
||||
|
||||
@property
|
||||
def supports_prefetch(self):
|
||||
return True
|
||||
|
||||
def prefetch(self, indices):
|
||||
if all(i in self.cache_index for i in indices):
|
||||
return
|
||||
if not self.data_file:
|
||||
self.read_data(self.path)
|
||||
indices = sorted(set(indices))
|
||||
total_size = 0
|
||||
for i in indices:
|
||||
total_size += self.data_offsets[i + 1] - self.data_offsets[i]
|
||||
self.cache = np.empty(total_size, dtype=self.dtype)
|
||||
ptx = 0
|
||||
self.cache_index.clear()
|
||||
for i in indices:
|
||||
self.cache_index[i] = ptx
|
||||
size = self.data_offsets[i + 1] - self.data_offsets[i]
|
||||
a = self.cache[ptx : ptx + size]
|
||||
self.data_file.seek(self.data_offsets[i] * self.element_size)
|
||||
self.data_file.readinto(a)
|
||||
ptx += size
|
||||
if self.data_file:
|
||||
# close and delete data file after prefetch so we can pickle
|
||||
self.data_file.close()
|
||||
self.data_file = None
|
||||
|
||||
# @lru_cache(maxsize=8)
|
||||
def __getitem__(self, idx):
|
||||
if isinstance(idx, int):
|
||||
i = idx
|
||||
self.check_index(i)
|
||||
tensor_size = self.sizes[self.dim_offsets[i] : self.dim_offsets[i + 1]]
|
||||
a = np.empty(tensor_size, dtype=self.dtype)
|
||||
ptx = self.cache_index[i]
|
||||
np.copyto(a, self.cache[ptx : ptx + a.size])
|
||||
return a
|
||||
elif isinstance(idx, slice):
|
||||
# Hack just to make this work, can optimizer later if necessary
|
||||
sents = []
|
||||
for i in range(*idx.indices(len(self))):
|
||||
sents.append(self[i])
|
||||
return sents
|
||||
|
||||
|
||||
class IndexedDatasetBuilder(object):
|
||||
element_sizes = {np.uint8: 1, np.int8: 1, np.int16: 2, np.int32: 4, np.int64: 8, np.float: 4, np.double: 8}
|
||||
|
||||
def __init__(self, out_file, dtype=np.int32):
|
||||
self.out_file = open(out_file, 'wb')
|
||||
self.dtype = dtype
|
||||
self.data_offsets = [0]
|
||||
self.dim_offsets = [0]
|
||||
self.sizes = []
|
||||
self.element_size = self.element_sizes[self.dtype]
|
||||
self.doc_idx = [0]
|
||||
|
||||
def add_item(self, tensor):
|
||||
bytes = self.out_file.write(np.array(tensor.numpy(), dtype=self.dtype))
|
||||
self.data_offsets.append(self.data_offsets[-1] + bytes / self.element_size)
|
||||
for s in tensor.size():
|
||||
self.sizes.append(s)
|
||||
self.dim_offsets.append(self.dim_offsets[-1] + len(tensor.size()))
|
||||
|
||||
def end_document(self):
|
||||
self.doc_idx.append(len(self.sizes))
|
||||
|
||||
def merge_file_(self, another_file):
|
||||
index = IndexedDataset(another_file)
|
||||
assert index.dtype == self.dtype
|
||||
|
||||
begin = self.data_offsets[-1]
|
||||
for offset in index.data_offsets[1:]:
|
||||
self.data_offsets.append(begin + offset)
|
||||
self.sizes.extend(index.sizes)
|
||||
begin = self.dim_offsets[-1]
|
||||
for dim_offset in index.dim_offsets[1:]:
|
||||
self.dim_offsets.append(begin + dim_offset)
|
||||
|
||||
with open(data_file_path(another_file), 'rb') as f:
|
||||
while True:
|
||||
data = f.read(1024)
|
||||
if data:
|
||||
self.out_file.write(data)
|
||||
else:
|
||||
break
|
||||
|
||||
def finalize(self, index_file):
|
||||
self.out_file.close()
|
||||
index = open(index_file, 'wb')
|
||||
index.write(b'TNTIDX\x00\x00')
|
||||
index.write(struct.pack('<Q', 1))
|
||||
index.write(struct.pack('<QQ', code(self.dtype), self.element_size))
|
||||
index.write(struct.pack('<QQ', len(self.data_offsets) - 1, len(self.sizes)))
|
||||
index.write(struct.pack('<Q', len(self.doc_idx)))
|
||||
write_longs(index, self.dim_offsets)
|
||||
write_longs(index, self.data_offsets)
|
||||
write_longs(index, self.sizes)
|
||||
write_longs(index, self.doc_idx)
|
||||
index.close()
|
||||
|
||||
|
||||
def _warmup_mmap_file(path):
|
||||
with open(path, 'rb') as stream:
|
||||
while stream.read(100 * 1024 * 1024):
|
||||
pass
|
||||
|
||||
|
||||
class MMapIndexedDataset(torch.utils.data.Dataset):
|
||||
class Index(object):
|
||||
_HDR_MAGIC = b'MMIDIDX\x00\x00'
|
||||
|
||||
@classmethod
|
||||
def writer(cls, path, dtype):
|
||||
class _Writer(object):
|
||||
def __enter__(self):
|
||||
self._file = open(path, 'wb')
|
||||
|
||||
self._file.write(cls._HDR_MAGIC)
|
||||
self._file.write(struct.pack('<Q', 1))
|
||||
self._file.write(struct.pack('<B', code(dtype)))
|
||||
|
||||
return self
|
||||
|
||||
@staticmethod
|
||||
def _get_pointers(sizes):
|
||||
dtype_size = dtype().itemsize
|
||||
address = 0
|
||||
pointers = []
|
||||
|
||||
for size in sizes:
|
||||
pointers.append(address)
|
||||
address += size * dtype_size
|
||||
|
||||
return pointers
|
||||
|
||||
def write(self, sizes, doc_idx):
|
||||
pointers = self._get_pointers(sizes)
|
||||
|
||||
self._file.write(struct.pack('<Q', len(sizes)))
|
||||
self._file.write(struct.pack('<Q', len(doc_idx)))
|
||||
|
||||
sizes = np.array(sizes, dtype=np.int32)
|
||||
self._file.write(sizes.tobytes(order='C'))
|
||||
del sizes
|
||||
|
||||
pointers = np.array(pointers, dtype=np.int64)
|
||||
self._file.write(pointers.tobytes(order='C'))
|
||||
del pointers
|
||||
|
||||
doc_idx = np.array(doc_idx, dtype=np.int64)
|
||||
self._file.write(doc_idx.tobytes(order='C'))
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self._file.close()
|
||||
|
||||
return _Writer()
|
||||
|
||||
def __init__(self, path, skip_warmup=False):
|
||||
with open(path, 'rb') as stream:
|
||||
magic_test = stream.read(9)
|
||||
assert self._HDR_MAGIC == magic_test, (
|
||||
'Index file doesn\'t match expected format. '
|
||||
'Make sure that --dataset-impl is configured properly.'
|
||||
)
|
||||
version = struct.unpack('<Q', stream.read(8))
|
||||
assert (1,) == version
|
||||
|
||||
(dtype_code,) = struct.unpack('<B', stream.read(1))
|
||||
self._dtype = dtypes[dtype_code]
|
||||
self._dtype_size = self._dtype().itemsize
|
||||
|
||||
self._len = struct.unpack('<Q', stream.read(8))[0]
|
||||
self._doc_count = struct.unpack('<Q', stream.read(8))[0]
|
||||
offset = stream.tell()
|
||||
|
||||
if not skip_warmup:
|
||||
logging.info(" warming up index mmap file...")
|
||||
_warmup_mmap_file(path)
|
||||
|
||||
self._bin_buffer_mmap = np.memmap(path, mode='r', order='C')
|
||||
self._bin_buffer = memoryview(self._bin_buffer_mmap)
|
||||
logging.info(" reading sizes...")
|
||||
self._sizes = np.frombuffer(self._bin_buffer, dtype=np.int32, count=self._len, offset=offset)
|
||||
logging.info(" reading pointers...")
|
||||
self._pointers = np.frombuffer(
|
||||
self._bin_buffer, dtype=np.int64, count=self._len, offset=offset + self._sizes.nbytes
|
||||
)
|
||||
logging.info(" reading document index...")
|
||||
self._doc_idx = np.frombuffer(
|
||||
self._bin_buffer,
|
||||
dtype=np.int64,
|
||||
count=self._doc_count,
|
||||
offset=offset + self._sizes.nbytes + self._pointers.nbytes,
|
||||
)
|
||||
|
||||
def __del__(self):
|
||||
self._bin_buffer_mmap._mmap.close()
|
||||
del self._bin_buffer_mmap
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return self._dtype
|
||||
|
||||
@property
|
||||
def sizes(self):
|
||||
return self._sizes
|
||||
|
||||
@property
|
||||
def doc_idx(self):
|
||||
return self._doc_idx
|
||||
|
||||
@lru_cache(maxsize=8)
|
||||
def __getitem__(self, i):
|
||||
return self._pointers[i], self._sizes[i]
|
||||
|
||||
def __len__(self):
|
||||
return self._len
|
||||
|
||||
def __init__(self, path, skip_warmup=False):
|
||||
super().__init__()
|
||||
|
||||
self._path = None
|
||||
self._index = None
|
||||
self._bin_buffer = None
|
||||
|
||||
self._do_init(path, skip_warmup)
|
||||
|
||||
def __getstate__(self):
|
||||
return self._path
|
||||
|
||||
# def __setstate__(self, state):
|
||||
# self._do_init(state)
|
||||
|
||||
def _do_init(self, path, skip_warmup):
|
||||
self._path = path
|
||||
self._index = self.Index(index_file_path(self._path), skip_warmup)
|
||||
|
||||
if not skip_warmup:
|
||||
logging.info(" warming up data mmap file...")
|
||||
_warmup_mmap_file(data_file_path(self._path))
|
||||
logging.info(" creating numpy buffer of mmap...")
|
||||
self._bin_buffer_mmap = np.memmap(data_file_path(self._path), mode='r', order='C')
|
||||
logging.info(" creating memory view of numpy buffer...")
|
||||
self._bin_buffer = memoryview(self._bin_buffer_mmap)
|
||||
|
||||
def __del__(self):
|
||||
self._bin_buffer_mmap._mmap.close()
|
||||
del self._bin_buffer_mmap
|
||||
del self._index
|
||||
|
||||
def __len__(self):
|
||||
return len(self._index)
|
||||
|
||||
# @lru_cache(maxsize=8)
|
||||
def __getitem__(self, idx):
|
||||
if isinstance(idx, int):
|
||||
ptr, size = self._index[idx]
|
||||
np_array = np.frombuffer(self._bin_buffer, dtype=self._index.dtype, count=size, offset=ptr)
|
||||
return np_array
|
||||
elif isinstance(idx, slice):
|
||||
start, stop, step = idx.indices(len(self))
|
||||
if step != 1:
|
||||
raise ValueError("Slices into indexed_dataset must be contiguous")
|
||||
ptr = self._index._pointers[start]
|
||||
sizes = self._index._sizes[idx]
|
||||
offsets = list(accumulate(sizes))
|
||||
total_size = sum(sizes)
|
||||
np_array = np.frombuffer(self._bin_buffer, dtype=self._index.dtype, count=total_size, offset=ptr)
|
||||
sents = np.split(np_array, offsets[:-1])
|
||||
return sents
|
||||
|
||||
def get(self, idx, offset=0, length=None):
|
||||
""" Retrieves a single item from the dataset with the option to only
|
||||
return a portion of the item.
|
||||
|
||||
get(idx) is the same as [idx] but get() does not support slicing.
|
||||
"""
|
||||
ptr, size = self._index[idx]
|
||||
if length is None:
|
||||
length = size - offset
|
||||
ptr += offset * np.dtype(self._index.dtype).itemsize
|
||||
np_array = np.frombuffer(self._bin_buffer, dtype=self._index.dtype, count=length, offset=ptr)
|
||||
return np_array
|
||||
|
||||
@property
|
||||
def sizes(self):
|
||||
return self._index.sizes
|
||||
|
||||
@property
|
||||
def doc_idx(self):
|
||||
return self._index.doc_idx
|
||||
|
||||
def get_doc_idx(self):
|
||||
return self._index._doc_idx
|
||||
|
||||
def set_doc_idx(self, doc_idx_):
|
||||
self._index._doc_idx = doc_idx_
|
||||
|
||||
@property
|
||||
def supports_prefetch(self):
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def exists(path):
|
||||
return os.path.exists(index_file_path(path)) and os.path.exists(data_file_path(path))
|
||||
|
||||
|
||||
class MMapIndexedDatasetBuilder(object):
|
||||
def __init__(self, out_file, dtype=np.int64):
|
||||
self._data_file = open(out_file, 'wb')
|
||||
self._dtype = dtype
|
||||
self._sizes = []
|
||||
self._doc_idx = [0]
|
||||
|
||||
def add_item(self, tensor):
|
||||
np_array = np.array(tensor.numpy(), dtype=self._dtype)
|
||||
self._data_file.write(np_array.tobytes(order='C'))
|
||||
self._sizes.append(np_array.size)
|
||||
|
||||
def end_document(self):
|
||||
self._doc_idx.append(len(self._sizes))
|
||||
|
||||
def merge_file_(self, another_file):
|
||||
# Concatenate index
|
||||
index = MMapIndexedDataset.Index(index_file_path(another_file))
|
||||
assert index.dtype == self._dtype
|
||||
|
||||
for size in index.sizes:
|
||||
self._sizes.append(size)
|
||||
|
||||
# Concatenate data
|
||||
with open(data_file_path(another_file), 'rb') as f:
|
||||
shutil.copyfileobj(f, self._data_file)
|
||||
|
||||
def finalize(self, index_file):
|
||||
self._data_file.close()
|
||||
|
||||
with MMapIndexedDataset.Index.writer(index_file, self._dtype) as index:
|
||||
index.write(self._sizes, self._doc_idx)
|
|
@ -0,0 +1,48 @@
|
|||
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Megatron dataset class that handles initialization and other megatron specific things common to all datasets."""
|
||||
|
||||
import torch
|
||||
from omegaconf.dictconfig import DictConfig
|
||||
from pytorch_lightning.trainer.trainer import Trainer
|
||||
|
||||
from nemo.collections.nlp.modules.common.megatron.megatron_init import initialize_model_parallel_for_nemo
|
||||
from nemo.utils import AppState, logging
|
||||
|
||||
|
||||
class MegatronDataset(torch.utils.data.Dataset):
|
||||
"""
|
||||
Megatron GPT pretraining
|
||||
"""
|
||||
|
||||
def __init__(self, cfg: DictConfig, trainer: Trainer):
|
||||
app_state = AppState()
|
||||
|
||||
if not app_state._is_megatron_initialized:
|
||||
logging.info(
|
||||
f"Initializing megatron since it hasn't been initialized by the model. This is normal if you are using a NeMo model with Megatron dataloaders."
|
||||
)
|
||||
app_state.global_rank = trainer.global_rank
|
||||
app_state.world_size = trainer.world_size
|
||||
app_state.model_parallel_size = 1
|
||||
app_state.model_parallel_rank = trainer.global_rank
|
||||
|
||||
initialize_model_parallel_for_nemo(
|
||||
world_size=trainer.world_size,
|
||||
global_rank=trainer.global_rank,
|
||||
local_rank=trainer.local_rank,
|
||||
tensor_model_parallel_size=cfg.get('tensor_model_parallel_size', 1),
|
||||
seed=self.cfg.get('seed', 1234),
|
||||
)
|
|
@ -0,0 +1,322 @@
|
|||
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""T5 Style dataset."""
|
||||
|
||||
import collections
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from megatron import get_tokenizer
|
||||
|
||||
from nemo.collections.nlp.data.language_modeling.megatron.dataset_utils import (
|
||||
create_masked_lm_predictions,
|
||||
get_samples_mapping,
|
||||
)
|
||||
|
||||
|
||||
class T5Dataset(torch.utils.data.Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
name,
|
||||
indexed_dataset,
|
||||
data_prefix,
|
||||
num_epochs,
|
||||
max_num_samples,
|
||||
masked_lm_prob,
|
||||
max_seq_length,
|
||||
max_seq_length_dec,
|
||||
short_seq_prob,
|
||||
seed,
|
||||
):
|
||||
|
||||
# Params to store.
|
||||
self.name = name
|
||||
self.seed = seed
|
||||
self.masked_lm_prob = masked_lm_prob
|
||||
self.max_seq_length = max_seq_length
|
||||
self.max_seq_length_dec = max_seq_length_dec
|
||||
|
||||
# Dataset.
|
||||
self.indexed_dataset = indexed_dataset
|
||||
|
||||
# Build the samples mapping.
|
||||
self.samples_mapping = get_samples_mapping(
|
||||
self.indexed_dataset,
|
||||
data_prefix,
|
||||
num_epochs,
|
||||
max_num_samples,
|
||||
self.max_seq_length - 2, # account for added tokens
|
||||
short_seq_prob,
|
||||
self.seed,
|
||||
self.name,
|
||||
False,
|
||||
)
|
||||
|
||||
# Vocab stuff.
|
||||
tokenizer = get_tokenizer()
|
||||
self.vocab_id_list = list(tokenizer.inv_vocab.keys())
|
||||
self.vocab_id_to_token_dict = tokenizer.inv_vocab
|
||||
self.cls_id = tokenizer.cls
|
||||
self.sep_id = tokenizer.sep
|
||||
self.mask_id = tokenizer.mask
|
||||
self.pad_id = tokenizer.pad
|
||||
self.bos_id = tokenizer.bos_token_id
|
||||
self.eos_id = tokenizer.eos_token_id
|
||||
self.sentinel_tokens = tokenizer.additional_special_tokens_ids
|
||||
assert len(self.sentinel_tokens) > 0, "Provide the argument --vocab-extra-ids 100 to the script"
|
||||
|
||||
def __len__(self):
|
||||
return self.samples_mapping.shape[0]
|
||||
|
||||
def __getitem__(self, idx):
|
||||
|
||||
start_index, end_index, seq_length = self.samples_mapping[idx]
|
||||
sample = []
|
||||
for index in range(start_index, end_index):
|
||||
sample.append(self.indexed_dataset[index])
|
||||
# Note that this rng state should be numpy and not python since
|
||||
# python randint is inclusive whereas the numpy one is exclusive.
|
||||
np_rng = np.random.RandomState(seed=(self.seed + idx))
|
||||
return build_training_sample(
|
||||
sample,
|
||||
seq_length,
|
||||
self.max_seq_length, # needed for padding
|
||||
self.max_seq_length_dec,
|
||||
self.vocab_id_list,
|
||||
self.vocab_id_to_token_dict,
|
||||
self.cls_id,
|
||||
self.sep_id,
|
||||
self.mask_id,
|
||||
self.pad_id,
|
||||
self.masked_lm_prob,
|
||||
np_rng,
|
||||
self.bos_id,
|
||||
self.eos_id,
|
||||
self.sentinel_tokens,
|
||||
)
|
||||
|
||||
|
||||
def build_training_sample(
|
||||
sample,
|
||||
target_seq_length,
|
||||
max_seq_length,
|
||||
max_seq_length_dec,
|
||||
vocab_id_list,
|
||||
vocab_id_to_token_dict,
|
||||
cls_id,
|
||||
sep_id,
|
||||
mask_id,
|
||||
pad_id,
|
||||
masked_lm_prob,
|
||||
np_rng,
|
||||
bos_id=None,
|
||||
eos_id=None,
|
||||
sentinel_tokens=None,
|
||||
):
|
||||
"""Build training sample.
|
||||
|
||||
Arguments:
|
||||
sample: A list of sentences in which each sentence is a list token ids.
|
||||
target_seq_length: Desired sequence length.
|
||||
max_seq_length: Maximum length of the sequence. All values are padded to
|
||||
this length.
|
||||
vocab_id_list: List of vocabulary ids. Used to pick a random id.
|
||||
vocab_id_to_token_dict: A dictionary from vocab ids to text tokens.
|
||||
cls_id: Start of example id.
|
||||
sep_id: Separator id.
|
||||
mask_id: Mask token id.
|
||||
pad_id: Padding token id.
|
||||
masked_lm_prob: Probability to mask tokens.
|
||||
np_rng: Random number genenrator. Note that this rng state should be
|
||||
numpy and not python since python randint is inclusive for
|
||||
the opper bound whereas the numpy one is exclusive.
|
||||
bos_id: start of decoder example id
|
||||
eos_id: end of generation id
|
||||
sentinel_tokens: unique value to be substituted for every replaced span
|
||||
"""
|
||||
|
||||
assert target_seq_length <= max_seq_length
|
||||
|
||||
# flatten sentences into one list
|
||||
tokens = [token for sentence in sample for token in sentence]
|
||||
|
||||
# Truncate to `target_sequence_length`.
|
||||
max_num_tokens = target_seq_length
|
||||
truncated = len(tokens) > max_num_tokens
|
||||
tokens = tokens[:max_num_tokens]
|
||||
|
||||
# Masking.
|
||||
max_predictions_per_seq = masked_lm_prob * max_num_tokens
|
||||
(tokens, masked_positions, masked_labels, _, masked_spans) = create_masked_lm_predictions(
|
||||
tokens,
|
||||
vocab_id_list,
|
||||
vocab_id_to_token_dict,
|
||||
masked_lm_prob,
|
||||
cls_id,
|
||||
sep_id,
|
||||
mask_id,
|
||||
max_predictions_per_seq,
|
||||
np_rng,
|
||||
max_ngrams=10,
|
||||
geometric_dist=True,
|
||||
masking_style="t5",
|
||||
)
|
||||
|
||||
# Padding.
|
||||
tokens_enc, tokens_dec_in, labels, enc_mask, dec_mask, enc_dec_mask, loss_mask = pad_and_convert_to_numpy(
|
||||
tokens,
|
||||
masked_positions,
|
||||
masked_labels,
|
||||
pad_id,
|
||||
max_seq_length,
|
||||
max_seq_length_dec,
|
||||
masked_spans,
|
||||
bos_id,
|
||||
eos_id,
|
||||
sentinel_tokens,
|
||||
)
|
||||
|
||||
train_sample = {
|
||||
'text_enc': tokens_enc,
|
||||
'text_dec': tokens_dec_in,
|
||||
'labels': labels,
|
||||
'loss_mask': loss_mask,
|
||||
'truncated': int(truncated),
|
||||
'enc_mask': enc_mask,
|
||||
'dec_mask': dec_mask,
|
||||
'enc_dec_mask': enc_dec_mask,
|
||||
}
|
||||
return train_sample
|
||||
|
||||
|
||||
def pad_and_convert_to_numpy(
|
||||
tokens,
|
||||
masked_positions,
|
||||
masked_labels,
|
||||
pad_id,
|
||||
max_seq_length,
|
||||
max_seq_length_dec,
|
||||
masked_spans=None,
|
||||
bos_id=None,
|
||||
eos_id=None,
|
||||
sentinel_tokens=None,
|
||||
):
|
||||
"""Pad sequences and convert them to numpy."""
|
||||
|
||||
sentinel_tokens = collections.deque(sentinel_tokens)
|
||||
t5_input = []
|
||||
(t5_decoder_in, t5_decoder_out) = ([bos_id], [])
|
||||
(start_index, end_index) = (0, None)
|
||||
for span in masked_spans:
|
||||
flag = sentinel_tokens.popleft()
|
||||
|
||||
# Append the same tokens in decoder input and output
|
||||
t5_decoder_in.append(flag)
|
||||
t5_decoder_in.extend(span.label)
|
||||
t5_decoder_out.append(flag)
|
||||
t5_decoder_out.extend(span.label)
|
||||
|
||||
end_index = span.index[0]
|
||||
t5_input.extend(tokens[start_index:end_index])
|
||||
t5_input.append(flag)
|
||||
|
||||
# the next start index is the token after the last span token
|
||||
start_index = span.index[-1] + 1
|
||||
|
||||
# Add <eos> token to the t5_decoder_out
|
||||
t5_decoder_out.append(eos_id)
|
||||
|
||||
# Add the remaining tokens to the t5 input
|
||||
t5_input.extend(tokens[start_index:])
|
||||
|
||||
# assert (len(t5_input) - len(masked_spans)) + \
|
||||
# (len(t5_decoder_in) - (len(masked_spans) + 1)) == len(tokens)
|
||||
|
||||
# Some checks.
|
||||
|
||||
# Encoder-side padding mask.
|
||||
num_tokens = len(t5_input)
|
||||
padding_length = max_seq_length - num_tokens
|
||||
assert padding_length >= 0
|
||||
assert len(masked_positions) == len(masked_labels)
|
||||
|
||||
# Tokens..
|
||||
filler = [pad_id] * padding_length
|
||||
tokens_enc = np.array(t5_input + filler, dtype=np.int64)
|
||||
|
||||
# Decoder-side padding mask.
|
||||
num_tokens_dec = len(t5_decoder_in)
|
||||
padding_length_dec = max_seq_length_dec - num_tokens_dec
|
||||
assert padding_length_dec >= 0
|
||||
filler_dec = [pad_id] * padding_length_dec
|
||||
tokens_dec_in = np.array(t5_decoder_in + filler_dec, dtype=np.int64)
|
||||
|
||||
# Create attention masks
|
||||
enc_mask = make_attention_mask(tokens_enc, tokens_enc)
|
||||
enc_dec_mask = make_attention_mask(tokens_dec_in, tokens_enc)
|
||||
dec_mask = make_attention_mask(tokens_dec_in, tokens_dec_in)
|
||||
dec_mask = dec_mask * make_history_mask(tokens_dec_in)
|
||||
|
||||
# Labels mask.
|
||||
labels = t5_decoder_out + ([-1] * padding_length_dec)
|
||||
labels = np.array(labels, dtype=np.int64)
|
||||
|
||||
# Loss mask
|
||||
loss_mask = ([1] * num_tokens_dec) + ([0] * padding_length_dec)
|
||||
loss_mask = np.array(loss_mask, dtype=np.int64)
|
||||
|
||||
return tokens_enc, tokens_dec_in, labels, enc_mask, dec_mask, enc_dec_mask, loss_mask
|
||||
|
||||
|
||||
def make_attention_mask(source_block, target_block):
|
||||
"""
|
||||
Returns a 2-dimensional (2-D) attention mask
|
||||
:param source_block: 1-D array
|
||||
:param target_block: 1-D array
|
||||
"""
|
||||
mask = (target_block[None, :] >= 1) * (source_block[:, None] >= 1)
|
||||
mask = mask.astype(np.int64)
|
||||
# (source_length, target_length)
|
||||
return mask
|
||||
|
||||
|
||||
def make_attention_mask_3d(source_block, target_block):
|
||||
"""
|
||||
Returns a 3-dimensional (3-D) attention mask
|
||||
:param source_block: 1-D array
|
||||
:param target_block: 1-D array
|
||||
"""
|
||||
mask = (target_block[:, None, :] >= 1) * (source_block[:, :, None] >= 1)
|
||||
# (batch, source_length, target_length)
|
||||
# mask = mask.astype(np.int64)
|
||||
return mask
|
||||
|
||||
|
||||
def make_history_mask(block):
|
||||
length = block.shape[0]
|
||||
arange = np.arange(length)
|
||||
history_mask = arange[None,] <= arange[:, None]
|
||||
history_mask = history_mask.astype(np.int64)
|
||||
return history_mask
|
||||
|
||||
|
||||
def make_history_mask_3d(block):
|
||||
batch, length = block.shape
|
||||
arange = torch.arange(length, device=block.device)
|
||||
history_mask = (arange[None,] <= arange[:, None])[
|
||||
None,
|
||||
]
|
||||
history_mask = history_mask.expand(batch, length, length)
|
||||
return history_mask
|
|
@ -0,0 +1,18 @@
|
|||
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# from nemo.collections.nlp.models.language_modeling.megatron.bert_model import BertModel
|
||||
from nemo.collections.nlp.models.language_modeling.megatron.gpt_model import GPTModel
|
||||
|
||||
# from nemo.collections.nlp.models.language_modeling.megatron.t5_model import T5Model
|
|
@ -0,0 +1,221 @@
|
|||
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""BERT model."""
|
||||
|
||||
import torch
|
||||
from apex.transformer import tensor_parallel
|
||||
|
||||
from nemo.collections.nlp.modules.common.megatron.enums import AttnMaskType
|
||||
from nemo.collections.nlp.modules.common.megatron.fused_layer_norm import MixedFusedLayerNorm as LayerNorm
|
||||
from nemo.collections.nlp.modules.common.megatron.language_model import get_language_model, parallel_lm_logits
|
||||
from nemo.collections.nlp.modules.common.megatron.module import MegatronModule
|
||||
from nemo.collections.nlp.modules.common.megatron.utils import (
|
||||
erf_gelu,
|
||||
get_linear_layer,
|
||||
init_method_normal,
|
||||
openai_gelu,
|
||||
scaled_init_method_normal,
|
||||
)
|
||||
|
||||
|
||||
def bert_extended_attention_mask(attention_mask):
|
||||
# We create a 3D attention mask from a 2D tensor mask.
|
||||
# [b, 1, s]
|
||||
attention_mask_b1s = attention_mask.unsqueeze(1)
|
||||
# [b, s, 1]
|
||||
attention_mask_bs1 = attention_mask.unsqueeze(2)
|
||||
# [b, s, s]
|
||||
attention_mask_bss = attention_mask_b1s * attention_mask_bs1
|
||||
# [b, 1, s, s]
|
||||
extended_attention_mask = attention_mask_bss.unsqueeze(1)
|
||||
|
||||
# Convert attention mask to binary:
|
||||
extended_attention_mask = extended_attention_mask < 0.5
|
||||
|
||||
return extended_attention_mask
|
||||
|
||||
|
||||
def bert_position_ids(token_ids):
|
||||
# Create position ids
|
||||
seq_length = token_ids.size(1)
|
||||
position_ids = torch.arange(seq_length, dtype=torch.long, device=token_ids.device)
|
||||
position_ids = position_ids.unsqueeze(0).expand_as(token_ids)
|
||||
|
||||
return position_ids
|
||||
|
||||
|
||||
class BertLMHead(MegatronModule):
|
||||
"""Masked LM head for Bert
|
||||
|
||||
Arguments:
|
||||
mpu_vocab_size: model parallel size of vocabulary.
|
||||
hidden_size: hidden size
|
||||
init_method: init method for weight initialization
|
||||
layernorm_epsilon: tolerance for layer norm divisions
|
||||
parallel_output: whether output logits being distributed or not.
|
||||
"""
|
||||
|
||||
def __init__(self, mpu_vocab_size, hidden_size, init_method, layernorm_epsilon, parallel_output):
|
||||
|
||||
super(BertLMHead, self).__init__()
|
||||
|
||||
args = get_args()
|
||||
|
||||
self.bias = torch.nn.Parameter(torch.zeros(mpu_vocab_size))
|
||||
parallel_state.set_tensor_model_parallel_attributes(self.bias, True, 0, 1)
|
||||
self.parallel_output = parallel_output
|
||||
|
||||
self.dense = get_linear_layer(hidden_size, hidden_size, init_method)
|
||||
self.layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon)
|
||||
self.gelu = torch.nn.functional.gelu
|
||||
if args.openai_gelu:
|
||||
self.gelu = openai_gelu
|
||||
elif args.onnx_safe:
|
||||
self.gelu = erf_gelu
|
||||
|
||||
def forward(self, hidden_states, word_embeddings_weight):
|
||||
hidden_states = self.dense(hidden_states)
|
||||
hidden_states = self.gelu(hidden_states)
|
||||
hidden_states = self.layernorm(hidden_states)
|
||||
output = parallel_lm_logits(hidden_states, word_embeddings_weight, self.parallel_output, bias=self.bias)
|
||||
return output
|
||||
|
||||
|
||||
def post_language_model_processing(
|
||||
lm_output, pooled_output, lm_head, binary_head, lm_labels, logit_weights, fp16_lm_cross_entropy
|
||||
):
|
||||
# Output.
|
||||
lm_logits = lm_head(lm_output, logit_weights)
|
||||
|
||||
binary_logits = None
|
||||
if binary_head is not None:
|
||||
binary_logits = binary_head(pooled_output)
|
||||
|
||||
if lm_labels is None:
|
||||
return lm_logits, binary_logits
|
||||
else:
|
||||
if fp16_lm_cross_entropy:
|
||||
assert lm_logits.dtype == torch.half
|
||||
lm_loss = tensor_parallel.vocab_parallel_cross_entropy(lm_logits, lm_labels)
|
||||
else:
|
||||
lm_loss = tensor_parallel.vocab_parallel_cross_entropy(lm_logits.float(), lm_labels)
|
||||
return lm_loss, binary_logits
|
||||
|
||||
|
||||
class BertModel(MegatronModule):
|
||||
"""Bert Language model."""
|
||||
|
||||
def __init__(
|
||||
self, num_tokentypes=2, add_binary_head=True, parallel_output=True, pre_process=True, post_process=True
|
||||
):
|
||||
super(BertModel, self).__init__()
|
||||
args = get_args()
|
||||
|
||||
self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy
|
||||
self.add_binary_head = add_binary_head
|
||||
self.parallel_output = parallel_output
|
||||
self.pre_process = pre_process
|
||||
self.post_process = post_process
|
||||
|
||||
init_method = init_method_normal(args.init_method_std)
|
||||
scaled_init_method = scaled_init_method_normal(args.init_method_std, args.num_layers)
|
||||
|
||||
self.language_model, self._language_model_key = get_language_model(
|
||||
num_tokentypes=num_tokentypes,
|
||||
add_pooler=self.add_binary_head,
|
||||
encoder_attn_mask_type=AttnMaskType.padding,
|
||||
init_method=init_method,
|
||||
scaled_init_method=scaled_init_method,
|
||||
pre_process=self.pre_process,
|
||||
post_process=self.post_process,
|
||||
)
|
||||
|
||||
self.initialize_word_embeddings(init_method_normal)
|
||||
if self.post_process:
|
||||
self.lm_head = BertLMHead(
|
||||
self.word_embeddings_weight().size(0),
|
||||
args.hidden_size,
|
||||
init_method,
|
||||
args.layernorm_epsilon,
|
||||
parallel_output,
|
||||
)
|
||||
self._lm_head_key = 'lm_head'
|
||||
self.binary_head = None
|
||||
if self.add_binary_head:
|
||||
self.binary_head = get_linear_layer(args.hidden_size, 2, init_method)
|
||||
self._binary_head_key = 'binary_head'
|
||||
|
||||
def set_input_tensor(self, input_tensor):
|
||||
"""See megatron.model.transformer.set_input_tensor()"""
|
||||
self.language_model.set_input_tensor(input_tensor)
|
||||
|
||||
def forward(self, bert_model_input, attention_mask, tokentype_ids=None, lm_labels=None):
|
||||
|
||||
extended_attention_mask = bert_extended_attention_mask(attention_mask)
|
||||
input_ids = bert_model_input
|
||||
position_ids = bert_position_ids(input_ids)
|
||||
|
||||
lm_output = self.language_model(input_ids, position_ids, extended_attention_mask, tokentype_ids=tokentype_ids)
|
||||
|
||||
if self.post_process and self.add_binary_head:
|
||||
lm_output, pooled_output = lm_output
|
||||
else:
|
||||
pooled_output = None
|
||||
|
||||
if self.post_process:
|
||||
return post_language_model_processing(
|
||||
lm_output,
|
||||
pooled_output,
|
||||
self.lm_head,
|
||||
self.binary_head,
|
||||
lm_labels,
|
||||
self.word_embeddings_weight(),
|
||||
self.fp16_lm_cross_entropy,
|
||||
)
|
||||
else:
|
||||
return lm_output
|
||||
|
||||
def state_dict_for_save_checkpoint(self, destination=None, prefix='', keep_vars=False):
|
||||
"""For easy load when model is combined with other heads,
|
||||
add an extra key."""
|
||||
|
||||
state_dict_ = {}
|
||||
state_dict_[self._language_model_key] = self.language_model.state_dict_for_save_checkpoint(
|
||||
destination, prefix, keep_vars
|
||||
)
|
||||
if self.post_process:
|
||||
state_dict_[self._lm_head_key] = self.lm_head.state_dict_for_save_checkpoint(
|
||||
destination, prefix, keep_vars
|
||||
)
|
||||
if self.post_process and self.add_binary_head:
|
||||
state_dict_[self._binary_head_key] = self.binary_head.state_dict(destination, prefix, keep_vars)
|
||||
# Save word_embeddings.
|
||||
if self.post_process and not self.pre_process:
|
||||
state_dict_[self._word_embeddings_for_head_key] = self.word_embeddings.state_dict(
|
||||
destination, prefix, keep_vars
|
||||
)
|
||||
return state_dict_
|
||||
|
||||
def load_state_dict(self, state_dict, strict=True):
|
||||
"""Customized load."""
|
||||
|
||||
self.language_model.load_state_dict(state_dict[self._language_model_key], strict=strict)
|
||||
if self.post_process:
|
||||
self.lm_head.load_state_dict(state_dict[self._lm_head_key], strict=strict)
|
||||
if self.post_process and self.add_binary_head:
|
||||
self.binary_head.load_state_dict(state_dict[self._binary_head_key], strict=strict)
|
||||
# Load word_embeddings.
|
||||
if self.post_process and not self.pre_process:
|
||||
self.word_embeddings.load_state_dict(state_dict[self._word_embeddings_for_head_key], strict=strict)
|
|
@ -0,0 +1,191 @@
|
|||
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""GPT-2 model."""
|
||||
|
||||
import torch
|
||||
from apex.transformer import tensor_parallel
|
||||
from apex.transformer.enums import AttnMaskType
|
||||
|
||||
from nemo.collections.nlp.modules.common.megatron.language_model import get_language_model, parallel_lm_logits
|
||||
from nemo.collections.nlp.modules.common.megatron.module import MegatronModule
|
||||
from nemo.collections.nlp.modules.common.megatron.utils import init_method_normal, scaled_init_method_normal
|
||||
|
||||
|
||||
def post_language_model_processing(
|
||||
lm_output,
|
||||
labels,
|
||||
logit_weights,
|
||||
get_key_value,
|
||||
parallel_output,
|
||||
forward_method_parallel_output,
|
||||
fp16_lm_cross_entropy,
|
||||
):
|
||||
if get_key_value:
|
||||
lm_output, presents = lm_output
|
||||
|
||||
# Output.
|
||||
if forward_method_parallel_output is not None:
|
||||
parallel_output = forward_method_parallel_output
|
||||
output = parallel_lm_logits(lm_output, logit_weights, parallel_output)
|
||||
|
||||
if get_key_value:
|
||||
output = [output, presents]
|
||||
|
||||
if labels is None:
|
||||
return output
|
||||
else:
|
||||
if fp16_lm_cross_entropy:
|
||||
assert output.dtype == torch.half
|
||||
loss = tensor_parallel.vocab_parallel_cross_entropy(output, labels)
|
||||
else:
|
||||
loss = tensor_parallel.vocab_parallel_cross_entropy(output.float(), labels)
|
||||
return loss
|
||||
|
||||
|
||||
class GPTModel(MegatronModule):
|
||||
"""GPT-2 Language model."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size,
|
||||
hidden_size,
|
||||
max_position_embeddings,
|
||||
num_layers,
|
||||
num_attention_heads,
|
||||
ffn_hidden_size,
|
||||
apply_query_key_layer_scaling=True,
|
||||
kv_channels=None,
|
||||
num_tokentypes=0,
|
||||
parallel_output=True,
|
||||
pre_process=True,
|
||||
post_process=True,
|
||||
init_method_std=0.02,
|
||||
fp16_lm_cross_entropy=False,
|
||||
use_cpu_initialization=False,
|
||||
hidden_dropout=0.1,
|
||||
fused_fp16=False,
|
||||
fused_bf16=False,
|
||||
fp32_residual_connection=False,
|
||||
activations_checkpoint_method=None,
|
||||
activations_checkpoint_num_layers=1,
|
||||
layernorm_epsilon=1e-5,
|
||||
bias_gelu_fusion=True,
|
||||
openai_gelu=False,
|
||||
onnx_safe=False,
|
||||
):
|
||||
super(GPTModel, self).__init__()
|
||||
|
||||
self.parallel_output = parallel_output
|
||||
self.pre_process = pre_process
|
||||
self.post_process = post_process
|
||||
self.fp16_lm_cross_entropy = fp16_lm_cross_entropy
|
||||
|
||||
assert not (fused_fp16 and fused_bf16), "both fused_fp16 and fused_bf16 flags cannot be True at the same time."
|
||||
|
||||
if kv_channels is None:
|
||||
assert (
|
||||
hidden_size % num_attention_heads == 0
|
||||
), 'hidden_size must be divisible by num_attention_heads if kv_channels is None'
|
||||
kv_channels = hidden_size // num_attention_heads
|
||||
|
||||
self.language_model, self._language_model_key = get_language_model(
|
||||
vocab_size=vocab_size,
|
||||
hidden_size=hidden_size,
|
||||
hidden_dropout=hidden_dropout,
|
||||
num_tokentypes=num_tokentypes,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
num_layers=num_layers,
|
||||
num_attention_heads=num_attention_heads,
|
||||
apply_query_key_layer_scaling=apply_query_key_layer_scaling,
|
||||
kv_channels=kv_channels,
|
||||
ffn_hidden_size=ffn_hidden_size,
|
||||
add_pooler=False,
|
||||
encoder_attn_mask_type=AttnMaskType.causal,
|
||||
init_method=init_method_normal(init_method_std),
|
||||
scaled_init_method=scaled_init_method_normal(init_method_std, num_layers),
|
||||
pre_process=self.pre_process,
|
||||
post_process=self.post_process,
|
||||
init_method_std=init_method_std,
|
||||
use_cpu_initialization=use_cpu_initialization,
|
||||
fused_fp16=fused_fp16,
|
||||
fused_bf16=fused_bf16,
|
||||
fp32_residual_connection=fp32_residual_connection,
|
||||
activations_checkpoint_method=activations_checkpoint_method,
|
||||
activations_checkpoint_num_layers=activations_checkpoint_num_layers,
|
||||
layernorm_epsilon=layernorm_epsilon,
|
||||
bias_gelu_fusion=bias_gelu_fusion,
|
||||
openai_gelu=openai_gelu,
|
||||
onnx_safe=onnx_safe,
|
||||
)
|
||||
|
||||
self.initialize_word_embeddings(
|
||||
init_method=init_method_normal(init_method_std), vocab_size=vocab_size, hidden_size=hidden_size
|
||||
)
|
||||
|
||||
def set_input_tensor(self, input_tensor):
|
||||
"""See megatron.model.transformer.set_input_tensor()"""
|
||||
self.language_model.set_input_tensor(input_tensor)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids,
|
||||
position_ids,
|
||||
attention_mask,
|
||||
labels=None,
|
||||
tokentype_ids=None,
|
||||
layer_past=None,
|
||||
get_key_value=False,
|
||||
forward_method_parallel_output=None,
|
||||
):
|
||||
|
||||
lm_output = self.language_model(
|
||||
input_ids, position_ids, attention_mask, layer_past=layer_past, get_key_value=get_key_value
|
||||
)
|
||||
|
||||
if self.post_process:
|
||||
return post_language_model_processing(
|
||||
lm_output,
|
||||
labels,
|
||||
self.word_embeddings_weight(),
|
||||
get_key_value,
|
||||
self.parallel_output,
|
||||
forward_method_parallel_output,
|
||||
self.fp16_lm_cross_entropy,
|
||||
)
|
||||
else:
|
||||
return lm_output
|
||||
|
||||
def state_dict_for_save_checkpoint(self, destination=None, prefix='', keep_vars=False):
|
||||
|
||||
state_dict_ = {}
|
||||
state_dict_[self._language_model_key] = self.language_model.state_dict_for_save_checkpoint(
|
||||
destination, prefix, keep_vars
|
||||
)
|
||||
# Save word_embeddings.
|
||||
if self.post_process and not self.pre_process:
|
||||
state_dict_[self._word_embeddings_for_head_key] = self.word_embeddings.state_dict(
|
||||
destination, prefix, keep_vars
|
||||
)
|
||||
return state_dict_
|
||||
|
||||
def load_state_dict(self, state_dict, strict=True):
|
||||
"""Customized load."""
|
||||
|
||||
# Load word_embeddings.
|
||||
if self.post_process and not self.pre_process:
|
||||
self.word_embeddings.load_state_dict(state_dict[self._word_embeddings_for_head_key], strict=strict)
|
||||
if self._language_model_key in state_dict:
|
||||
state_dict = state_dict[self._language_model_key]
|
||||
self.language_model.load_state_dict(state_dict, strict=strict)
|
|
@ -0,0 +1,163 @@
|
|||
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""T5 model."""
|
||||
|
||||
import torch
|
||||
from apex.transformer import tensor_parallel
|
||||
from apex.transformer.enums import AttnMaskType
|
||||
from megatron import get_args
|
||||
|
||||
from nemo.collections.nlp.modules.common.megatron.language_model import get_language_model, parallel_lm_logits
|
||||
from nemo.collections.nlp.modules.common.megatron.module import MegatronModule
|
||||
from nemo.collections.nlp.modules.common.megatron.transformer import LayerNorm
|
||||
from nemo.collections.nlp.modules.common.megatron.utils import init_method_normal, scaled_init_method_normal
|
||||
|
||||
|
||||
def t5_extended_attention_mask(attention_mask_list):
|
||||
def attn_mask_postprocess(attn_mask):
|
||||
# [b, 1, s, s]
|
||||
extended_attention_mask = attn_mask.unsqueeze(1)
|
||||
return extended_attention_mask
|
||||
|
||||
return [attn_mask_postprocess(attn_mask) for attn_mask in attention_mask_list]
|
||||
|
||||
|
||||
def t5_position_ids(token_ids):
|
||||
# Create position ids
|
||||
seq_length = token_ids.size(1)
|
||||
position_ids = torch.arange(seq_length, dtype=torch.long, device=token_ids.device)
|
||||
position_ids = position_ids.unsqueeze(0).expand_as(token_ids)
|
||||
|
||||
return position_ids
|
||||
|
||||
|
||||
class T5LMHead(MegatronModule):
|
||||
"""Masked LM head for T5
|
||||
|
||||
Arguments:
|
||||
mpu_vocab_size: model parallel size of vocabulary.
|
||||
hidden_size: hidden size
|
||||
init_method: init method for weight initialization
|
||||
layernorm_epsilon: tolerance for layer norm divisions
|
||||
parallel_output: wether output logits being distributed or not.
|
||||
"""
|
||||
|
||||
def __init__(self, mpu_vocab_size, parallel_output):
|
||||
super(T5LMHead, self).__init__()
|
||||
|
||||
args = get_args()
|
||||
|
||||
self.bias = torch.nn.Parameter(torch.zeros(mpu_vocab_size))
|
||||
self.bias.model_parallel = True
|
||||
self.bias.partition_dim = 0
|
||||
self.bias.stride = 1
|
||||
self.parallel_output = parallel_output
|
||||
|
||||
def forward(self, hidden_states, word_embeddings_weight):
|
||||
output = parallel_lm_logits(hidden_states, word_embeddings_weight, self.parallel_output, bias=self.bias)
|
||||
return output
|
||||
|
||||
|
||||
class T5Model(MegatronModule):
|
||||
"""T5 Language model."""
|
||||
|
||||
def __init__(self, num_tokentypes=0, parallel_output=True):
|
||||
super(T5Model, self).__init__()
|
||||
args = get_args()
|
||||
|
||||
self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy
|
||||
self.parallel_output = parallel_output
|
||||
init_method = init_method_normal(args.init_method_std)
|
||||
scaled_init_method = scaled_init_method_normal(args.init_method_std, args.num_layers)
|
||||
|
||||
self.language_model, self._language_model_key = get_language_model(
|
||||
num_tokentypes=num_tokentypes,
|
||||
add_pooler=False,
|
||||
add_decoder=True,
|
||||
encoder_attn_mask_type=AttnMaskType.padding,
|
||||
init_method=init_method,
|
||||
scaled_init_method=scaled_init_method,
|
||||
)
|
||||
|
||||
self.lm_head = T5LMHead(self.language_model.embedding.word_embeddings.weight.size(0), parallel_output)
|
||||
self._lm_head_key = 'lm_head'
|
||||
|
||||
def set_input_tensor(self, input_tensor):
|
||||
"""See megatron.model.transformer.set_input_tensor()"""
|
||||
self.language_model.set_input_tensor(input_tensor)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
encoder_input_ids,
|
||||
decoder_input_ids,
|
||||
encoder_attn_mask,
|
||||
decoder_attn_mask,
|
||||
encoder_decoder_attn_mask,
|
||||
tokentype_ids=None,
|
||||
lm_labels=None,
|
||||
enc_hidden_states=None,
|
||||
):
|
||||
|
||||
# Converting the attention masks to proper parameter settings
|
||||
encoder_attn_mask, decoder_attn_mask, encoder_decoder_attn_mask = t5_extended_attention_mask(
|
||||
[encoder_attn_mask, decoder_attn_mask, encoder_decoder_attn_mask]
|
||||
)
|
||||
|
||||
encoder_position_ids = t5_position_ids(encoder_input_ids)
|
||||
decoder_position_ids = t5_position_ids(decoder_input_ids)
|
||||
|
||||
lm_output = self.language_model(
|
||||
encoder_input_ids,
|
||||
encoder_position_ids,
|
||||
encoder_attn_mask,
|
||||
decoder_input_ids,
|
||||
decoder_position_ids,
|
||||
decoder_attn_mask,
|
||||
encoder_decoder_attn_mask,
|
||||
tokentype_ids=tokentype_ids,
|
||||
enc_hidden_states=enc_hidden_states,
|
||||
)
|
||||
|
||||
decoder_output, encoder_output = lm_output
|
||||
|
||||
# Output.
|
||||
lm_logits = self.lm_head(decoder_output, self.language_model.embedding.word_embeddings.weight)
|
||||
|
||||
if lm_labels is None:
|
||||
return lm_logits, encoder_output
|
||||
else:
|
||||
if self.fp16_lm_cross_entropy:
|
||||
assert lm_logits.dtype == torch.half
|
||||
lm_loss = tensor_parallel.vocab_parallel_cross_entropy(lm_logits, lm_labels)
|
||||
else:
|
||||
lm_loss = tensor_parallel.vocab_parallel_cross_entropy(lm_logits.float(), lm_labels)
|
||||
return lm_loss, encoder_output
|
||||
|
||||
def state_dict_for_save_checkpoint(self, destination=None, prefix='', keep_vars=False):
|
||||
"""For easy load when model is combined with other heads,
|
||||
add an extra key."""
|
||||
|
||||
state_dict_ = {}
|
||||
state_dict_[self._language_model_key] = self.language_model.state_dict_for_save_checkpoint(
|
||||
destination, prefix, keep_vars
|
||||
)
|
||||
state_dict_[self._lm_head_key] = self.lm_head.state_dict_for_save_checkpoint(destination, prefix, keep_vars)
|
||||
return state_dict_
|
||||
|
||||
def load_state_dict(self, state_dict, strict=True):
|
||||
"""Customized load."""
|
||||
|
||||
self.language_model.load_state_dict(state_dict[self._language_model_key], strict=strict)
|
||||
self.lm_head.load_state_dict(state_dict[self._lm_head_key], strict=strict)
|
|
@ -0,0 +1,406 @@
|
|||
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import re
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import torch
|
||||
from apex.transformer import parallel_state, tensor_parallel
|
||||
from omegaconf.dictconfig import DictConfig
|
||||
from pytorch_lightning.trainer.trainer import Trainer
|
||||
|
||||
from nemo.collections.nlp.data.language_modeling.megatron.data_samplers import (
|
||||
MegatronPretrainingRandomSampler,
|
||||
MegatronPretrainingSampler,
|
||||
)
|
||||
from nemo.collections.nlp.data.language_modeling.megatron.gpt_dataset import build_train_valid_test_datasets
|
||||
from nemo.collections.nlp.models.language_modeling.megatron.gpt_model import GPTModel
|
||||
from nemo.collections.nlp.models.nlp_model import NLPModel
|
||||
from nemo.collections.nlp.modules.common.megatron.clip_grads import clip_grad_norm_fp32
|
||||
from nemo.collections.nlp.modules.common.megatron.megatron_init import (
|
||||
initialize_model_parallel_for_nemo,
|
||||
set_jit_fusion_options,
|
||||
)
|
||||
from nemo.collections.nlp.modules.common.megatron.utils import (
|
||||
average_losses_across_data_parallel_group,
|
||||
get_ltor_masks_and_position_ids,
|
||||
)
|
||||
from nemo.collections.nlp.modules.common.tokenizer_utils import get_nmt_tokenizer
|
||||
from nemo.collections.nlp.parts.nlp_overrides import NLPNativeMixedPrecisionPlugin
|
||||
from nemo.utils import AppState, logging
|
||||
|
||||
|
||||
class MegatronGPTModel(NLPModel):
|
||||
"""
|
||||
Megatron GPT pretraining
|
||||
"""
|
||||
|
||||
def __init__(self, cfg: DictConfig, trainer: Trainer):
|
||||
super().__init__(cfg, trainer=trainer)
|
||||
self.cfg = cfg
|
||||
|
||||
if self.cfg.get('use_cpu_initialization', False) is False:
|
||||
torch.cuda.set_device(trainer.local_rank)
|
||||
|
||||
# buffer used during train_step for logging average loss over gradient accumulation steps
|
||||
self._reduced_loss_buffer = []
|
||||
|
||||
initialize_model_parallel_for_nemo(
|
||||
world_size=trainer.world_size,
|
||||
global_rank=trainer.global_rank,
|
||||
local_rank=trainer.local_rank,
|
||||
tensor_model_parallel_size=cfg.get('tensor_model_parallel_size', 1),
|
||||
seed=self.cfg.get('seed', 1234),
|
||||
)
|
||||
|
||||
if not self.cfg.get('fused_bf16'):
|
||||
set_jit_fusion_options()
|
||||
|
||||
self.tokenizer = get_nmt_tokenizer(
|
||||
library=self.cfg.tokenizer.library,
|
||||
model_name=self.cfg.tokenizer.type,
|
||||
tokenizer_model=self.register_artifact("tokenizer_model", self.cfg.tokenizer.model),
|
||||
vocab_file=self.register_artifact("vocab_file", self.cfg.tokenizer.vocab_file),
|
||||
merges_file=self.register_artifact("merges_file", self.cfg.tokenizer.merge_file),
|
||||
)
|
||||
|
||||
vocab_size = self.tokenizer.vocab_size
|
||||
|
||||
padded_vocab_size = self._vocab_size_with_padding(
|
||||
orig_vocab_size=vocab_size,
|
||||
make_vocab_size_divisible_by=cfg.get('make_vocab_size_divisible_by', 128),
|
||||
tensor_model_parallel_size=cfg.get('tensor_model_parallel_size', 1),
|
||||
)
|
||||
|
||||
self.model = GPTModel(
|
||||
vocab_size=padded_vocab_size,
|
||||
hidden_size=cfg.hidden_size,
|
||||
max_position_embeddings=cfg.max_position_embeddings,
|
||||
num_layers=cfg.num_layers,
|
||||
num_attention_heads=cfg.num_attention_heads,
|
||||
apply_query_key_layer_scaling=cfg.get('apply_query_key_layer_scaling', True),
|
||||
kv_channels=cfg.get('kv_channels', None),
|
||||
ffn_hidden_size=cfg.ffn_hidden_size,
|
||||
num_tokentypes=0,
|
||||
parallel_output=True,
|
||||
pre_process=cfg.get('pre_process', True),
|
||||
post_process=cfg.get('post_process', True),
|
||||
init_method_std=cfg.get('init_method_std', 0.02),
|
||||
fp16_lm_cross_entropy=cfg.get('fp16_lm_cross_entropy', False),
|
||||
use_cpu_initialization=cfg.get('use_cpu_initialization', False),
|
||||
hidden_dropout=cfg.get('hidden_dropout', 0.1),
|
||||
fused_fp16=cfg.get('fused_fp16', False),
|
||||
fused_bf16=cfg.get('fused_bf16', False),
|
||||
fp32_residual_connection=cfg.get('fp32_residual_connection', False),
|
||||
activations_checkpoint_method=cfg.get('activations_checkpoint_method', None),
|
||||
activations_checkpoint_num_layers=cfg.get('activations_checkpoint_num_layers', 1),
|
||||
layernorm_epsilon=cfg.get('layernorm_epsilon', 1e-5),
|
||||
onnx_safe=cfg.get('onnx_safe', False),
|
||||
)
|
||||
|
||||
def forward(self, tokens, position_ids, attention_mask, labels):
|
||||
output_tensor = self.model(tokens, position_ids, attention_mask, labels=labels)
|
||||
return output_tensor
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
tokens, labels, loss_mask, attention_mask, position_ids = self.process_batch(batch)
|
||||
output_tensor = self(tokens, position_ids, attention_mask, labels)
|
||||
loss = self.loss_func(loss_mask, output_tensor)
|
||||
reduced_loss = average_losses_across_data_parallel_group([loss])
|
||||
|
||||
# cache reduced loss while accumulating gradients
|
||||
self._reduced_loss_buffer.append(reduced_loss[0])
|
||||
|
||||
if (batch_idx + 1) % self.trainer.accumulate_grad_batches == 0:
|
||||
# Reduced loss for logging.
|
||||
average_reduced_loss = sum(self._reduced_loss_buffer) / len(self._reduced_loss_buffer)
|
||||
self.log('reduced_train_loss', average_reduced_loss, prog_bar=True)
|
||||
lr = self._optimizer.param_groups[0]['lr']
|
||||
self.log('lr', lr)
|
||||
self.log('global_step', self.trainer.global_step, prog_bar=True)
|
||||
self.log('consumed_samples', self.compute_consumed_samples(self.trainer.global_step), prog_bar=True)
|
||||
self._reduced_loss_buffer = []
|
||||
return loss
|
||||
|
||||
def validation_step(self, batch, batch_idx):
|
||||
tokens, labels, loss_mask, attention_mask, position_ids = self.process_batch(batch)
|
||||
output_tensor = self(tokens, position_ids, attention_mask, labels)
|
||||
loss = self.loss_func(loss_mask, output_tensor)
|
||||
# TODO: add text generation
|
||||
# take the first k tokens and then generate text - compare with ground truth (won't look similar in general)
|
||||
"""
|
||||
k = num_context
|
||||
n = max_generate_length
|
||||
context_tokens = tokens[0:k]
|
||||
while k < n:
|
||||
output_tensor = self(context_tokens)
|
||||
next_token = sample(output_tensor)
|
||||
context_tokens.append(next_token)
|
||||
k += 1
|
||||
"""
|
||||
return loss
|
||||
|
||||
def validation_epoch_end(self, outputs):
|
||||
averaged_loss = average_losses_across_data_parallel_group(outputs)
|
||||
self.log('val_loss', averaged_loss[0], prog_bar=True)
|
||||
self.log('consumed_samples', self.compute_consumed_samples(self.trainer.global_step))
|
||||
|
||||
def test_step(self, batch, batch_idx):
|
||||
return self.validation_step(batch, batch_idx)
|
||||
|
||||
def test_epoch_end(self, outputs):
|
||||
averaged_loss = average_losses_across_data_parallel_group(outputs)
|
||||
logging.info(f'test_loss: {averaged_loss[0]}')
|
||||
|
||||
def loss_func(self, loss_mask, output_tensor):
|
||||
losses = output_tensor.float()
|
||||
loss_mask = loss_mask.view(-1).float()
|
||||
# TODO: add nemo version here
|
||||
loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum() # sequence level nll
|
||||
return loss
|
||||
|
||||
def process_batch(self, batch):
|
||||
|
||||
# Items and their type.
|
||||
keys = ['text']
|
||||
datatype = torch.int64
|
||||
|
||||
data = batch
|
||||
data_b = tensor_parallel.broadcast_data(keys, data, datatype)
|
||||
|
||||
# Unpack.
|
||||
tokens_ = data_b['text'].long()
|
||||
labels = tokens_[:, 1:].contiguous()
|
||||
tokens = tokens_[:, :-1].contiguous()
|
||||
|
||||
# Get the masks and postition ids.
|
||||
attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids(
|
||||
tokens,
|
||||
self.tokenizer.eos_id,
|
||||
self.cfg.data.get('reset_position_ids', False),
|
||||
self.cfg.data.get('reset_attention_mask', False),
|
||||
self.cfg.data.get('eod_mask_loss', False),
|
||||
)
|
||||
|
||||
return tokens, labels, loss_mask, attention_mask, position_ids
|
||||
|
||||
def build_train_valid_test_datasets(self):
|
||||
logging.info('Building GPT datasets.')
|
||||
global_batch_size = self.trainer.world_size * self.cfg.micro_batch_size / self.cfg.tensor_model_parallel_size
|
||||
eval_iters = (self.trainer.max_steps // self.trainer.val_check_interval + 1) * self.trainer.limit_val_batches
|
||||
test_iters = self.trainer.limit_test_batches
|
||||
train_valid_test_num_samples = [
|
||||
self.trainer.max_steps * global_batch_size,
|
||||
eval_iters * global_batch_size,
|
||||
test_iters * global_batch_size,
|
||||
]
|
||||
self._train_ds, self._validation_ds, self._test_ds = build_train_valid_test_datasets(
|
||||
cfg=self.cfg,
|
||||
trainer=self.trainer,
|
||||
data_prefix=self.cfg.data.data_prefix,
|
||||
data_impl=self.cfg.data.data_impl,
|
||||
splits_string=self.cfg.data.splits_string,
|
||||
train_valid_test_num_samples=train_valid_test_num_samples,
|
||||
seq_length=self.cfg.data.seq_length,
|
||||
seed=self.cfg.seed,
|
||||
skip_warmup=self.cfg.data.get('skip_warmup', True),
|
||||
)
|
||||
if self._train_ds is not None:
|
||||
logging.info(f'Length of train dataset: {len(self._train_ds)}')
|
||||
if self._validation_ds is not None:
|
||||
logging.info(f'Length of val dataset: {len(self._validation_ds)}')
|
||||
if self._test_ds is not None:
|
||||
logging.info(f'Length of test dataset: {len(self._test_ds)}')
|
||||
logging.info(f'Finished building GPT datasets.')
|
||||
return self._train_ds, self._validation_ds, self._test_ds
|
||||
|
||||
def build_pretraining_data_loader(self, dataset, consumed_samples):
|
||||
"""Buld dataloader given an input dataset."""
|
||||
|
||||
if dataset is None:
|
||||
return None
|
||||
|
||||
# Megatron sampler
|
||||
if hasattr(self.cfg.data, 'dataloader_type') and self.cfg.data.dataloader_type is not None:
|
||||
if self.cfg.data.dataloader_type == 'single':
|
||||
batch_sampler = MegatronPretrainingSampler(
|
||||
total_samples=len(dataset),
|
||||
consumed_samples=consumed_samples,
|
||||
micro_batch_size=self.cfg.micro_batch_size,
|
||||
data_parallel_rank=parallel_state.get_data_parallel_rank(),
|
||||
data_parallel_size=parallel_state.get_data_parallel_world_size(),
|
||||
)
|
||||
elif self.cfg.data.dataloader_type == 'cyclic':
|
||||
batch_sampler = MegatronPretrainingRandomSampler(
|
||||
total_samples=len(dataset),
|
||||
consumed_samples=consumed_samples,
|
||||
micro_batch_size=self.cfg.micro_batch_size,
|
||||
data_parallel_rank=parallel_state.get_data_parallel_rank(),
|
||||
data_parallel_size=parallel_state.get_data_parallel_world_size(),
|
||||
)
|
||||
else:
|
||||
raise ValueError('cfg.data.dataloader_type must be "single" or "cyclic"')
|
||||
else:
|
||||
raise ValueError('cfg.data.dataloader_type not found. Must be "single" or "cyclic"')
|
||||
|
||||
# Torch dataloader.
|
||||
return torch.utils.data.DataLoader(
|
||||
dataset, batch_sampler=batch_sampler, num_workers=self.cfg.data.num_workers, pin_memory=True,
|
||||
)
|
||||
|
||||
def setup(self, stage=None):
|
||||
if stage == 'predict':
|
||||
return
|
||||
# TODO: consider adding a ModelPT guard to check if model is being restored.
|
||||
# allowing restored models to optionally setup datasets
|
||||
self.build_train_valid_test_datasets()
|
||||
self.setup_training_data(self.cfg.data)
|
||||
self.setup_validation_data(self.cfg.data)
|
||||
self.setup_test_data(self.cfg.data)
|
||||
|
||||
def setup_training_data(self, cfg):
|
||||
if hasattr(self, '_train_ds'):
|
||||
resume_checkpoint_path = self.trainer.checkpoint_connector.resume_checkpoint_path
|
||||
if resume_checkpoint_path:
|
||||
consumed_samples = int(
|
||||
float(re.findall(r"consumed_samples\=([0-9]+.[0-9]+)", resume_checkpoint_path)[0])
|
||||
)
|
||||
else:
|
||||
consumed_samples = 0
|
||||
logging.info(
|
||||
f'Setting up train dataloader with len(len(self._train_ds)): {len(self._train_ds)} and consumed samples: {consumed_samples}'
|
||||
)
|
||||
self._train_dl = self.build_pretraining_data_loader(self._train_ds, consumed_samples)
|
||||
|
||||
def setup_validation_data(self, cfg):
|
||||
if hasattr(self, '_validation_ds'):
|
||||
consumed_samples = 0
|
||||
logging.info(
|
||||
f'Setting up validation dataloader with len(len(self._validation_ds)): {len(self._validation_ds)} and consumed samples: {consumed_samples}'
|
||||
)
|
||||
self._validation_dl = self.build_pretraining_data_loader(self._validation_ds, consumed_samples)
|
||||
|
||||
def setup_test_data(self, cfg):
|
||||
if hasattr(self, '_test_ds'):
|
||||
consumed_samples = 0
|
||||
logging.info(
|
||||
f'Setting up test dataloader with len(len(self._test_ds)): {len(self._test_ds)} and consumed samples: {consumed_samples}'
|
||||
)
|
||||
self._test_dl = self.build_pretraining_data_loader(self._test_ds, consumed_samples)
|
||||
|
||||
def compute_consumed_samples(self, global_step):
|
||||
app_state = AppState()
|
||||
consumed_samples = (
|
||||
global_step
|
||||
* app_state.data_parallel_size
|
||||
* self.cfg.micro_batch_size
|
||||
* self.trainer.accumulate_grad_batches
|
||||
)
|
||||
return int(consumed_samples)
|
||||
|
||||
def on_before_optimizer_step(self, optimizer, optimizer_idx):
|
||||
"""PTL hook that is called after unscaling gradients when using native amp.
|
||||
We use gradient clipping implementation from megatron-lm.
|
||||
"""
|
||||
clip_val = self.trainer.gradient_clip_val
|
||||
if clip_val is None:
|
||||
return
|
||||
|
||||
clip_val = float(clip_val)
|
||||
if clip_val <= 0:
|
||||
return
|
||||
|
||||
if isinstance(self.trainer.accelerator_connector.precision_plugin, NLPNativeMixedPrecisionPlugin):
|
||||
parameters = self.model.parameters()
|
||||
clip_grad_norm_fp32(parameters=parameters, max_norm=clip_val)
|
||||
else:
|
||||
return
|
||||
|
||||
def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: Optional[int] = None) -> Any:
|
||||
request = batch
|
||||
response = self.complete(request)
|
||||
return response
|
||||
|
||||
def complete(self, request: Dict):
|
||||
"""
|
||||
Autoregressively invokes language model in the inference mode
|
||||
Args:
|
||||
request: Dictionary with the following fields
|
||||
* prompt: a string which text the model should complete.
|
||||
* tokens_to_generate: how many tokens to generate while doing prompt completion.
|
||||
* stop_after_sentence: (default True) whether to stop generation once sentence end is reached.
|
||||
Returns:
|
||||
response: A python dictionary with the following fields
|
||||
* prompt: original text of the prompt
|
||||
* tokenized_prompt: list of (str) tokens from prompt
|
||||
* completion: a python dictionary with the following subfields:
|
||||
* tokens: a list of triples (token, token_id, log_prob) comprising completion
|
||||
* stop reason: either 'eos', 'sentence_end' or 'limit' indicating why generation stopped
|
||||
* text: completion text (as a single string)
|
||||
|
||||
"""
|
||||
response = {}
|
||||
self.freeze()
|
||||
logsoftmaxlayer = torch.nn.LogSoftmax(dim=-1)
|
||||
response['tokenized_prompt'] = request['tokenized_prompt']
|
||||
tokens = request['tokens']
|
||||
# naive greedy slow loop
|
||||
# TODO: add option for BeamSearchDecoder
|
||||
response['prompt'] = request['prompt']
|
||||
response['completion'] = {}
|
||||
response['completion']['stop reason'] = 'limit'
|
||||
for i in range(request.get("tokens_to_generate", 64)):
|
||||
attention_mask, _, position_ids = get_ltor_masks_and_position_ids(
|
||||
data=tokens,
|
||||
eod_token=self.tokenizer.eos_id,
|
||||
reset_position_ids=self.cfg.get('reset_position_ids', False),
|
||||
reset_attention_mask=self.cfg.get('reset_attention_mask', False),
|
||||
eod_mask_loss=self.cfg.get('eod_mask_loss', False),
|
||||
)
|
||||
# No labels during inference. Still need masks to not attend to the right
|
||||
output_tensor = self(tokens, position_ids, attention_mask, labels=None)
|
||||
output_tensor = tensor_parallel.gather_from_tensor_model_parallel_region(output_tensor)
|
||||
log_probs, token_ids = torch.max(logsoftmaxlayer(output_tensor), dim=-1)
|
||||
reached_eos = token_ids[0, -1].item() == self.tokenizer.eos_id
|
||||
tokens = torch.cat([torch.squeeze(tokens), token_ids[:, -1]])
|
||||
response['completion']["tokens"] = list(
|
||||
zip(self.tokenizer.ids_to_tokens(tokens), tokens.tolist(), log_probs.tolist()[0])
|
||||
)
|
||||
completion_text = self.tokenizer.ids_to_text(x[1] for x in response['completion']["tokens"])
|
||||
if reached_eos: # Will it actually ever reach that?
|
||||
response['completion']['stop reason'] = 'eos'
|
||||
break
|
||||
elif request.get("stop_after_sentence", True) and completion_text.endswith(('.', '!', '?')):
|
||||
response['completion']['stop reason'] = 'sentence_end'
|
||||
break
|
||||
tokens = torch.unsqueeze(tokens, 0)
|
||||
response['completion']["text"] = self.tokenizer.ids_to_text(x[1] for x in response['completion']["tokens"])
|
||||
self.unfreeze()
|
||||
return response
|
||||
|
||||
def list_available_models(self):
|
||||
return None
|
||||
|
||||
def _vocab_size_with_padding(self, orig_vocab_size, make_vocab_size_divisible_by, tensor_model_parallel_size):
|
||||
"""Pad vocab size so it is divisible by model parallel size and
|
||||
still having GPU friendly size."""
|
||||
|
||||
after = orig_vocab_size
|
||||
multiple = make_vocab_size_divisible_by * tensor_model_parallel_size
|
||||
while (after % multiple) != 0:
|
||||
after += 1
|
||||
logging.info(
|
||||
f'Padded vocab_size: {after}, original vocab_size: {orig_vocab_size}, dummy tokens: {after - orig_vocab_size}.'
|
||||
)
|
||||
return after
|
|
@ -15,32 +15,29 @@
|
|||
import hashlib
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
import tarfile
|
||||
import tempfile
|
||||
from typing import Any, Dict, Optional, Union
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import torch
|
||||
from megatron import mpu
|
||||
from megatron.checkpointing import get_checkpoint_version, set_checkpoint_version
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning.accelerators.accelerator import Accelerator
|
||||
from pytorch_lightning.core.saving import load_hparams_from_tags_csv, load_hparams_from_yaml
|
||||
from pytorch_lightning.utilities import rank_zero_only
|
||||
from pytorch_lightning.utilities.cloud_io import load as pl_load
|
||||
from pytorch_lightning.utilities.migration import pl_legacy_patch
|
||||
from transformers import TRANSFORMERS_CACHE
|
||||
|
||||
from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import AutoTokenizer
|
||||
from nemo.collections.nlp.modules import BertModule, MegatronBertEncoder
|
||||
from nemo.collections.nlp.modules.common.huggingface.huggingface_utils import VOCAB_FILE_NAME
|
||||
from nemo.collections.nlp.modules.common.megatron.megatron_bert import (
|
||||
get_megatron_checkpoint_version,
|
||||
set_megatron_checkpoint_version,
|
||||
)
|
||||
from nemo.collections.nlp.modules.common.megatron.megatron_encoder import MegatronEncoderModule
|
||||
from nemo.collections.nlp.modules.common.megatron.megatron_utils import compute_model_parallel_rank
|
||||
from nemo.collections.nlp.modules.common.tokenizer_utils import get_tokenizer
|
||||
from nemo.collections.nlp.parts.nlp_overrides import NLPCheckpointConnector
|
||||
from nemo.collections.nlp.parts.nlp_overrides import NLPCheckpointConnector, NLPSaveRestoreConnector
|
||||
from nemo.core.classes import ModelPT
|
||||
from nemo.core.classes.exportable import Exportable
|
||||
from nemo.core.connectors.save_restore_connector import SaveRestoreConnector
|
||||
from nemo.utils import AppState, logging
|
||||
from nemo.utils.get_rank import is_global_rank_zero
|
||||
|
||||
__all__ = ['NLPModel']
|
||||
|
||||
|
@ -55,6 +52,8 @@ class NLPModel(ModelPT, Exportable):
|
|||
|
||||
def __init__(self, cfg: DictConfig, trainer: Trainer = None):
|
||||
super().__init__(cfg, trainer)
|
||||
# handles model parallel save and restore logic
|
||||
self._save_restore_connector = NLPSaveRestoreConnector()
|
||||
self.set_world_size(trainer)
|
||||
|
||||
def register_artifact(
|
||||
|
@ -180,274 +179,31 @@ class NLPModel(ModelPT, Exportable):
|
|||
f'Registering tokenizer vocab for {self.tokenizer} is not yet supported. Please override this method if needed.'
|
||||
)
|
||||
|
||||
def _clip_gradients(self, optimizer, clip_val=None):
|
||||
""" Override of PTL Gradient Clipping.
|
||||
Enables model parallel gradient clipping from Megatron-LM.
|
||||
|
||||
Args:
|
||||
optimizer ([type]): [description]
|
||||
clip_val ([type], optional): [description]. Defaults to None.
|
||||
"""
|
||||
app_state = AppState()
|
||||
|
||||
# get clip_val from trainer if None is provided
|
||||
if clip_val is None:
|
||||
clip_val = float(self._trainer.gradient_clip_val)
|
||||
|
||||
if app_state.model_parallel_size is not None:
|
||||
model = self._trainer.get_model()
|
||||
parameters = model.parameters()
|
||||
if mpu.model_parallel_is_initialized():
|
||||
mpu.grads.clip_grad_norm(parameters=parameters, max_norm=clip_val)
|
||||
else:
|
||||
raise ValueError('Model parallel groups must be intialized to use model parallel gradient clipping.')
|
||||
|
||||
else:
|
||||
return Accelerator._clip_gradients(self, optimizer, clip_val)
|
||||
|
||||
def setup(self, stage: str) -> None:
|
||||
""" PTL hook that is called on all DDP processes. """
|
||||
|
||||
if stage == 'fit':
|
||||
|
||||
# adds self.bert_model config to .nemo file
|
||||
if hasattr(self, 'bert_model') and self.bert_model is not None:
|
||||
self.register_bert_model()
|
||||
|
||||
app_state = AppState()
|
||||
|
||||
if app_state.model_parallel_size is not None:
|
||||
|
||||
self._trainer.checkpoint_connector = NLPCheckpointConnector(self._trainer)
|
||||
|
||||
# # Configure checkpointing for model parallel
|
||||
# if app_state.create_checkpoint_callback:
|
||||
# # global rank 0 is configured by exp_manager
|
||||
# if not is_global_rank_zero() and app_state.data_parallel_rank == 0:
|
||||
# configure_checkpointing(
|
||||
# self._trainer,
|
||||
# app_state.log_dir,
|
||||
# app_state.checkpoint_name,
|
||||
# app_state.checkpoint_callback_params,
|
||||
# )
|
||||
|
||||
def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
|
||||
""" LightningModule hook that's used to save things in addition to model weights. """
|
||||
|
||||
if hasattr(self, "bert_model") and isinstance(self.bert_model, MegatronBertEncoder):
|
||||
checkpoint['checkpoint_version'] = get_checkpoint_version()
|
||||
checkpoint['checkpoint_version'] = get_megatron_checkpoint_version()
|
||||
return None
|
||||
|
||||
def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
|
||||
""" LightningModule hook that's used to restore things saved with on_save_checkpoint."""
|
||||
|
||||
if hasattr(self, "bert_model") and isinstance(self.bert_model, MegatronBertEncoder):
|
||||
if get_checkpoint_version():
|
||||
if get_megatron_checkpoint_version():
|
||||
assert (
|
||||
checkpoint['checkpoint_version'] == get_checkpoint_version()
|
||||
), 'checkpoint version found on_load_checkpoint different than get_checkpoint_version'
|
||||
checkpoint['checkpoint_version'] == get_megatron_checkpoint_version()
|
||||
), 'checkpoint version found on_load_checkpoint different than get_megatron_checkpoint_version'
|
||||
else:
|
||||
set_checkpoint_version(checkpoint['checkpoint_version'])
|
||||
set_megatron_checkpoint_version(checkpoint['checkpoint_version'])
|
||||
logging.info(f"Setting Megatron checkpoint version: {checkpoint['checkpoint_version']}")
|
||||
return None
|
||||
|
||||
# no rank check as model parallel models need to be saved on data parallel rank 0
|
||||
def save_to(self, save_path: str):
|
||||
"""
|
||||
Saves model instance (weights and configuration) into .nemo file
|
||||
You can use "restore_from" method to fully restore instance from .nemo file.
|
||||
|
||||
.nemo file is an archive (tar.gz) with the following:
|
||||
model_config.yaml - model configuration in .yaml format. You can deserialize this into cfg argument for model's constructor
|
||||
model_weights.ckpt - model checkpoint
|
||||
|
||||
Args:
|
||||
save_path: Path to .nemo file where model instance should be saved
|
||||
"""
|
||||
save_path = os.path.abspath(save_path)
|
||||
app_state = AppState()
|
||||
if app_state.model_parallel_size is not None:
|
||||
self._default_save_to(save_path)
|
||||
else:
|
||||
# super.save_to only runs on global rank 0
|
||||
return super().save_to(save_path)
|
||||
|
||||
def _default_save_to(self, save_path: str):
|
||||
app_state = AppState()
|
||||
if app_state.model_parallel_size is not None:
|
||||
# each model parallel rank creates a .nemo file
|
||||
# after all .nemo files are created, each rank
|
||||
# will add their checkpoint to global rank 0
|
||||
|
||||
base_dir = os.path.dirname(save_path) # use the directory to merge mp_rank .nemo files into one
|
||||
|
||||
# update save_path based on model parallel_rank
|
||||
base_path = os.path.splitext(save_path)[0] # everything except the extension
|
||||
|
||||
mp_save_path = f'{base_path}_mp_rank_{app_state.model_parallel_rank:02d}.nemo'
|
||||
|
||||
if app_state.data_parallel_rank == 0:
|
||||
super()._default_save_to(mp_save_path)
|
||||
|
||||
# barrier so that all processes have finished writing their weights before creating .nemo file
|
||||
torch.distributed.barrier()
|
||||
|
||||
if is_global_rank_zero():
|
||||
# extract all tar files
|
||||
for mp_rank in range(app_state.model_parallel_size):
|
||||
mp_tar_path = f'{base_path}_mp_rank_{mp_rank:02d}.nemo'
|
||||
mp_tar = tarfile.open(mp_tar_path, 'r:gz')
|
||||
mp_tar.extractall(path=os.path.join(base_dir, f'mp_rank_{mp_rank:02d}'))
|
||||
mp_tar.close()
|
||||
os.remove(mp_tar_path)
|
||||
|
||||
# move rank 0 .nemo extract to base_path
|
||||
shutil.move(os.path.join(base_dir, 'mp_rank_00'), base_path)
|
||||
|
||||
# move mp_rank_00 checkpoint to mp_rank_00 directory inside base_path
|
||||
os.mkdir(os.path.join(base_path, 'mp_rank_00'))
|
||||
shutil.move(os.path.join(base_path, 'model_weights.ckpt'), os.path.join(base_path, 'mp_rank_00'))
|
||||
|
||||
# move other mp_rank checkpoints from base_dir to base_path
|
||||
for mp_rank in range(1, app_state.model_parallel_size):
|
||||
os.mkdir(os.path.join(base_path, f'mp_rank_{mp_rank:02d}'))
|
||||
shutil.move(
|
||||
os.path.join(base_dir, f'mp_rank_{mp_rank:02d}', 'model_weights.ckpt'),
|
||||
os.path.join(base_path, f'mp_rank_{mp_rank:02d}'),
|
||||
)
|
||||
# clean up leftover directory
|
||||
shutil.rmtree(os.path.join(base_dir, f'mp_rank_{mp_rank:02d}'))
|
||||
|
||||
# create tar file from base_path
|
||||
self._make_nemo_file_from_folder(save_path, base_path)
|
||||
|
||||
# clean up base_path
|
||||
shutil.rmtree(base_path)
|
||||
|
||||
elif is_global_rank_zero():
|
||||
return super()._default_save_to(save_path)
|
||||
else:
|
||||
return
|
||||
|
||||
@classmethod
|
||||
def restore_from(
|
||||
cls,
|
||||
restore_path: str,
|
||||
override_config_path: Optional[Union[OmegaConf, str]] = None,
|
||||
map_location: Optional[torch.device] = None,
|
||||
strict: bool = True,
|
||||
return_config: bool = False,
|
||||
trainer: Trainer = None,
|
||||
save_restore_connector: SaveRestoreConnector = None,
|
||||
):
|
||||
"""
|
||||
Restores model instance (weights and configuration) from .nemo file.
|
||||
|
||||
Args:
|
||||
restore_path: path to .nemo file from which model should be instantiated
|
||||
override_config_path: path to a yaml config that will override the internal
|
||||
config file or an OmegaConf / DictConfig object representing the model config.
|
||||
map_location: Optional torch.device() to map the instantiated model to a device.
|
||||
By default (None), it will select a GPU if available, falling back to CPU otherwise.
|
||||
strict: Passed to load_state_dict. Set to True by default.
|
||||
return_config: If set to true, will return just the underlying config of the restored
|
||||
model as an OmegaConf DictConfig object without instantiating the model.
|
||||
trainer: PyTorch Lightning trainer. Must be passed in order to use model parallel .nemo
|
||||
|
||||
Example:
|
||||
```
|
||||
model = nemo.collections.nlp.models.TokenClassificationModel.restore_from('token_classification.nemo')
|
||||
assert isinstance(model, nemo.collections.nlp.models.TokenClassificationModel)
|
||||
```
|
||||
|
||||
Returns:
|
||||
An instance of type cls or its underlying config (if return_config is set).
|
||||
"""
|
||||
if save_restore_connector is None:
|
||||
save_restore_connector = SaveRestoreConnector()
|
||||
|
||||
if not os.path.exists(restore_path):
|
||||
raise FileNotFoundError(f"Can't find {restore_path}")
|
||||
|
||||
app_state = AppState()
|
||||
app_state.model_restore_path = os.path.abspath(os.path.expanduser(restore_path))
|
||||
|
||||
# detect if we have a model parallel .nemo file
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
cwd = os.getcwd()
|
||||
os.chdir(tmpdir)
|
||||
# detect if model parallel from tarfile
|
||||
tar = tarfile.open(app_state.model_restore_path, "r:gz")
|
||||
names = tar.getnames()
|
||||
mp_ranks = []
|
||||
for name in names:
|
||||
if 'mp_rank' in name:
|
||||
mp_ranks.append(name)
|
||||
if mp_ranks:
|
||||
app_state.model_parallel_size = len(mp_ranks) // 2 # directory and file are included in getnames()
|
||||
|
||||
# get checkpoint version
|
||||
checkpoint_version_member = None
|
||||
for member in tar.getmembers():
|
||||
if 'megatron_checkpoint_version.json' in member.name:
|
||||
checkpoint_version_member = member
|
||||
tar.extract(checkpoint_version_member, tmpdir)
|
||||
with open(checkpoint_version_member.name, 'r') as f:
|
||||
checkpoint_version = json.load(f).get('checkpoint_version', None)
|
||||
logging.info(
|
||||
(
|
||||
f'Detected model parallel .nemo file: {restore_path}. '
|
||||
f'Assuming megatron model parallelism with '
|
||||
f'model_parallel_size: {app_state.model_parallel_size} '
|
||||
f'and checkpoint version: {checkpoint_version}'
|
||||
)
|
||||
)
|
||||
tar.close()
|
||||
os.chdir(cwd)
|
||||
|
||||
if app_state.model_parallel_size is not None:
|
||||
if not isinstance(trainer, Trainer):
|
||||
raise ValueError("trainer must be a PyTorch Lightning Trainer to restore model parallel .nemo files.")
|
||||
|
||||
if checkpoint_version is None:
|
||||
raise ValueError(
|
||||
"Restoring from megatron model parallel .nemo but could not find megatron checkpoint version."
|
||||
)
|
||||
else:
|
||||
logging.info(f"Setting megatron checkpoint version: {checkpoint_version}")
|
||||
set_checkpoint_version(checkpoint_version)
|
||||
|
||||
app_state.world_size = trainer.num_gpus * trainer.num_nodes
|
||||
|
||||
if trainer.local_rank is not None:
|
||||
app_state.local_rank = trainer.local_rank
|
||||
else:
|
||||
raise ValueError("trainer.local_rank is None. local_rank needed to restore model parallel models.")
|
||||
|
||||
model_parallel_rank = compute_model_parallel_rank(trainer.local_rank, app_state.model_parallel_size)
|
||||
app_state.model_parallel_rank = model_parallel_rank
|
||||
|
||||
cls.update_save_restore_connector(save_restore_connector)
|
||||
restored_model = cls._save_restore_connector.restore_from(
|
||||
cls, app_state.model_restore_path, override_config_path, map_location, strict, return_config
|
||||
)
|
||||
restored_model.set_trainer(trainer)
|
||||
return restored_model
|
||||
else:
|
||||
return super().restore_from(
|
||||
app_state.model_restore_path,
|
||||
override_config_path,
|
||||
map_location,
|
||||
strict,
|
||||
return_config,
|
||||
save_restore_connector=save_restore_connector,
|
||||
)
|
||||
|
||||
@rank_zero_only
|
||||
def register_megatron_checkpoint_version(self):
|
||||
""" Adds checkpoint version to .nemo archive """
|
||||
if self.has_megatron_encoder:
|
||||
checkpoint_version = get_checkpoint_version()
|
||||
checkpoint_version = get_megatron_checkpoint_version()
|
||||
if checkpoint_version is None:
|
||||
raise ValueError('Unable to get megatron checkpoint version.')
|
||||
else:
|
||||
|
@ -511,3 +267,72 @@ class NLPModel(ModelPT, Exportable):
|
|||
if isinstance(self.encoder, MegatronEncoderModule):
|
||||
logging.info(f"Restoring from pretrained model parallel checkpoint: {self.encoder.checkpoint_file}")
|
||||
self.encoder._encoder.restore_weights(self.encoder.checkpoint_file)
|
||||
|
||||
@classmethod
|
||||
def load_from_checkpoint(
|
||||
cls,
|
||||
checkpoint_path: str,
|
||||
map_location: Any = None,
|
||||
hparams_file: Optional[str] = None,
|
||||
strict: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Loads ModelPT from checkpoint, with some maintenance of restoration.
|
||||
For documentation, please refer to LightningModule.load_from_checkpoin() documentation.
|
||||
"""
|
||||
checkpoint = None
|
||||
try:
|
||||
cls._set_model_restore_state(is_being_restored=True)
|
||||
# TODO: replace with proper PTL API
|
||||
with pl_legacy_patch():
|
||||
if map_location is not None:
|
||||
checkpoint = pl_load(checkpoint_path, map_location=map_location)
|
||||
else:
|
||||
checkpoint = pl_load(checkpoint_path, map_location=lambda storage, loc: storage)
|
||||
|
||||
if hparams_file is not None:
|
||||
extension = hparams_file.split(".")[-1]
|
||||
if extension.lower() == "csv":
|
||||
hparams = load_hparams_from_tags_csv(hparams_file)
|
||||
elif extension.lower() in ("yml", "yaml"):
|
||||
hparams = load_hparams_from_yaml(hparams_file)
|
||||
else:
|
||||
raise ValueError(".csv, .yml or .yaml is required for `hparams_file`")
|
||||
|
||||
hparams["on_gpu"] = False
|
||||
|
||||
# overwrite hparams by the given file
|
||||
checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY] = hparams
|
||||
|
||||
# for past checkpoint need to add the new key
|
||||
if cls.CHECKPOINT_HYPER_PARAMS_KEY not in checkpoint:
|
||||
checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY] = {}
|
||||
# override the hparams with values that were passed in
|
||||
# TODO: can we do this without overriding?
|
||||
config_kwargs = kwargs.copy()
|
||||
if 'trainer' in config_kwargs:
|
||||
config_kwargs.pop('trainer')
|
||||
checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY].update(config_kwargs)
|
||||
|
||||
model = cls._load_model_state(checkpoint, strict=strict, **kwargs)
|
||||
checkpoint = model
|
||||
|
||||
finally:
|
||||
cls._set_model_restore_state(is_being_restored=False)
|
||||
return checkpoint
|
||||
|
||||
def save_to(self, save_path: str):
|
||||
app_state = AppState()
|
||||
# Add NeMo rank check as well
|
||||
if app_state.model_parallel_size is not None:
|
||||
if app_state.model_parallel_size > 1:
|
||||
if not isinstance(self._save_restore_connector, NLPSaveRestoreConnector):
|
||||
logging.warning(
|
||||
f"Using {self._save_restore_connector.__class__} to save a model parallel model. Overriding with NLPSaveRestoreConnector. Make sure to subclass NLPSaveRestoreConnector."
|
||||
)
|
||||
self._save_restore_connector = NLPSaveRestoreConnector()
|
||||
save_path = os.path.abspath(os.path.expanduser(save_path))
|
||||
self._save_restore_connector.save_to(self, save_path)
|
||||
else:
|
||||
super(NLPModel, self).save_to(save_path=save_path)
|
||||
|
|
|
@ -39,16 +39,6 @@ class BertModule(NeuralModule, Exportable):
|
|||
def output_types(self) -> Optional[Dict[str, NeuralType]]:
|
||||
return {"last_hidden_states": NeuralType(('B', 'T', 'D'), ChannelType())}
|
||||
|
||||
@classmethod
|
||||
def restore_from(cls, restore_path: str):
|
||||
"""Restores module/model with weights"""
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def save_to(self, save_path: str):
|
||||
"""Saves module/model with weights"""
|
||||
pass
|
||||
|
||||
def restore_weights(self, restore_path: str):
|
||||
"""Restores module/model's weights"""
|
||||
logging.info(f"Restoring weights from {restore_path}")
|
||||
|
|
|
@ -89,13 +89,15 @@ def get_lm_model(
|
|||
)
|
||||
|
||||
if "megatron" in pretrained_model_name:
|
||||
model, checkpoint_file = get_megatron_lm_model(
|
||||
config_dict=config_dict,
|
||||
config_file=config_file,
|
||||
pretrained_model_name=pretrained_model_name,
|
||||
checkpoint_file=checkpoint_file,
|
||||
vocab_file=vocab_file,
|
||||
)
|
||||
raise ValueError('megatron-lm BERT models have been deprecated in NeMo 1.5+. Please use NeMo 1.4 for support.')
|
||||
# TODO: enable megatron bert in nemo
|
||||
# model, checkpoint_file = get_megatron_lm_model(
|
||||
# config_dict=config_dict,
|
||||
# config_file=config_file,
|
||||
# pretrained_model_name=pretrained_model_name,
|
||||
# checkpoint_file=checkpoint_file,
|
||||
# vocab_file=vocab_file,
|
||||
# )
|
||||
else:
|
||||
model = get_huggingface_lm_model(
|
||||
config_dict=config_dict, config_file=config_file, pretrained_model_name=pretrained_model_name,
|
||||
|
@ -180,13 +182,17 @@ def get_transformer(
|
|||
)
|
||||
|
||||
elif library == 'megatron':
|
||||
model = get_megatron_transformer(
|
||||
model_name=model_name,
|
||||
pretrained=pretrained,
|
||||
config_dict=config_dict,
|
||||
encoder=encoder,
|
||||
checkpoint_file=checkpoint_file,
|
||||
raise ValueError(
|
||||
f'megatron-lm bert support has been deprecated in NeMo 1.5+. Please use NeMo 1.4 for support.'
|
||||
)
|
||||
# TODO: enable megatron bert in nemo
|
||||
# model = get_megatron_transformer(
|
||||
# model_name=model_name,
|
||||
# pretrained=pretrained,
|
||||
# config_dict=config_dict,
|
||||
# encoder=encoder,
|
||||
# checkpoint_file=checkpoint_file,
|
||||
# )
|
||||
|
||||
else:
|
||||
raise ValueError("Libary must be 'nemo', 'huggingface' or 'megatron'")
|
||||
|
|
137
nemo/collections/nlp/modules/common/megatron/clip_grads.py
Normal file
137
nemo/collections/nlp/modules/common/megatron/clip_grads.py
Normal file
|
@ -0,0 +1,137 @@
|
|||
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Gradient clipping."""
|
||||
|
||||
import amp_C
|
||||
import torch
|
||||
from apex.multi_tensor_apply import multi_tensor_applier
|
||||
from apex.transformer import parallel_state
|
||||
from apex.transformer.tensor_parallel.layers import param_is_not_tensor_parallel_duplicate
|
||||
from torch._six import inf
|
||||
|
||||
from nemo.collections.nlp.modules.common.megatron.module import param_is_not_shared
|
||||
|
||||
|
||||
def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
|
||||
"""Clips gradient norm of an iterable of parameters whose gradients
|
||||
are in fp32.
|
||||
This is adapted from torch.nn.utils.clip_grad.clip_grad_norm_ and
|
||||
added functionality to handle model parallel parameters. Note that
|
||||
the gradients are modified in place.
|
||||
Arguments:
|
||||
parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
|
||||
single Tensor that will have gradients normalized
|
||||
max_norm (float or int): max norm of the gradients
|
||||
norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for
|
||||
infinity norm.
|
||||
Returns:
|
||||
Total norm of the parameters (viewed as a single vector).
|
||||
"""
|
||||
|
||||
if isinstance(parameters, torch.Tensor):
|
||||
parameters = [parameters]
|
||||
|
||||
# Filter parameters based on:
|
||||
# - grad should not be none
|
||||
# - parameter should not be shared
|
||||
# - should not be a replica due to tensor model parallelism
|
||||
grads = []
|
||||
grads_for_norm = []
|
||||
for param in parameters:
|
||||
grad_not_none = param.grad is not None
|
||||
is_not_shared = param_is_not_shared(param)
|
||||
is_not_tp_duplicate = param_is_not_tensor_parallel_duplicate(param)
|
||||
grad = param.grad.detach()
|
||||
if grad_not_none:
|
||||
# Make sure the grads are in fp32
|
||||
assert isinstance(param.grad, torch.cuda.FloatTensor)
|
||||
grads.append(grad)
|
||||
if grad_not_none and is_not_shared and is_not_tp_duplicate:
|
||||
grads_for_norm.append(grad)
|
||||
|
||||
# Norm parameters.
|
||||
max_norm = float(max_norm)
|
||||
norm_type = float(norm_type)
|
||||
total_norm = 0.0
|
||||
|
||||
# Calculate norm.
|
||||
if norm_type == inf:
|
||||
total_norm = max(grad.abs().max() for grad in grads_for_norm)
|
||||
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
|
||||
# Take max across all model-parallel GPUs.
|
||||
torch.distributed.all_reduce(
|
||||
total_norm_cuda, op=torch.distributed.ReduceOp.MAX, group=parallel_state.get_model_parallel_group()
|
||||
)
|
||||
total_norm = total_norm_cuda[0].item()
|
||||
|
||||
else:
|
||||
if norm_type == 2.0:
|
||||
dummy_overflow_buf = torch.cuda.IntTensor([0])
|
||||
# Use apex's multi-tensor applier for efficiency reasons.
|
||||
# Multi-tensor applier takes a function and a list of list
|
||||
# and performs the operation on that list all in one kernel.
|
||||
grad_norm, _ = multi_tensor_applier(
|
||||
amp_C.multi_tensor_l2norm, dummy_overflow_buf, [grads_for_norm], False # no per-parameter norm
|
||||
)
|
||||
# Since we will be summing across data parallel groups,
|
||||
# we need the pow(norm-type).
|
||||
total_norm = grad_norm ** norm_type
|
||||
|
||||
else:
|
||||
for grad in grads_for_norm:
|
||||
grad_norm = torch.norm(grad, norm_type)
|
||||
total_norm += grad_norm ** norm_type
|
||||
|
||||
# Sum across all model-parallel GPUs.
|
||||
torch.distributed.all_reduce(
|
||||
total_norm, op=torch.distributed.ReduceOp.SUM, group=parallel_state.get_model_parallel_group()
|
||||
)
|
||||
total_norm = total_norm.item() ** (1.0 / norm_type)
|
||||
|
||||
# Scale.
|
||||
clip_coeff = max_norm / (total_norm + 1.0e-6)
|
||||
if clip_coeff < 1.0:
|
||||
dummy_overflow_buf = torch.cuda.IntTensor([0])
|
||||
multi_tensor_applier(amp_C.multi_tensor_scale, dummy_overflow_buf, [grads, grads], clip_coeff)
|
||||
|
||||
return total_norm
|
||||
|
||||
|
||||
def count_zeros_fp32(parameters):
|
||||
|
||||
if isinstance(parameters, torch.Tensor):
|
||||
parameters = [parameters]
|
||||
|
||||
# Filter parameters based on:
|
||||
# - grad should not be none
|
||||
# - parameter should not be shared
|
||||
# - should not be a replica due to tensor model parallelism
|
||||
total_num_zeros = 0.0
|
||||
for param in parameters:
|
||||
grad_not_none = param.grad is not None
|
||||
is_not_shared = param_is_not_shared(param)
|
||||
is_not_tp_duplicate = param_is_not_tensor_parallel_duplicate(param)
|
||||
if grad_not_none and is_not_shared and is_not_tp_duplicate:
|
||||
grad = param.grad.detach()
|
||||
num_zeros = grad.numel() - torch.count_nonzero(grad)
|
||||
total_num_zeros = num_zeros + total_num_zeros
|
||||
|
||||
# Sum across all model-parallel GPUs.
|
||||
torch.distributed.all_reduce(
|
||||
total_num_zeros, op=torch.distributed.ReduceOp.SUM, group=parallel_state.get_model_parallel_group()
|
||||
)
|
||||
total_num_zeros = total_num_zeros.item()
|
||||
|
||||
return total_num_zeros
|
30
nemo/collections/nlp/modules/common/megatron/enums.py
Normal file
30
nemo/collections/nlp/modules/common/megatron/enums.py
Normal file
|
@ -0,0 +1,30 @@
|
|||
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import enum
|
||||
|
||||
|
||||
class LayerType(enum.Enum):
|
||||
encoder = 1
|
||||
decoder = 2
|
||||
|
||||
|
||||
class AttnType(enum.Enum):
|
||||
self_attn = 1
|
||||
cross_attn = 2
|
||||
|
||||
|
||||
class AttnMaskType(enum.Enum):
|
||||
padding = 1
|
||||
causal = 2
|
|
@ -0,0 +1,62 @@
|
|||
# coding=utf-8
|
||||
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import torch
|
||||
|
||||
from nemo.collections.nlp.modules.common.megatron.utils import AutocastModuleWrapper
|
||||
|
||||
|
||||
def bias_dropout_add(x, bias, residual, prob, training):
|
||||
# type: (Tensor, Tensor, Tensor, float, bool) -> Tensor
|
||||
out = torch.nn.functional.dropout(x + bias, p=prob, training=training)
|
||||
out = residual + out
|
||||
return out
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def bias_dropout_add_fused_train_(
|
||||
x: torch.Tensor, bias: torch.Tensor, residual: torch.Tensor, prob: float
|
||||
) -> torch.Tensor:
|
||||
# type: (Tensor, Tensor, Tensor, float) -> Tensor
|
||||
return bias_dropout_add(x, bias, residual, prob, True)
|
||||
|
||||
|
||||
class BiasDropoutAddFusedTrain(AutocastModuleWrapper):
|
||||
def __init__(self, fp16=False, bf16=False):
|
||||
super(BiasDropoutAddFusedTrain, self).__init__(fp16, bf16)
|
||||
|
||||
self.func = bias_dropout_add_fused_train_
|
||||
|
||||
def forward(self, x, bias, residual, prob):
|
||||
return self.autocast_forward(x, bias, residual, prob)
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def bias_dropout_add_fused_inference_(
|
||||
x: torch.Tensor, bias: torch.Tensor, residual: torch.Tensor, prob: float
|
||||
) -> torch.Tensor:
|
||||
# type: (Tensor, Tensor, Tensor, float) -> Tensor
|
||||
return bias_dropout_add(x, bias, residual, prob, False)
|
||||
|
||||
|
||||
class BiasDropoutAddFusedInference(AutocastModuleWrapper):
|
||||
def __init__(self, fp16=False, bf16=False):
|
||||
super(BiasDropoutAddFusedInference, self).__init__(fp16, bf16)
|
||||
|
||||
self.func = bias_dropout_add_fused_inference_
|
||||
|
||||
def forward(self, x, bias, residual, prob):
|
||||
return self.autocast_forward(x, bias, residual, prob)
|
|
@ -0,0 +1,68 @@
|
|||
# coding=utf-8
|
||||
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
|
||||
from nemo.collections.nlp.modules.common.megatron.utils import AutocastModuleWrapper
|
||||
|
||||
###### BIAS GELU FUSION/ NO AUTOGRAD ################
|
||||
# 1/sqrt(2*pi)-> 0.3989423
|
||||
# 1/sqrt(2) -> 0.70710678
|
||||
# sqrt(2/pi) -> 0.79788456
|
||||
# this function is tanh approximation of gelu
|
||||
# actual gelu is:
|
||||
# x * 0.5 * (1.0 + torch.erf(x * 0.70710678))
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def bias_gelu(bias, y):
|
||||
x = bias + y
|
||||
return x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))
|
||||
|
||||
|
||||
# gradient of tanh approximation of gelu
|
||||
# gradient of actual gelu is:
|
||||
# 0.5 * (1. + torch.erf(x * 0.70710678)) + 0.3989423 * x * torch.exp(-0.5 * x * x)
|
||||
@torch.jit.script
|
||||
def bias_gelu_back(g, bias, y):
|
||||
x = bias + y
|
||||
tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))
|
||||
# sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243
|
||||
ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out)
|
||||
return ff * g
|
||||
|
||||
|
||||
class GeLUFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
# bias is an optional argument
|
||||
def forward(ctx, input, bias):
|
||||
ctx.save_for_backward(input, bias)
|
||||
return bias_gelu(bias, input)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
input, bias = ctx.saved_tensors
|
||||
tmp = bias_gelu_back(grad_output, bias, input)
|
||||
return tmp, tmp
|
||||
|
||||
|
||||
class FusedBiasGeLU(AutocastModuleWrapper):
|
||||
def __init__(self, fp16=False, bf16=False):
|
||||
super(FusedBiasGeLU, self).__init__(fp16, bf16)
|
||||
|
||||
self.func = GeLUFunction
|
||||
|
||||
def forward(self, input, bias):
|
||||
return self.autocast_forward(input, bias)
|
|
@ -0,0 +1,107 @@
|
|||
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import pathlib
|
||||
import subprocess
|
||||
|
||||
from torch.utils import cpp_extension
|
||||
|
||||
# Setting this param to a list has a problem of generating different
|
||||
# compilation commands (with diferent order of architectures) and
|
||||
# leading to recompilation of fused kernels. Set it to empty string
|
||||
# to avoid recompilation and assign arch flags explicity in
|
||||
# extra_cuda_cflags below
|
||||
os.environ["TORCH_CUDA_ARCH_LIST"] = ""
|
||||
|
||||
|
||||
def load():
|
||||
|
||||
# Check if cuda 11 is installed for compute capability 8.0
|
||||
cc_flag = []
|
||||
_, bare_metal_major, _ = _get_cuda_bare_metal_version(cpp_extension.CUDA_HOME)
|
||||
if int(bare_metal_major) >= 11:
|
||||
cc_flag.append('-gencode')
|
||||
cc_flag.append('arch=compute_80,code=sm_80')
|
||||
|
||||
# Build path
|
||||
srcpath = pathlib.Path(__file__).parent.absolute()
|
||||
buildpath = srcpath / 'build'
|
||||
_create_build_dir(buildpath)
|
||||
|
||||
# Helper function to build the kernels.
|
||||
def _cpp_extention_load_helper(name, sources, extra_cuda_flags):
|
||||
return cpp_extension.load(
|
||||
name=name,
|
||||
sources=sources,
|
||||
build_directory=buildpath,
|
||||
extra_cflags=['-O3',],
|
||||
extra_cuda_cflags=['-O3', '-gencode', 'arch=compute_70,code=sm_70', '--use_fast_math']
|
||||
+ extra_cuda_flags
|
||||
+ cc_flag,
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
# ==============
|
||||
# Fused softmax.
|
||||
# ==============
|
||||
|
||||
extra_cuda_flags = [
|
||||
'-U__CUDA_NO_HALF_OPERATORS__',
|
||||
'-U__CUDA_NO_HALF_CONVERSIONS__',
|
||||
'--expt-relaxed-constexpr',
|
||||
'--expt-extended-lambda',
|
||||
]
|
||||
|
||||
# Upper triangular softmax.
|
||||
sources = [
|
||||
srcpath / 'scaled_upper_triang_masked_softmax.cpp',
|
||||
srcpath / 'scaled_upper_triang_masked_softmax_cuda.cu',
|
||||
]
|
||||
scaled_upper_triang_masked_softmax_cuda = _cpp_extention_load_helper(
|
||||
"scaled_upper_triang_masked_softmax_cuda", sources, extra_cuda_flags
|
||||
)
|
||||
|
||||
# Masked softmax.
|
||||
sources = [srcpath / 'scaled_masked_softmax.cpp', srcpath / 'scaled_masked_softmax_cuda.cu']
|
||||
scaled_masked_softmax_cuda = _cpp_extention_load_helper("scaled_masked_softmax_cuda", sources, extra_cuda_flags)
|
||||
|
||||
# =================================
|
||||
# Mixed precision fused layer norm.
|
||||
# =================================
|
||||
|
||||
extra_cuda_flags = ['-maxrregcount=50']
|
||||
sources = [srcpath / 'layer_norm_cuda.cpp', srcpath / 'layer_norm_cuda_kernel.cu']
|
||||
fused_mix_prec_layer_norm_cuda = _cpp_extention_load_helper(
|
||||
"fused_mix_prec_layer_norm_cuda", sources, extra_cuda_flags
|
||||
)
|
||||
|
||||
|
||||
def _get_cuda_bare_metal_version(cuda_dir):
|
||||
raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True)
|
||||
output = raw_output.split()
|
||||
release_idx = output.index("release") + 1
|
||||
release = output[release_idx].split(".")
|
||||
bare_metal_major = release[0]
|
||||
bare_metal_minor = release[1][0]
|
||||
|
||||
return raw_output, bare_metal_major, bare_metal_minor
|
||||
|
||||
|
||||
def _create_build_dir(buildpath):
|
||||
try:
|
||||
os.mkdir(buildpath)
|
||||
except OSError:
|
||||
if not os.path.isdir(buildpath):
|
||||
print(f"Creation of the build directory {buildpath} failed")
|
|
@ -0,0 +1,31 @@
|
|||
/* coding=utf-8
|
||||
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
/*This code is copied fron NVIDIA apex:
|
||||
* https://github.com/NVIDIA/apex
|
||||
* with minor changes. */
|
||||
|
||||
|
||||
|
||||
#ifndef TORCH_CHECK
|
||||
#define TORCH_CHECK AT_CHECK
|
||||
#endif
|
||||
|
||||
#ifdef VERSION_GE_1_3
|
||||
#define DATA_PTR data_ptr
|
||||
#else
|
||||
#define DATA_PTR data
|
||||
#endif
|
|
@ -0,0 +1,201 @@
|
|||
/* coding=utf-8
|
||||
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
/*This code is copied fron NVIDIA apex:
|
||||
* https://github.com/NVIDIA/apex
|
||||
* with minor changes. */
|
||||
|
||||
#include <torch/extension.h>
|
||||
#include <vector>
|
||||
#include <cassert>
|
||||
#include "compat.h"
|
||||
|
||||
namespace {
|
||||
|
||||
void compute_n1_n2(
|
||||
at::Tensor input,
|
||||
at::IntArrayRef normalized_shape,
|
||||
int& n1,
|
||||
int& n2) {
|
||||
int idiff = input.ndimension() - normalized_shape.size();
|
||||
n2 = 1;
|
||||
for (int i = 0; i < (int)normalized_shape.size(); ++i) {
|
||||
assert( input.sizes()[i+idiff] == normalized_shape[i] );
|
||||
n2 *= normalized_shape[i];
|
||||
}
|
||||
n1 = 1;
|
||||
for (int i = 0; i < idiff; ++i) {
|
||||
n1 *= input.sizes()[i];
|
||||
}
|
||||
}
|
||||
|
||||
void check_args(
|
||||
at::IntArrayRef normalized_shape,
|
||||
at::Tensor gamma,
|
||||
at::Tensor beta
|
||||
)
|
||||
{
|
||||
TORCH_CHECK(!gamma.defined() || gamma.sizes().equals(normalized_shape));
|
||||
TORCH_CHECK(!beta.defined() || beta.sizes().equals(normalized_shape));
|
||||
}
|
||||
|
||||
void check_args(
|
||||
at::Tensor input,
|
||||
at::IntArrayRef normalized_shape,
|
||||
int& n1,
|
||||
int& n2
|
||||
)
|
||||
{
|
||||
int64_t normalized_ndim = normalized_shape.size();
|
||||
|
||||
if (normalized_ndim < 1) {
|
||||
std::stringstream ss;
|
||||
ss << "Expected normalized_shape to be at least 1-dimensional, i.e., "
|
||||
<< "containing at least one element, but got normalized_shape="
|
||||
<< normalized_shape;
|
||||
throw std::runtime_error(ss.str());
|
||||
}
|
||||
|
||||
auto input_shape = input.sizes();
|
||||
auto input_ndim = input.dim();
|
||||
|
||||
if (input_ndim < normalized_ndim ||
|
||||
!input_shape.slice(input_ndim - normalized_ndim).equals(normalized_shape)) {
|
||||
std::stringstream ss;
|
||||
ss << "Given normalized_shape=" << normalized_shape
|
||||
<< ", expected input with shape [*";
|
||||
for (auto size : normalized_shape) {
|
||||
ss << ", " << size;
|
||||
}
|
||||
ss << "], but got input of size" << input_shape;
|
||||
throw std::runtime_error(ss.str());
|
||||
}
|
||||
|
||||
compute_n1_n2(input,normalized_shape,n1,n2);
|
||||
}
|
||||
|
||||
|
||||
void check_args(
|
||||
at::Tensor input,
|
||||
at::IntArrayRef normalized_shape,
|
||||
at::Tensor gamma,
|
||||
at::Tensor beta,
|
||||
int& n1,
|
||||
int& n2
|
||||
)
|
||||
{
|
||||
check_args(input,normalized_shape,n1,n2);
|
||||
check_args(normalized_shape,gamma,beta);
|
||||
}
|
||||
}
|
||||
|
||||
void cuda_layer_norm(
|
||||
at::Tensor* output,
|
||||
at::Tensor* mean,
|
||||
at::Tensor* invvar,
|
||||
at::Tensor* input,
|
||||
int n1,
|
||||
int n2,
|
||||
at::IntArrayRef normalized_shape,
|
||||
at::Tensor* gamma,
|
||||
at::Tensor* beta,
|
||||
double epsilon);
|
||||
|
||||
#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
|
||||
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
|
||||
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
|
||||
|
||||
std::vector<at::Tensor> layer_norm_affine(
|
||||
at::Tensor input,
|
||||
at::IntArrayRef normalized_shape,
|
||||
at::Tensor gamma,
|
||||
at::Tensor beta,
|
||||
double epsilon) {
|
||||
|
||||
CHECK_INPUT(input);
|
||||
CHECK_INPUT(gamma);
|
||||
CHECK_INPUT(beta);
|
||||
int n1, n2;
|
||||
check_args(input, normalized_shape, gamma, beta, n1, n2);
|
||||
|
||||
at::Tensor output = at::empty_like(
|
||||
input, gamma.options().dtype(gamma.scalar_type()));
|
||||
at::Tensor mean = at::empty(
|
||||
{n1}, input.options().dtype(at::ScalarType::Float));
|
||||
at::Tensor invvar = at::empty_like(mean);
|
||||
|
||||
cuda_layer_norm(&output, &mean, &invvar, &input, n1, n2,
|
||||
normalized_shape, &gamma, &beta, epsilon);
|
||||
|
||||
return {output, mean, invvar};
|
||||
|
||||
}
|
||||
|
||||
|
||||
void cuda_layer_norm_gradient(
|
||||
at::Tensor* dout,
|
||||
at::Tensor* mean,
|
||||
at::Tensor* invvar,
|
||||
at::Tensor* input,
|
||||
int n1,
|
||||
int n2,
|
||||
at::IntArrayRef normalized_shape,
|
||||
at::Tensor* gamma,
|
||||
at::Tensor* beta,
|
||||
double epsilon,
|
||||
at::Tensor* grad_input,
|
||||
at::Tensor* grad_gamma,
|
||||
at::Tensor* grad_beta
|
||||
);
|
||||
|
||||
std::vector<at::Tensor> layer_norm_gradient_affine(
|
||||
at::Tensor dout,
|
||||
at::Tensor mean,
|
||||
at::Tensor invvar,
|
||||
at::Tensor input,
|
||||
at::IntArrayRef normalized_shape,
|
||||
at::Tensor gamma,
|
||||
at::Tensor beta,
|
||||
double epsilon) {
|
||||
|
||||
CHECK_INPUT(dout);
|
||||
CHECK_INPUT(mean);
|
||||
CHECK_INPUT(invvar);
|
||||
CHECK_INPUT(input);
|
||||
CHECK_INPUT(gamma);
|
||||
CHECK_INPUT(beta);
|
||||
int n1, n2;
|
||||
check_args(input, normalized_shape, gamma, beta, n1, n2);
|
||||
|
||||
at::Tensor grad_input = at::empty_like(input);
|
||||
at::Tensor grad_gamma = at::empty_like(gamma);
|
||||
at::Tensor grad_beta = at::empty_like(beta);
|
||||
|
||||
cuda_layer_norm_gradient(&dout, &mean, &invvar, &input, n1, n2,
|
||||
normalized_shape, &gamma, &beta, epsilon,
|
||||
&grad_input, &grad_gamma, &grad_beta);
|
||||
|
||||
return {grad_input, grad_gamma, grad_beta};
|
||||
|
||||
}
|
||||
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("forward_affine", &layer_norm_affine,
|
||||
"LayerNorm forward (CUDA)");
|
||||
m.def("backward_affine", &layer_norm_gradient_affine,
|
||||
"LayerNorm backward (CUDA)");
|
||||
}
|
|
@ -0,0 +1,829 @@
|
|||
/* coding=utf-8
|
||||
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
/*This code is copied fron NVIDIA apex:
|
||||
* https://github.com/NVIDIA/apex
|
||||
* with minor changes. */
|
||||
|
||||
#include "ATen/ATen.h"
|
||||
#include "ATen/AccumulateType.h"
|
||||
#include "ATen/cuda/CUDAContext.h"
|
||||
#include "ATen/cuda/DeviceUtils.cuh"
|
||||
|
||||
#include <cuda.h>
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
#include "type_shim.h"
|
||||
|
||||
template<typename U> __device__
|
||||
void cuWelfordOnlineSum(
|
||||
const U curr,
|
||||
U& mu,
|
||||
U& sigma2,
|
||||
U& count)
|
||||
{
|
||||
count = count + U(1);
|
||||
U delta = curr - mu;
|
||||
U lmean = mu + delta / count;
|
||||
mu = lmean;
|
||||
U delta2 = curr - lmean;
|
||||
sigma2 = sigma2 + delta * delta2;
|
||||
}
|
||||
|
||||
template<typename U> __device__
|
||||
void cuChanOnlineSum(
|
||||
const U muB,
|
||||
const U sigma2B,
|
||||
const U countB,
|
||||
U& mu,
|
||||
U& sigma2,
|
||||
U& count)
|
||||
{
|
||||
U delta = muB - mu;
|
||||
U nA = count;
|
||||
U nB = countB;
|
||||
count = count + countB;
|
||||
U nX = count;
|
||||
if (nX > U(0)) {
|
||||
nA = nA / nX;
|
||||
nB = nB / nX;
|
||||
mu = nA*mu + nB*muB;
|
||||
sigma2 = sigma2 + sigma2B + delta * delta * nA * nB * nX;
|
||||
} else {
|
||||
mu = U(0);
|
||||
sigma2 = U(0);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T, typename U> __device__
|
||||
void cuWelfordMuSigma2(
|
||||
const T* __restrict__ vals,
|
||||
const int n1,
|
||||
const int n2,
|
||||
const int i1,
|
||||
U& mu,
|
||||
U& sigma2,
|
||||
U* buf)
|
||||
{
|
||||
// Assumptions:
|
||||
// 1) blockDim.x == warpSize
|
||||
// 2) Tensor is contiguous
|
||||
// 3) 2*blockDim.y*sizeof(U)+blockDim.y*sizeof(int) shared memory available.
|
||||
//
|
||||
// compute variance and mean over n2
|
||||
U count = U(0);
|
||||
mu= U(0);
|
||||
sigma2 = U(0);
|
||||
if (i1 < n1) {
|
||||
// one warp normalizes one n1 index,
|
||||
// synchronization is implicit
|
||||
// initialize with standard Welford algorithm
|
||||
const int numx = blockDim.x * blockDim.y;
|
||||
const int thrx = threadIdx.x + threadIdx.y * blockDim.x;
|
||||
const T* lvals = vals + i1*n2;
|
||||
int l = 4*thrx;
|
||||
for (; l+3 < n2; l+=4*numx) {
|
||||
for (int k = 0; k < 4; ++k) {
|
||||
U curr = static_cast<U>(lvals[l+k]);
|
||||
cuWelfordOnlineSum<U>(curr,mu,sigma2,count);
|
||||
}
|
||||
}
|
||||
for (; l < n2; ++l) {
|
||||
U curr = static_cast<U>(lvals[l]);
|
||||
cuWelfordOnlineSum<U>(curr,mu,sigma2,count);
|
||||
}
|
||||
// intra-warp reductions
|
||||
for (int l = 0; l <= 4; ++l) {
|
||||
int srcLaneB = (threadIdx.x+(1<<l))&31;
|
||||
U muB = WARP_SHFL(mu, srcLaneB);
|
||||
U countB = WARP_SHFL(count, srcLaneB);
|
||||
U sigma2B = WARP_SHFL(sigma2, srcLaneB);
|
||||
cuChanOnlineSum<U>(muB,sigma2B,countB,mu,sigma2,count);
|
||||
}
|
||||
// threadIdx.x == 0 has correct values for each warp
|
||||
// inter-warp reductions
|
||||
if (blockDim.y > 1) {
|
||||
U* ubuf = (U*)buf;
|
||||
U* ibuf = (U*)(ubuf + blockDim.y);
|
||||
for (int offset = blockDim.y/2; offset > 0; offset /= 2) {
|
||||
// upper half of warps write to shared
|
||||
if (threadIdx.x == 0 && threadIdx.y >= offset && threadIdx.y < 2*offset) {
|
||||
const int wrt_y = threadIdx.y - offset;
|
||||
ubuf[2*wrt_y] = mu;
|
||||
ubuf[2*wrt_y+1] = sigma2;
|
||||
ibuf[wrt_y] = count;
|
||||
}
|
||||
__syncthreads();
|
||||
// lower half merges
|
||||
if (threadIdx.x == 0 && threadIdx.y < offset) {
|
||||
U muB = ubuf[2*threadIdx.y];
|
||||
U sigma2B = ubuf[2*threadIdx.y+1];
|
||||
U countB = ibuf[threadIdx.y];
|
||||
cuChanOnlineSum<U>(muB,sigma2B,countB,mu,sigma2,count);
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
// threadIdx.x = 0 && threadIdx.y == 0 only thread that has correct values
|
||||
if (threadIdx.x == 0 && threadIdx.y == 0) {
|
||||
ubuf[0] = mu;
|
||||
ubuf[1] = sigma2;
|
||||
}
|
||||
__syncthreads();
|
||||
mu = ubuf[0];
|
||||
sigma2 = ubuf[1]/U(n2);
|
||||
// don't care about final value of count, we know count == n2
|
||||
} else {
|
||||
mu = WARP_SHFL(mu, 0);
|
||||
sigma2 = WARP_SHFL(sigma2/U(n2), 0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<> __device__
|
||||
void cuWelfordMuSigma2(
|
||||
const at::Half* __restrict__ vals,
|
||||
const int n1,
|
||||
const int n2,
|
||||
const int i1,
|
||||
float& mu,
|
||||
float& sigma2,
|
||||
float* buf)
|
||||
{
|
||||
// Assumptions:
|
||||
// 1) blockDim.x == warpSize
|
||||
// 2) Tensor is contiguous
|
||||
// 3) 2*blockDim.y*sizeof(U)+blockDim.y*sizeof(int) shared memory available.
|
||||
//
|
||||
// compute variance and mean over n2
|
||||
float count = 0.0f;
|
||||
mu= float(0);
|
||||
sigma2 = float(0);
|
||||
if (i1 < n1) {
|
||||
// one warp normalizes one n1 index,
|
||||
// synchronization is implicit
|
||||
// initialize with standard Welford algorithm
|
||||
const int numx = blockDim.x * blockDim.y;
|
||||
const int thrx = threadIdx.x + threadIdx.y * blockDim.x;
|
||||
const at::Half* lvals = vals + i1*n2;
|
||||
int l = 8*thrx;
|
||||
if ((((size_t)lvals)&3) != 0) {
|
||||
// 16 bit alignment
|
||||
// first thread consumes first point
|
||||
if (thrx == 0) {
|
||||
float curr = static_cast<float>(lvals[0]);
|
||||
cuWelfordOnlineSum(curr,mu,sigma2,count);
|
||||
}
|
||||
++l;
|
||||
}
|
||||
// at this point, lvals[l] are 32 bit aligned for all threads.
|
||||
for (; l+7 < n2; l+=8*numx) {
|
||||
for (int k = 0; k < 8; k+=2) {
|
||||
float2 curr = __half22float2(*((__half2*)(lvals+l+k)));
|
||||
cuWelfordOnlineSum(curr.x,mu,sigma2,count);
|
||||
cuWelfordOnlineSum(curr.y,mu,sigma2,count);
|
||||
}
|
||||
}
|
||||
for (; l < n2; ++l) {
|
||||
float curr = static_cast<float>(lvals[l]);
|
||||
cuWelfordOnlineSum(curr,mu,sigma2,count);
|
||||
}
|
||||
// intra-warp reductions
|
||||
for (int l = 0; l <= 4; ++l) {
|
||||
int srcLaneB = (threadIdx.x+(1<<l))&31;
|
||||
float muB = WARP_SHFL(mu, srcLaneB);
|
||||
float countB = WARP_SHFL(count, srcLaneB);
|
||||
float sigma2B = WARP_SHFL(sigma2, srcLaneB);
|
||||
cuChanOnlineSum(muB,sigma2B,countB,mu,sigma2,count);
|
||||
}
|
||||
// threadIdx.x == 0 has correct values for each warp
|
||||
// inter-warp reductions
|
||||
if (blockDim.y > 1) {
|
||||
float* ubuf = (float*)buf;
|
||||
float* ibuf = (float*)(ubuf + blockDim.y);
|
||||
for (int offset = blockDim.y/2; offset > 0; offset /= 2) {
|
||||
// upper half of warps write to shared
|
||||
if (threadIdx.x == 0 && threadIdx.y >= offset && threadIdx.y < 2*offset) {
|
||||
const int wrt_y = threadIdx.y - offset;
|
||||
ubuf[2*wrt_y] = mu;
|
||||
ubuf[2*wrt_y+1] = sigma2;
|
||||
ibuf[wrt_y] = count;
|
||||
}
|
||||
__syncthreads();
|
||||
// lower half merges
|
||||
if (threadIdx.x == 0 && threadIdx.y < offset) {
|
||||
float muB = ubuf[2*threadIdx.y];
|
||||
float sigma2B = ubuf[2*threadIdx.y+1];
|
||||
float countB = ibuf[threadIdx.y];
|
||||
cuChanOnlineSum(muB,sigma2B,countB,mu,sigma2,count);
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
// threadIdx.x = 0 && threadIdx.y == 0 only thread that has correct values
|
||||
if (threadIdx.x == 0 && threadIdx.y == 0) {
|
||||
ubuf[0] = mu;
|
||||
ubuf[1] = sigma2;
|
||||
}
|
||||
__syncthreads();
|
||||
mu = ubuf[0];
|
||||
sigma2 = ubuf[1]/float(n2);
|
||||
// don't care about final value of count, we know count == n2
|
||||
} else {
|
||||
mu = WARP_SHFL(mu, 0);
|
||||
sigma2 = WARP_SHFL(sigma2/float(n2), 0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<typename U> U rsqrt(U v) {
|
||||
return U(1) / sqrt(v);
|
||||
}
|
||||
template<> float rsqrt(float v) {
|
||||
return rsqrtf(v);
|
||||
}
|
||||
template<> double rsqrt(double v) {
|
||||
return rsqrt(v);
|
||||
}
|
||||
|
||||
namespace {
|
||||
// This is the un-specialized struct. Note that we prevent instantiation of this
|
||||
// struct by putting an undefined symbol in the function body so it won't compile.
|
||||
// template <typename T>
|
||||
// struct SharedMemory
|
||||
// {
|
||||
// // Ensure that we won't compile any un-specialized types
|
||||
// __device__ T *getPointer()
|
||||
// {
|
||||
// extern __device__ void error(void);
|
||||
// error();
|
||||
// return NULL;
|
||||
// }
|
||||
// };
|
||||
// https://github.com/NVIDIA/apex/issues/246
|
||||
template <typename T>
|
||||
struct SharedMemory;
|
||||
|
||||
template <>
|
||||
struct SharedMemory <float>
|
||||
{
|
||||
__device__ float *getPointer()
|
||||
{
|
||||
extern __shared__ float s_float[];
|
||||
return s_float;
|
||||
}
|
||||
};
|
||||
|
||||
}
|
||||
|
||||
template<typename T, typename U, typename V> __global__
|
||||
void cuApplyLayerNorm(
|
||||
V* __restrict__ output_vals,
|
||||
U* __restrict__ mean,
|
||||
U* __restrict__ invvar,
|
||||
const T* __restrict__ vals,
|
||||
const int n1,
|
||||
const int n2,
|
||||
const U epsilon,
|
||||
const V* __restrict__ gamma,
|
||||
const V* __restrict__ beta
|
||||
)
|
||||
{
|
||||
// Assumptions:
|
||||
// 1) blockDim.x == warpSize
|
||||
// 2) Tensors are contiguous
|
||||
//
|
||||
for (auto i1=blockIdx.y; i1 < n1; i1 += gridDim.y) {
|
||||
SharedMemory<U> shared;
|
||||
U* buf = shared.getPointer();
|
||||
U mu,sigma2;
|
||||
cuWelfordMuSigma2(vals,n1,n2,i1,mu,sigma2,buf);
|
||||
const T* lvals = vals + i1*n2;
|
||||
V* ovals = output_vals + i1*n2;
|
||||
U c_invvar = rsqrt(sigma2 + epsilon);
|
||||
const int numx = blockDim.x * blockDim.y;
|
||||
const int thrx = threadIdx.x + threadIdx.y * blockDim.x;
|
||||
if (gamma != NULL && beta != NULL) {
|
||||
for (int i = thrx; i < n2; i+=numx) {
|
||||
U curr = static_cast<U>(lvals[i]);
|
||||
ovals[i] = gamma[i] * static_cast<V>(c_invvar * (curr - mu)) + beta[i];
|
||||
}
|
||||
} else {
|
||||
for (int i = thrx; i < n2; i+=numx) {
|
||||
U curr = static_cast<U>(lvals[i]);
|
||||
ovals[i] = static_cast<V>(c_invvar * (curr - mu));
|
||||
}
|
||||
}
|
||||
if (threadIdx.x == 0 && threadIdx.y == 0) {
|
||||
mean[i1] = mu;
|
||||
invvar[i1] = c_invvar;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T, typename U, typename V> __device__
|
||||
void cuLoadWriteStridedInputs(
|
||||
const int i1_block,
|
||||
const int thr_load_row_off,
|
||||
const int thr_load_col_off,
|
||||
const int i2_off,
|
||||
const int row_stride,
|
||||
U* warp_buf1,
|
||||
U* warp_buf2,
|
||||
const T* input,
|
||||
const V* dout,
|
||||
const int i1_end,
|
||||
const int n2,
|
||||
const U* __restrict__ mean,
|
||||
const U* __restrict__ invvar
|
||||
)
|
||||
{
|
||||
int i1 = i1_block+thr_load_row_off;
|
||||
if (i1 < i1_end) {
|
||||
U curr_mean = mean[i1];
|
||||
U curr_invvar = invvar[i1];
|
||||
for (int k = 0; k < blockDim.y; ++k) {
|
||||
int i2 = i2_off + k;
|
||||
int load_idx = i1*n2+i2;
|
||||
int write_idx = thr_load_row_off*row_stride+thr_load_col_off+k;
|
||||
if (i2<n2) {
|
||||
U curr_input = static_cast<U>(input[load_idx]);
|
||||
U curr_dout = static_cast<U>(dout[load_idx]);
|
||||
warp_buf1[write_idx] = curr_dout;
|
||||
warp_buf2[write_idx] = curr_dout * (curr_input - curr_mean) * curr_invvar;
|
||||
} else {
|
||||
warp_buf1[write_idx] = U(0);
|
||||
warp_buf2[write_idx] = U(0);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (int k = 0; k < blockDim.y; ++k) {
|
||||
int write_idx = thr_load_row_off*row_stride+thr_load_col_off+k;
|
||||
warp_buf1[write_idx] = U(0);
|
||||
warp_buf2[write_idx] = U(0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T, typename U, typename V> __device__
|
||||
void cuLoadAddStridedInputs(
|
||||
const int i1_block,
|
||||
const int thr_load_row_off,
|
||||
const int thr_load_col_off,
|
||||
const int i2_off,
|
||||
const int row_stride,
|
||||
U* warp_buf1,
|
||||
U* warp_buf2,
|
||||
const T* input,
|
||||
const V* dout,
|
||||
const int i1_end,
|
||||
const int n2,
|
||||
const U* __restrict__ mean,
|
||||
const U* __restrict__ invvar
|
||||
)
|
||||
{
|
||||
int i1 = i1_block+thr_load_row_off;
|
||||
if (i1 < i1_end) {
|
||||
U curr_mean = mean[i1];
|
||||
U curr_invvar = invvar[i1];
|
||||
for (int k = 0; k < blockDim.y; ++k) {
|
||||
int i2 = i2_off + k;
|
||||
int load_idx = i1*n2+i2;
|
||||
int write_idx = thr_load_row_off*row_stride+thr_load_col_off+k;
|
||||
if (i2<n2) {
|
||||
U curr_input = static_cast<U>(input[load_idx]);
|
||||
U curr_dout = static_cast<U>(dout[load_idx]);
|
||||
warp_buf1[write_idx] += curr_dout;
|
||||
warp_buf2[write_idx] += curr_dout * (curr_input - curr_mean) * curr_invvar;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T, typename U, typename V> __global__
|
||||
void cuComputePartGradGammaBeta(
|
||||
const V* __restrict__ dout,
|
||||
const T* __restrict__ input,
|
||||
const int n1,
|
||||
const int n2,
|
||||
const U* __restrict__ mean,
|
||||
const U* __restrict__ invvar,
|
||||
U epsilon,
|
||||
U* part_grad_gamma,
|
||||
U* part_grad_beta)
|
||||
{
|
||||
const int numsegs_n1 = (n1+blockDim.y*blockDim.y-1) / (blockDim.y*blockDim.y);
|
||||
const int segs_per_block = (numsegs_n1 + gridDim.y - 1) / gridDim.y;
|
||||
const int i1_beg = blockIdx.y * segs_per_block * blockDim.y*blockDim.y;
|
||||
const int i1_beg_plus_one = (blockIdx.y+1) * segs_per_block * blockDim.y*blockDim.y;
|
||||
const int i1_end = i1_beg_plus_one < n1 ? i1_beg_plus_one : n1;
|
||||
const int row_stride = blockDim.x+1;
|
||||
const int thr_load_col_off = (threadIdx.x*blockDim.y)&(blockDim.x-1);
|
||||
const int thr_load_row_off = (threadIdx.x*blockDim.y)/blockDim.x + threadIdx.y*blockDim.y;
|
||||
const int i2_off = blockIdx.x * blockDim.x + thr_load_col_off;
|
||||
SharedMemory<U> shared;
|
||||
U* buf = shared.getPointer(); // buf has at least blockDim.x * blockDim.y * blockDim.y + (blockDim.y - 1)*(blockDim.x/blockDim.y) elements
|
||||
U* warp_buf1 = (U*)buf;
|
||||
U* warp_buf2 = warp_buf1 + blockDim.y * blockDim.y * row_stride;
|
||||
// compute partial sums from strided inputs
|
||||
// do this to increase number of loads in flight
|
||||
cuLoadWriteStridedInputs(i1_beg,thr_load_row_off,thr_load_col_off,i2_off,row_stride,warp_buf1,warp_buf2,input,dout,i1_end,n2,mean,invvar);
|
||||
for (int i1_block = i1_beg+blockDim.y*blockDim.y; i1_block < i1_end; i1_block+=blockDim.y*blockDim.y) {
|
||||
cuLoadAddStridedInputs(i1_block,thr_load_row_off,thr_load_col_off,i2_off,row_stride,warp_buf1,warp_buf2,input,dout,i1_end,n2,mean,invvar);
|
||||
}
|
||||
__syncthreads();
|
||||
// inter-warp reductions
|
||||
// sum within each warp
|
||||
U acc1 = U(0);
|
||||
U acc2 = U(0);
|
||||
for (int k = 0; k < blockDim.y; ++k) {
|
||||
int row1 = threadIdx.y + k*blockDim.y;
|
||||
int idx1 = row1*row_stride + threadIdx.x;
|
||||
acc1 += warp_buf1[idx1];
|
||||
acc2 += warp_buf2[idx1];
|
||||
}
|
||||
warp_buf1[threadIdx.y*row_stride+threadIdx.x] = acc1;
|
||||
warp_buf2[threadIdx.y*row_stride+threadIdx.x] = acc2;
|
||||
__syncthreads();
|
||||
// sum all warps
|
||||
for (int offset = blockDim.y/2; offset > 1; offset /= 2) {
|
||||
if (threadIdx.y < offset) {
|
||||
int row1 = threadIdx.y;
|
||||
int row2 = threadIdx.y + offset;
|
||||
int idx1 = row1*row_stride + threadIdx.x;
|
||||
int idx2 = row2*row_stride + threadIdx.x;
|
||||
warp_buf1[idx1] += warp_buf1[idx2];
|
||||
warp_buf2[idx1] += warp_buf2[idx2];
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
int i2 = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (threadIdx.y == 0 && i2 < n2) {
|
||||
int row1 = threadIdx.y;
|
||||
int row2 = threadIdx.y + 1;
|
||||
int idx1 = row1*row_stride + threadIdx.x;
|
||||
int idx2 = row2*row_stride + threadIdx.x;
|
||||
part_grad_beta[blockIdx.y*n2+i2] = warp_buf1[idx1] + warp_buf1[idx2];
|
||||
part_grad_gamma[blockIdx.y*n2+i2] = warp_buf2[idx1] + warp_buf2[idx2];
|
||||
}
|
||||
}
|
||||
|
||||
template<typename U, typename V> __global__
|
||||
void cuComputeGradGammaBeta(
|
||||
const U* part_grad_gamma,
|
||||
const U* part_grad_beta,
|
||||
const int part_size,
|
||||
const int n1,
|
||||
const int n2,
|
||||
V* grad_gamma,
|
||||
V* grad_beta)
|
||||
{
|
||||
// sum partial gradients for gamma and beta
|
||||
SharedMemory<U> shared;
|
||||
U* buf = shared.getPointer();
|
||||
int i2 = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (i2 < n2) {
|
||||
// each warp does sequential reductions until reduced part_size is num_warps
|
||||
int num_warp_reductions = part_size / blockDim.y;
|
||||
U sum_gamma = U(0);
|
||||
U sum_beta = U(0);
|
||||
const U* part_grad_gamma_ptr = part_grad_gamma + threadIdx.y * num_warp_reductions * n2 + i2;
|
||||
const U* part_grad_beta_ptr = part_grad_beta + threadIdx.y * num_warp_reductions * n2 + i2;
|
||||
for (int warp_offset = 0; warp_offset < num_warp_reductions; ++warp_offset) {
|
||||
sum_gamma += part_grad_gamma_ptr[warp_offset*n2];
|
||||
sum_beta += part_grad_beta_ptr[warp_offset*n2];
|
||||
}
|
||||
// inter-warp reductions
|
||||
const int nbsize3 = blockDim.x * blockDim.y / 2;
|
||||
for (int offset = blockDim.y/2; offset >= 1; offset /= 2) {
|
||||
// top half write to shared memory
|
||||
if (threadIdx.y >= offset && threadIdx.y < 2*offset) {
|
||||
const int write_idx = (threadIdx.y - offset) * blockDim.x + threadIdx.x;
|
||||
buf[write_idx] = sum_gamma;
|
||||
buf[write_idx+nbsize3] = sum_beta;
|
||||
}
|
||||
__syncthreads();
|
||||
// bottom half sums
|
||||
if (threadIdx.y < offset) {
|
||||
const int read_idx = threadIdx.y * blockDim.x + threadIdx.x;
|
||||
sum_gamma += buf[read_idx];
|
||||
sum_beta += buf[read_idx+nbsize3];
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
// write out fully summed gradients
|
||||
if (threadIdx.y == 0) {
|
||||
grad_gamma[i2] = sum_gamma;
|
||||
grad_beta[i2] = sum_beta;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T, typename U, typename V> __global__
|
||||
void cuComputeGradInput(
|
||||
const V* __restrict__ dout,
|
||||
const T* __restrict__ input,
|
||||
const int n1,
|
||||
const int n2,
|
||||
const U* __restrict__ mean,
|
||||
const U* __restrict__ invvar,
|
||||
U epsilon,
|
||||
const V* gamma,
|
||||
T* grad_input)
|
||||
{
|
||||
for (auto i1=blockIdx.y; i1 < n1; i1 += gridDim.y) {
|
||||
U sum_loss1 = U(0);
|
||||
U sum_loss2 = U(0);
|
||||
const U c_mean = mean[i1];
|
||||
const U c_invvar = invvar[i1];
|
||||
const T* k_input = input + i1*n2;
|
||||
const V* k_dout = dout + i1*n2;
|
||||
const int numx = blockDim.x * blockDim.y;
|
||||
const int thrx = threadIdx.x + threadIdx.y * blockDim.x;
|
||||
if (gamma != NULL) {
|
||||
int l = 4*thrx;
|
||||
for (; l+3 < n2; l+=4*numx) {
|
||||
for (int k = 0; k < 4; ++k) {
|
||||
const U c_h = static_cast<U>(k_input[l+k]);
|
||||
const U c_loss = static_cast<U>(k_dout[l+k]);
|
||||
sum_loss1 += c_loss * gamma[l+k];
|
||||
sum_loss2 += c_loss * gamma[l+k] * (c_h - c_mean) * c_invvar;
|
||||
}
|
||||
}
|
||||
for (; l < n2; ++l) {
|
||||
const U c_h = static_cast<U>(k_input[l]);
|
||||
const U c_loss = static_cast<U>(k_dout[l]);
|
||||
sum_loss1 += c_loss * gamma[l];
|
||||
sum_loss2 += c_loss * gamma[l] * (c_h - c_mean) * c_invvar;
|
||||
}
|
||||
} else {
|
||||
int l = 4*thrx;
|
||||
for (; l+3 < n2; l+=4*numx) {
|
||||
for (int k = 0; k < 4; ++k) {
|
||||
const U c_h = static_cast<U>(k_input[l+k]);
|
||||
const U c_loss = static_cast<U>(k_dout[l+k]);
|
||||
sum_loss1 += c_loss;
|
||||
sum_loss2 += c_loss * (c_h - c_mean) * c_invvar;
|
||||
}
|
||||
}
|
||||
for (; l < n2; ++l) {
|
||||
const U c_h = static_cast<U>(k_input[l]);
|
||||
const U c_loss = static_cast<U>(k_dout[l]);
|
||||
sum_loss1 += c_loss;
|
||||
sum_loss2 += c_loss * (c_h - c_mean) * c_invvar;
|
||||
}
|
||||
}
|
||||
// intra-warp reductions
|
||||
for (int mask = blockDim.x/2; mask > 0; mask /= 2) {
|
||||
sum_loss1 += WARP_SHFL_XOR(sum_loss1, mask);
|
||||
sum_loss2 += WARP_SHFL_XOR(sum_loss2, mask);
|
||||
}
|
||||
// inter-warp reductions
|
||||
if (blockDim.y > 1) {
|
||||
SharedMemory<U> shared;
|
||||
U* buf = shared.getPointer();
|
||||
for (int offset = blockDim.y/2; offset > 0; offset /= 2) {
|
||||
// upper half of warps write to shared
|
||||
if (threadIdx.y >= offset && threadIdx.y < 2*offset) {
|
||||
const int wrt_i = (threadIdx.y - offset) * blockDim.x + threadIdx.x;
|
||||
buf[2*wrt_i] = sum_loss1;
|
||||
buf[2*wrt_i+1] = sum_loss2;
|
||||
}
|
||||
__syncthreads();
|
||||
// lower half merges
|
||||
if (threadIdx.y < offset) {
|
||||
const int read_i = threadIdx.y * blockDim.x + threadIdx.x;
|
||||
sum_loss1 += buf[2*read_i];
|
||||
sum_loss2 += buf[2*read_i+1];
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
if (threadIdx.y == 0) {
|
||||
buf[2*threadIdx.x] = sum_loss1;
|
||||
buf[2*threadIdx.x+1] = sum_loss2;
|
||||
}
|
||||
__syncthreads();
|
||||
if (threadIdx.y !=0) {
|
||||
sum_loss1 = buf[2*threadIdx.x];
|
||||
sum_loss2 = buf[2*threadIdx.x+1];
|
||||
}
|
||||
}
|
||||
// all threads now have the two sums over l
|
||||
U fH = (U)n2;
|
||||
U term1 = (U(1) / fH) * c_invvar;
|
||||
T* k_grad_input = grad_input + i1*n2;
|
||||
if (gamma != NULL) {
|
||||
for (int l = thrx; l < n2; l+=numx) {
|
||||
const U c_h = static_cast<U>(k_input[l]);
|
||||
const U c_loss = static_cast<U>(k_dout[l]);
|
||||
U f_grad_input = fH * c_loss * gamma[l];
|
||||
f_grad_input -= sum_loss1;
|
||||
f_grad_input -= (c_h - c_mean) * c_invvar * sum_loss2;
|
||||
f_grad_input *= term1;
|
||||
k_grad_input[l] = static_cast<T>(f_grad_input);
|
||||
}
|
||||
} else {
|
||||
for (int l = thrx; l < n2; l+=numx) {
|
||||
const U c_h = static_cast<U>(k_input[l]);
|
||||
const U c_loss = static_cast<U>(k_dout[l]);
|
||||
U f_grad_input = fH * c_loss;
|
||||
f_grad_input -= sum_loss1;
|
||||
f_grad_input -= (c_h - c_mean) * c_invvar * sum_loss2;
|
||||
f_grad_input *= term1;
|
||||
k_grad_input[l] = static_cast<T>(f_grad_input);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
template<typename T, typename U, typename V>
|
||||
void HostApplyLayerNorm(
|
||||
V* output,
|
||||
U* mean,
|
||||
U* invvar,
|
||||
const T* input,
|
||||
int n1,
|
||||
int n2,
|
||||
double epsilon,
|
||||
const V* gamma,
|
||||
const V* beta
|
||||
)
|
||||
{
|
||||
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
||||
const dim3 threads(32,4,1);
|
||||
const uint64_t maxGridY =
|
||||
at::cuda::getCurrentDeviceProperties()->maxGridSize[1];
|
||||
const dim3 blocks(1, std::min((uint64_t)n1, maxGridY), 1);
|
||||
int nshared =
|
||||
threads.y > 1 ?
|
||||
threads.y*sizeof(U)+(threads.y/2)*sizeof(U) :
|
||||
0;
|
||||
cuApplyLayerNorm<<<blocks, threads, nshared, stream>>>(
|
||||
output,
|
||||
mean,
|
||||
invvar,
|
||||
input,
|
||||
n1,n2,
|
||||
U(epsilon),
|
||||
gamma,beta);
|
||||
}
|
||||
|
||||
|
||||
void cuda_layer_norm(
|
||||
at::Tensor* output,
|
||||
at::Tensor* mean,
|
||||
at::Tensor* invvar,
|
||||
at::Tensor* input,
|
||||
int n1,
|
||||
int n2,
|
||||
#ifdef VERSION_GE_1_1
|
||||
at::IntArrayRef normalized_shape,
|
||||
#else
|
||||
at::IntList normalized_shape,
|
||||
#endif
|
||||
at::Tensor* gamma,
|
||||
at::Tensor* beta,
|
||||
double epsilon)
|
||||
{
|
||||
using namespace at;
|
||||
DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(
|
||||
input->scalar_type(), output->scalar_type(), "cuda_layer_norm_kernel",
|
||||
HostApplyLayerNorm(
|
||||
output->DATA_PTR<scalar_t_out>(),
|
||||
mean->DATA_PTR<float>(),
|
||||
invvar->DATA_PTR<float>(),
|
||||
input->DATA_PTR<scalar_t_in>(),
|
||||
n1,n2,
|
||||
epsilon,
|
||||
gamma != NULL ? gamma->DATA_PTR<scalar_t_out>() : NULL,
|
||||
beta != NULL ? beta->DATA_PTR<scalar_t_out>() : NULL);
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
template<typename T, typename U, typename V>
|
||||
void HostLayerNormGradient(
|
||||
const V* dout,
|
||||
const U* mean,
|
||||
const U* invvar,
|
||||
at::Tensor* input,
|
||||
int n1,
|
||||
int n2,
|
||||
const V* gamma,
|
||||
const V* beta,
|
||||
double epsilon,
|
||||
T* grad_input,
|
||||
V* grad_gamma,
|
||||
V* grad_beta
|
||||
)
|
||||
{
|
||||
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
||||
|
||||
if (gamma != NULL && beta != NULL) {
|
||||
// compute grad_gamma(j) and grad_beta(j)
|
||||
const int part_size = 16;
|
||||
const dim3 threads2(32,4,1);
|
||||
const dim3 blocks2((n2+threads2.x-1)/threads2.x,part_size,1);
|
||||
const int nshared2_a = 2 * sizeof(U) * threads2.y * threads2.y *
|
||||
(threads2.x + 1);
|
||||
const int nshared2_b = threads2.x * threads2.y * sizeof(U);
|
||||
const int nshared2 = nshared2_a > nshared2_b ? nshared2_a : nshared2_b;
|
||||
at::Tensor part_grad_gamma = at::empty(
|
||||
{part_size,n2}, input->options().dtype(at::ScalarType::Float));
|
||||
at::Tensor part_grad_beta = at::empty_like(part_grad_gamma);
|
||||
cuComputePartGradGammaBeta<<<blocks2, threads2, nshared2, stream>>>(
|
||||
dout,
|
||||
input->DATA_PTR<T>(),
|
||||
n1,n2,
|
||||
mean,
|
||||
invvar,
|
||||
U(epsilon),
|
||||
part_grad_gamma.DATA_PTR<U>(),
|
||||
part_grad_beta.DATA_PTR<U>());
|
||||
|
||||
const dim3 threads3(32,8,1);
|
||||
const dim3 blocks3((n2+threads2.x-1)/threads2.x,1,1);
|
||||
const int nshared3 = threads3.x * threads3.y * sizeof(U);
|
||||
cuComputeGradGammaBeta<<<blocks3, threads3, nshared3, stream>>>(
|
||||
part_grad_gamma.DATA_PTR<U>(),
|
||||
part_grad_beta.DATA_PTR<U>(),
|
||||
part_size,
|
||||
n1,n2,
|
||||
grad_gamma,
|
||||
grad_beta);
|
||||
}
|
||||
|
||||
// compute grad_input
|
||||
const uint64_t maxGridY =
|
||||
at::cuda::getCurrentDeviceProperties()->maxGridSize[1];
|
||||
const dim3 blocks1(1, std::min((uint64_t)n1, maxGridY), 1);
|
||||
const dim3 threads1(32,4,1);
|
||||
int nshared =
|
||||
threads1.y > 1 ?
|
||||
threads1.y*threads1.x*sizeof(U) :
|
||||
0;
|
||||
cuComputeGradInput<<<blocks1, threads1, nshared, stream>>>(
|
||||
dout,
|
||||
input->DATA_PTR<T>(),
|
||||
n1,n2,
|
||||
mean,
|
||||
invvar,
|
||||
U(epsilon),
|
||||
gamma,
|
||||
grad_input);
|
||||
}
|
||||
|
||||
|
||||
void cuda_layer_norm_gradient(
|
||||
at::Tensor* dout,
|
||||
at::Tensor* mean,
|
||||
at::Tensor* invvar,
|
||||
at::Tensor* input,
|
||||
int n1,
|
||||
int n2,
|
||||
#ifdef VERSION_GE_1_1
|
||||
at::IntArrayRef normalized_shape,
|
||||
#else
|
||||
at::IntList normalized_shape,
|
||||
#endif
|
||||
at::Tensor* gamma,
|
||||
at::Tensor* beta,
|
||||
double epsilon,
|
||||
at::Tensor* grad_input,
|
||||
at::Tensor* grad_gamma,
|
||||
at::Tensor* grad_beta)
|
||||
{
|
||||
using namespace at;
|
||||
DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(
|
||||
input->scalar_type(), gamma->scalar_type(),
|
||||
"cuda_layer_norm_gradient_kernel",
|
||||
HostLayerNormGradient(
|
||||
dout->DATA_PTR<scalar_t_out>(),
|
||||
mean->DATA_PTR<float>(),
|
||||
invvar->DATA_PTR<float>(),
|
||||
input,
|
||||
n1,n2,
|
||||
// TMJ pass NULL argument for gamma, beta, grad_gamma and grad_beta
|
||||
// if gamma Tensor is NULL on input.
|
||||
gamma != NULL ? gamma->DATA_PTR<scalar_t_out>() : NULL,
|
||||
gamma != NULL ? beta->DATA_PTR<scalar_t_out>() : NULL,
|
||||
epsilon,
|
||||
grad_input->DATA_PTR<scalar_t_in>(),
|
||||
gamma != NULL ? grad_gamma->DATA_PTR<scalar_t_out>() : NULL,
|
||||
gamma != NULL ? grad_beta->DATA_PTR<scalar_t_out>() : NULL);
|
||||
)
|
||||
}
|
|
@ -0,0 +1,97 @@
|
|||
/* coding=utf-8
|
||||
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <cuda_fp16.h>
|
||||
#include <torch/extension.h>
|
||||
#include <vector>
|
||||
|
||||
namespace multihead_attn {
|
||||
namespace fused_softmax {
|
||||
namespace scaled_masked_softmax {
|
||||
|
||||
torch::Tensor fwd_cuda(
|
||||
torch::Tensor const& input,
|
||||
torch::Tensor const& mask,
|
||||
float scale_factor);
|
||||
|
||||
torch::Tensor bwd_cuda(
|
||||
torch::Tensor const& output_grads,
|
||||
torch::Tensor const& softmax_results,
|
||||
float scale_factor);
|
||||
|
||||
int get_batch_per_block_cuda(
|
||||
int query_seq_len,
|
||||
int key_seq_len,
|
||||
int batches,
|
||||
int attn_heads);
|
||||
|
||||
torch::Tensor fwd(
|
||||
torch::Tensor const& input,
|
||||
torch::Tensor const& mask,
|
||||
float scale_factor) {
|
||||
AT_ASSERTM(input.dim() == 4, "expected 4D tensor");
|
||||
AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) ||
|
||||
(input.scalar_type() == at::ScalarType::BFloat16),
|
||||
"Only fp16 and bf16 are supported");
|
||||
AT_ASSERTM(mask.dim() == 4, "expected 4D tensor");
|
||||
|
||||
return fwd_cuda(input, mask, scale_factor);
|
||||
}
|
||||
|
||||
torch::Tensor bwd(
|
||||
torch::Tensor const& output_grads,
|
||||
torch::Tensor const& softmax_results,
|
||||
float scale_factor) {
|
||||
|
||||
AT_ASSERTM(output_grads.dim() == 4, "expected 3D tensor");
|
||||
AT_ASSERTM(softmax_results.dim() == 4, "expected 3D tensor");
|
||||
|
||||
AT_ASSERTM((output_grads.scalar_type() == at::ScalarType::Half) ||
|
||||
(output_grads.scalar_type() == at::ScalarType::BFloat16),
|
||||
"Only fp16 and bf16 are supported");
|
||||
AT_ASSERTM((softmax_results.scalar_type() == at::ScalarType::Half) ||
|
||||
(softmax_results.scalar_type() == at::ScalarType::BFloat16),
|
||||
"Only fp16 and bf16 are supported");
|
||||
|
||||
return bwd_cuda(output_grads, softmax_results, scale_factor);
|
||||
}
|
||||
|
||||
int get_batch_per_block(
|
||||
int query_seq_len,
|
||||
int key_seq_len,
|
||||
int batches,
|
||||
int attn_heads) {
|
||||
return get_batch_per_block_cuda(query_seq_len, key_seq_len, batches, attn_heads);
|
||||
}
|
||||
|
||||
} // end namespace scaled_masked_softmax
|
||||
} // end namespace fused_softmax
|
||||
} // end namespace multihead_attn
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("forward",
|
||||
&multihead_attn::fused_softmax::scaled_masked_softmax::fwd,
|
||||
"Self Multihead Attention scaled, time masked softmax -- Forward.");
|
||||
|
||||
m.def("backward",
|
||||
&multihead_attn::fused_softmax::scaled_masked_softmax::bwd,
|
||||
"Self Multihead Attention scaled, time masked softmax -- Backward.");
|
||||
|
||||
m.def("get_batch_per_block",
|
||||
&multihead_attn::fused_softmax::scaled_masked_softmax::get_batch_per_block,
|
||||
"Return Batch per block size."
|
||||
);
|
||||
}
|
|
@ -0,0 +1,505 @@
|
|||
/* coding=utf-8
|
||||
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <assert.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <cfloat>
|
||||
#include <limits>
|
||||
#include <stdint.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <c10/macros/Macros.h>
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename Datatype, int ELEMENTS_PER_LDG>
|
||||
__device__ __inline__ void copy_vector(Datatype *dst, const Datatype *src);
|
||||
|
||||
template <>
|
||||
__device__ __inline__ void copy_vector<c10::BFloat16, 1>(c10::BFloat16 *dst, const c10::BFloat16 *src) { *dst = *src; }
|
||||
|
||||
template <>
|
||||
__device__ __inline__ void copy_vector<c10::BFloat16, 4>(c10::BFloat16 *dst, const c10::BFloat16 *src) { *((float2*) dst) = *((float2*) src); }
|
||||
|
||||
template <>
|
||||
__device__ __inline__ void copy_vector<c10::Half, 1>(c10::Half *dst, const c10::Half *src) { *dst = *src; }
|
||||
|
||||
template <>
|
||||
__device__ __inline__ void copy_vector<c10::Half, 4>(c10::Half *dst, const c10::Half *src) { *((float2*) dst) = *((float2*) src); }
|
||||
|
||||
template <>
|
||||
__device__ __inline__ void copy_vector<uint8_t, 1>(uint8_t *dst, const uint8_t *src) { *dst = *src; }
|
||||
|
||||
template <>
|
||||
__device__ __inline__ void copy_vector<uint8_t, 4>(uint8_t *dst, const uint8_t *src) {*((half2*) dst) = *((half2*) src); }
|
||||
|
||||
int log2_ceil(int value) {
|
||||
int log2_value = 0;
|
||||
while ((1 << log2_value) < value) ++log2_value;
|
||||
return log2_value;
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
struct Add {
|
||||
__device__ __forceinline__ T operator()(T a, T b) const {
|
||||
return a + b;
|
||||
}
|
||||
};
|
||||
|
||||
template<typename T>
|
||||
struct Max {
|
||||
__device__ __forceinline__ T operator()(T a, T b) const {
|
||||
return a < b ? b : a;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
__device__ __forceinline__ T WARP_SHFL_XOR_NATIVE(T value, int laneMask, int width = warpSize, unsigned int mask = 0xffffffff)
|
||||
{
|
||||
#if CUDA_VERSION >= 9000
|
||||
return __shfl_xor_sync(mask, value, laneMask, width);
|
||||
#else
|
||||
return __shfl_xor(value, laneMask, width);
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename acc_t, int WARP_BATCH, int WARP_SIZE, template<typename> class ReduceOp>
|
||||
__device__ __forceinline__ void warp_reduce(acc_t* sum) {
|
||||
ReduceOp<acc_t> r;
|
||||
#pragma unroll
|
||||
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < WARP_BATCH; ++i) {
|
||||
acc_t b = WARP_SHFL_XOR_NATIVE(sum[i], offset, WARP_SIZE);
|
||||
sum[i] = r(sum[i], b);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
* Extended softmax (from native aten pytorch) with following additional features
|
||||
* 1) input scaling
|
||||
* 2) Explicit masking
|
||||
*/
|
||||
template <typename input_t, typename output_t, typename acc_t, int log2_elements>
|
||||
__global__ void scaled_masked_softmax_warp_forward(
|
||||
output_t *dst,
|
||||
const input_t *src,
|
||||
const uint8_t *mask,
|
||||
const acc_t scale,
|
||||
int micro_batch_size,
|
||||
int element_count,
|
||||
int pad_batches)
|
||||
{
|
||||
// WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and
|
||||
// warp_size of method warp_softmax_forward_kernel.
|
||||
constexpr int next_power_of_two = 1 << log2_elements;
|
||||
constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
|
||||
constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE;
|
||||
constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1;
|
||||
constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4;
|
||||
|
||||
// blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, )
|
||||
// gridDim/blockIdx = (seq_len, attn_heads, batches)
|
||||
int first_batch = (blockDim.y * (blockIdx.x + gridDim.x * (blockIdx.y + gridDim.y * blockIdx.z))+ threadIdx.y) * WARP_BATCH;
|
||||
int pad_first_batch = 0;
|
||||
if (pad_batches != 1) { // bert style
|
||||
pad_first_batch = (blockDim.y * (blockIdx.x + gridDim.x * blockIdx.z) + threadIdx.y) * WARP_BATCH;
|
||||
} else { // gpt2 style
|
||||
pad_first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH;
|
||||
}
|
||||
|
||||
// micro_batch_size might not be a multiple of WARP_BATCH. Check how
|
||||
// many batches have to computed within this WARP.
|
||||
int local_batches = micro_batch_size - first_batch;
|
||||
if (local_batches > WARP_BATCH)
|
||||
local_batches = WARP_BATCH;
|
||||
|
||||
// there might be multiple batches per warp. compute the index within the batch
|
||||
int local_idx = threadIdx.x;
|
||||
|
||||
src += first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx;
|
||||
dst += first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx;
|
||||
mask += pad_first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx;
|
||||
|
||||
// load data from global memory
|
||||
acc_t elements[WARP_BATCH][WARP_ITERATIONS];
|
||||
input_t temp_data[ELEMENTS_PER_LDG_STG];
|
||||
uint8_t temp_mask[ELEMENTS_PER_LDG_STG];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < WARP_BATCH; ++i) {
|
||||
int batch_element_count = (i >= local_batches) ? 0 : element_count;
|
||||
|
||||
#pragma unroll
|
||||
for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) {
|
||||
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
|
||||
|
||||
if (element_index < batch_element_count) {
|
||||
int itr_idx = i*element_count+it*WARP_SIZE;
|
||||
copy_vector<input_t, ELEMENTS_PER_LDG_STG>(temp_data, src + itr_idx);
|
||||
copy_vector<uint8_t, ELEMENTS_PER_LDG_STG>(temp_mask, mask + itr_idx);
|
||||
|
||||
#pragma unroll
|
||||
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
|
||||
if (temp_mask[element] != 1) {
|
||||
elements[i][it + element] = (acc_t)temp_data[element] * scale;
|
||||
} else {
|
||||
elements[i][it + element] = -10000.0;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
#pragma unroll
|
||||
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
|
||||
elements[i][it + element] = -std::numeric_limits<acc_t>::infinity();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// compute max_value
|
||||
acc_t max_value[WARP_BATCH];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < WARP_BATCH; ++i) {
|
||||
max_value[i] = elements[i][0];
|
||||
#pragma unroll
|
||||
for (int it = 1; it < WARP_ITERATIONS; ++it) {
|
||||
max_value[i] = (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it];
|
||||
}
|
||||
}
|
||||
warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Max>(max_value);
|
||||
|
||||
acc_t sum[WARP_BATCH] { 0.0f };
|
||||
#pragma unroll
|
||||
for (int i = 0; i < WARP_BATCH; ++i) {
|
||||
#pragma unroll
|
||||
for (int it = 0; it < WARP_ITERATIONS; ++it) {
|
||||
elements[i][it] = std::exp((elements[i][it] - max_value[i]));
|
||||
sum[i] += elements[i][it];
|
||||
}
|
||||
}
|
||||
warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Add>(sum);
|
||||
|
||||
// store result
|
||||
output_t out[ELEMENTS_PER_LDG_STG];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < WARP_BATCH; ++i) {
|
||||
if (i >= local_batches)
|
||||
break;
|
||||
#pragma unroll
|
||||
for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) {
|
||||
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
|
||||
if (element_index < element_count) {
|
||||
#pragma unroll
|
||||
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
|
||||
out[element] = elements[i][it + element] / sum[i];
|
||||
}
|
||||
copy_vector<output_t, ELEMENTS_PER_LDG_STG>(dst + i * element_count + it * WARP_SIZE, out);
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename input_t, typename output_t, typename acc_t, int log2_elements>
|
||||
__global__ void scaled_masked_softmax_warp_backward(
|
||||
output_t *gradInput,
|
||||
input_t *grad,
|
||||
const input_t *output,
|
||||
acc_t scale,
|
||||
int micro_batch_size,
|
||||
int element_count)
|
||||
{
|
||||
// WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and
|
||||
// warp_size of method warp_softmax_backward_kernel.
|
||||
constexpr int next_power_of_two = 1 << log2_elements;
|
||||
constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
|
||||
constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE;
|
||||
constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1;
|
||||
constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4;
|
||||
|
||||
// blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, )
|
||||
// gridDim/blockIdx = (seq_len, attn_heads, batches)
|
||||
int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH;
|
||||
|
||||
// micro_batch_size might not be a multiple of WARP_BATCH. Check how
|
||||
// many batches have to computed within this WARP.
|
||||
int local_batches = micro_batch_size - first_batch;
|
||||
if (local_batches > WARP_BATCH)
|
||||
local_batches = WARP_BATCH;
|
||||
|
||||
// there might be multiple batches per warp. compute the index within the batch
|
||||
int local_idx = threadIdx.x;
|
||||
|
||||
// the first element to process by the current thread
|
||||
int thread_offset = first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx;
|
||||
grad += thread_offset;
|
||||
output += thread_offset;
|
||||
gradInput += thread_offset;
|
||||
|
||||
// load data from global memory
|
||||
acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS] { 0.0f };
|
||||
acc_t output_reg[WARP_BATCH][WARP_ITERATIONS] { 0.0f };
|
||||
input_t temp_grad[ELEMENTS_PER_LDG_STG];
|
||||
input_t temp_output[ELEMENTS_PER_LDG_STG];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < WARP_BATCH; ++i) {
|
||||
int batch_element_count = (i >= local_batches) ? 0 : element_count;
|
||||
|
||||
#pragma unroll
|
||||
for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) {
|
||||
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
|
||||
if (element_index < batch_element_count) {
|
||||
copy_vector<input_t, ELEMENTS_PER_LDG_STG>(temp_grad, grad + i * element_count + it * WARP_SIZE);
|
||||
copy_vector<input_t, ELEMENTS_PER_LDG_STG>(temp_output, output + i * element_count + it * WARP_SIZE);
|
||||
|
||||
#pragma unroll
|
||||
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
|
||||
output_reg[i][it + element] = (acc_t)temp_output[element];
|
||||
}
|
||||
#pragma unroll
|
||||
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
|
||||
grad_reg[i][it + element] = (acc_t)temp_grad[element] * output_reg[i][it + element];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
acc_t sum[WARP_BATCH];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < WARP_BATCH; ++i) {
|
||||
sum[i] = grad_reg[i][0];
|
||||
#pragma unroll
|
||||
for (int it = 1; it < WARP_ITERATIONS; ++it) {
|
||||
sum[i] += grad_reg[i][it];
|
||||
}
|
||||
}
|
||||
warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Add>(sum);
|
||||
|
||||
// store result
|
||||
#pragma unroll
|
||||
for (int i = 0; i < WARP_BATCH; ++i) {
|
||||
if (i >= local_batches)
|
||||
break;
|
||||
#pragma unroll
|
||||
for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) {
|
||||
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
|
||||
if (element_index < element_count) {
|
||||
// compute gradients
|
||||
output_t out[ELEMENTS_PER_LDG_STG];
|
||||
#pragma unroll
|
||||
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
|
||||
out[element] = (output_t)(scale * (grad_reg[i][it + element] - output_reg[i][it + element] * sum[i]));
|
||||
}
|
||||
copy_vector<output_t, ELEMENTS_PER_LDG_STG>(gradInput + i * element_count + it * WARP_SIZE, out);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} // end of anonymous namespace
|
||||
|
||||
int get_batch_per_block(int query_seq_len, int key_seq_len, int batches, int attn_heads){
|
||||
int log2_elements = log2_ceil(key_seq_len);
|
||||
const int next_power_of_two = 1 << log2_elements;
|
||||
|
||||
int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
|
||||
int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;
|
||||
|
||||
constexpr int threads_per_block = 128;
|
||||
int warps_per_block = (threads_per_block / warp_size);
|
||||
int batches_per_block = warps_per_block * batches_per_warp;
|
||||
|
||||
return batches_per_block;
|
||||
}
|
||||
|
||||
template<typename input_t, typename output_t, typename acc_t>
|
||||
void dispatch_scaled_masked_softmax_forward(
|
||||
output_t *dst,
|
||||
const input_t *src,
|
||||
const uint8_t *mask,
|
||||
const input_t scale,
|
||||
int query_seq_len,
|
||||
int key_seq_len,
|
||||
int batches,
|
||||
int attn_heads,
|
||||
int pad_batches)
|
||||
{
|
||||
TORCH_INTERNAL_ASSERT(key_seq_len >= 0 && key_seq_len <= 2048 );
|
||||
if (key_seq_len == 0) {
|
||||
return;
|
||||
} else {
|
||||
int log2_elements = log2_ceil(key_seq_len);
|
||||
const int next_power_of_two = 1 << log2_elements;
|
||||
int batch_count = batches * attn_heads * query_seq_len;
|
||||
|
||||
// This value must match the WARP_SIZE constexpr value computed inside softmax_warp_forward.
|
||||
int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
|
||||
|
||||
// This value must match the WARP_BATCH constexpr value computed inside softmax_warp_forward.
|
||||
int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;
|
||||
|
||||
// use 128 threads per block to maximimize gpu utilization
|
||||
constexpr int threads_per_block = 128;
|
||||
|
||||
int warps_per_block = (threads_per_block / warp_size);
|
||||
int batches_per_block = warps_per_block * batches_per_warp;
|
||||
TORCH_INTERNAL_ASSERT(query_seq_len%batches_per_block == 0);
|
||||
dim3 blocks(query_seq_len/batches_per_block, attn_heads, batches);
|
||||
dim3 threads(warp_size, warps_per_block, 1);
|
||||
// Launch code would be more elegant if C++ supported FOR CONSTEXPR
|
||||
switch (log2_elements) {
|
||||
case 0: // 1
|
||||
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 0>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
|
||||
break;
|
||||
case 1: // 2
|
||||
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 1>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
|
||||
break;
|
||||
case 2: // 4
|
||||
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 2>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
|
||||
break;
|
||||
case 3: // 8
|
||||
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 3>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
|
||||
break;
|
||||
case 4: // 16
|
||||
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 4>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
|
||||
break;
|
||||
case 5: // 32
|
||||
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 5>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
|
||||
break;
|
||||
case 6: // 64
|
||||
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 6>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
|
||||
break;
|
||||
case 7: // 128
|
||||
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 7>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
|
||||
break;
|
||||
case 8: // 256
|
||||
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 8>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
|
||||
break;
|
||||
case 9: // 512
|
||||
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 9>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
|
||||
break;
|
||||
case 10: // 1024
|
||||
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 10>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
|
||||
break;
|
||||
case 11: // 2048
|
||||
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 11>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<typename input_t, typename output_t, typename acc_t>
|
||||
void dispatch_scaled_masked_softmax_backward(
|
||||
output_t *grad_input,
|
||||
input_t *grad,
|
||||
const input_t *output,
|
||||
const acc_t scale,
|
||||
int query_seq_len,
|
||||
int key_seq_len,
|
||||
int batches,
|
||||
int attn_heads)
|
||||
{
|
||||
TORCH_INTERNAL_ASSERT( key_seq_len >= 0 && key_seq_len <= 2048 );
|
||||
if (key_seq_len == 0) {
|
||||
return;
|
||||
} else {
|
||||
int log2_elements = log2_ceil(key_seq_len);
|
||||
const int next_power_of_two = 1 << log2_elements;
|
||||
int batch_count = batches * attn_heads * query_seq_len;
|
||||
|
||||
// This value must match the WARP_SIZE constexpr value computed inside softmax_warp_backward.
|
||||
int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
|
||||
|
||||
// This value must match the WARP_BATCH constexpr value computed inside softmax_warp_backward.
|
||||
int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;
|
||||
|
||||
// use 128 threads per block to maximimize gpu utilization
|
||||
constexpr int threads_per_block = 128;
|
||||
|
||||
int warps_per_block = (threads_per_block / warp_size);
|
||||
int batches_per_block = warps_per_block * batches_per_warp;
|
||||
int blocks = batch_count/batches_per_block;
|
||||
dim3 threads(warp_size, warps_per_block, 1);
|
||||
// Launch code would be more elegant if C++ supported FOR CONSTEXPR
|
||||
switch (log2_elements) {
|
||||
case 0: // 1
|
||||
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 0>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
|
||||
break;
|
||||
case 1: // 2
|
||||
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 1>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
|
||||
break;
|
||||
case 2: // 4
|
||||
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 2>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
|
||||
break;
|
||||
case 3: // 8
|
||||
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 3>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
|
||||
break;
|
||||
case 4: // 16
|
||||
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 4>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
|
||||
break;
|
||||
case 5: // 32
|
||||
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 5>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
|
||||
break;
|
||||
case 6: // 64
|
||||
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 6>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
|
||||
break;
|
||||
case 7: // 128
|
||||
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 7>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
|
||||
break;
|
||||
case 8: // 256
|
||||
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 8>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
|
||||
break;
|
||||
case 9: // 512
|
||||
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 9>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
|
||||
break;
|
||||
case 10: // 1024
|
||||
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 10>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
|
||||
break;
|
||||
case 11: // 2048
|
||||
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 11>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,117 @@
|
|||
/* coding=utf-8
|
||||
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <cuda.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda_profiler_api.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <torch/extension.h>
|
||||
#include "scaled_masked_softmax.h"
|
||||
#include "type_shim.h"
|
||||
|
||||
namespace multihead_attn {
|
||||
namespace fused_softmax {
|
||||
namespace scaled_masked_softmax {
|
||||
|
||||
int get_batch_per_block_cuda(int query_seq_len, int key_seq_len, int batches, int attn_heads){
|
||||
return get_batch_per_block(query_seq_len, key_seq_len, batches, attn_heads);
|
||||
}
|
||||
|
||||
|
||||
torch::Tensor fwd_cuda(
|
||||
torch::Tensor const& input,
|
||||
torch::Tensor const& mask,
|
||||
float scale_factor)
|
||||
{
|
||||
// input is a 4d tensor with dimensions [batches, attn_heads, seq_len, seq_len]
|
||||
const int batches = input.size(0);
|
||||
const int pad_batches = mask.size(0);
|
||||
const int attn_heads = input.size(1);
|
||||
const int query_seq_len = input.size(2);
|
||||
const int key_seq_len = input.size(3);
|
||||
TORCH_INTERNAL_ASSERT(key_seq_len <= 2048);
|
||||
TORCH_INTERNAL_ASSERT(query_seq_len > 1);
|
||||
TORCH_INTERNAL_ASSERT(pad_batches == 1 || pad_batches == batches);
|
||||
TORCH_INTERNAL_ASSERT(mask.size(1) == 1);
|
||||
TORCH_INTERNAL_ASSERT(mask.size(2) == query_seq_len);
|
||||
TORCH_INTERNAL_ASSERT(mask.size(3) == key_seq_len);
|
||||
|
||||
// Output
|
||||
auto act_options = input.options().requires_grad(false);
|
||||
torch::Tensor softmax_results =
|
||||
torch::empty({batches, attn_heads, query_seq_len, key_seq_len}, act_options);
|
||||
|
||||
// Softmax Intermediate Result Ptr
|
||||
void* input_ptr = static_cast<void*>(input.data_ptr());
|
||||
void* mask_ptr = static_cast<void*>(mask.data_ptr());
|
||||
void* softmax_results_ptr = static_cast<void*>(softmax_results.data_ptr());
|
||||
|
||||
DISPATCH_HALF_AND_BFLOAT(
|
||||
input.scalar_type(),
|
||||
"dispatch_scaled_masked_softmax_forward",
|
||||
dispatch_scaled_masked_softmax_forward<scalar_t, scalar_t, float>(
|
||||
reinterpret_cast<scalar_t*>(softmax_results_ptr),
|
||||
reinterpret_cast<const scalar_t*>(input_ptr),
|
||||
reinterpret_cast<const uint8_t*>(mask_ptr),
|
||||
scale_factor,
|
||||
query_seq_len,
|
||||
key_seq_len,
|
||||
batches,
|
||||
attn_heads,
|
||||
pad_batches);
|
||||
);
|
||||
return softmax_results;
|
||||
}
|
||||
|
||||
torch::Tensor bwd_cuda(
|
||||
torch::Tensor const& output_grads_,
|
||||
torch::Tensor const& softmax_results_,
|
||||
float scale_factor) {
|
||||
|
||||
auto output_grads = output_grads_.contiguous();
|
||||
auto softmax_results = softmax_results_.contiguous();
|
||||
|
||||
//output grads is a 4d tensor with dimensions [batches, attn_heads, seq_len, seq_len]
|
||||
const int batches = output_grads.size(0);
|
||||
const int attn_heads = output_grads.size(1);
|
||||
const int query_seq_len = output_grads.size(2);
|
||||
const int key_seq_len = output_grads.size(3);
|
||||
|
||||
void* output_grads_ptr = static_cast<void*>(output_grads.data_ptr());
|
||||
|
||||
//Softmax Grad
|
||||
DISPATCH_HALF_AND_BFLOAT(
|
||||
output_grads_.scalar_type(),
|
||||
"dispatch_scaled_masked_softmax_backward",
|
||||
dispatch_scaled_masked_softmax_backward<scalar_t, scalar_t, float>(
|
||||
reinterpret_cast<scalar_t*>(output_grads_ptr),
|
||||
reinterpret_cast<scalar_t*>(output_grads_ptr),
|
||||
reinterpret_cast<scalar_t const*>(softmax_results.data_ptr()),
|
||||
scale_factor,
|
||||
query_seq_len,
|
||||
key_seq_len,
|
||||
batches,
|
||||
attn_heads);
|
||||
);
|
||||
|
||||
//backward pass is completely in-place
|
||||
return output_grads;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,72 @@
|
|||
/* coding=utf-8
|
||||
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <cuda_fp16.h>
|
||||
#include <torch/extension.h>
|
||||
#include <vector>
|
||||
|
||||
namespace multihead_attn {
|
||||
namespace fused_softmax {
|
||||
namespace scaled_upper_triang_masked_softmax {
|
||||
|
||||
torch::Tensor fwd_cuda(
|
||||
torch::Tensor const& input,
|
||||
float scale_factor);
|
||||
|
||||
torch::Tensor bwd_cuda(
|
||||
torch::Tensor const& output_grads,
|
||||
torch::Tensor const& softmax_results,
|
||||
float scale_factor);
|
||||
|
||||
torch::Tensor fwd(torch::Tensor const& input, float scale_factor) {
|
||||
AT_ASSERTM(input.dim() == 3, "expected 3D tensor");
|
||||
AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) ||
|
||||
(input.scalar_type() == at::ScalarType::BFloat16),
|
||||
"Only fp16 and bf16 are supported");
|
||||
|
||||
return fwd_cuda(input, scale_factor);
|
||||
}
|
||||
|
||||
torch::Tensor bwd(
|
||||
torch::Tensor const& output_grads,
|
||||
torch::Tensor const& softmax_results,
|
||||
float scale_factor) {
|
||||
|
||||
AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor");
|
||||
AT_ASSERTM(softmax_results.dim() == 3, "expected 3D tensor");
|
||||
|
||||
AT_ASSERTM((output_grads.scalar_type() == at::ScalarType::Half) ||
|
||||
(output_grads.scalar_type() == at::ScalarType::BFloat16),
|
||||
"Only fp16 and bf16 are supported");
|
||||
AT_ASSERTM((softmax_results.scalar_type() == at::ScalarType::Half) ||
|
||||
(softmax_results.scalar_type() == at::ScalarType::BFloat16),
|
||||
"Only fp16 and bf16 are supported");
|
||||
|
||||
return bwd_cuda(output_grads, softmax_results, scale_factor);
|
||||
}
|
||||
|
||||
} // end namespace scaled_upper_triang_masked_softmax
|
||||
} // end namespace fused_softmax
|
||||
} // end namespace multihead_attn
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("forward",
|
||||
&multihead_attn::fused_softmax::scaled_upper_triang_masked_softmax::fwd,
|
||||
"Self Multihead Attention scaled, time masked softmax -- Forward.");
|
||||
m.def("backward",
|
||||
&multihead_attn::fused_softmax::scaled_upper_triang_masked_softmax::bwd,
|
||||
"Self Multihead Attention scaled, time masked softmax -- Backward.");
|
||||
}
|
|
@ -0,0 +1,513 @@
|
|||
/* coding=utf-8
|
||||
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <assert.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <cfloat>
|
||||
#include <limits>
|
||||
#include <stdint.h>
|
||||
#include <c10/macros/Macros.h>
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename Datatype, int ELEMENTS_PER_LDG>
|
||||
__device__ __inline__ void copy_vector(Datatype *dst, const Datatype *src);
|
||||
|
||||
template <>
|
||||
__device__ __inline__ void copy_vector<c10::BFloat16, 1>(c10::BFloat16 *dst, const c10::BFloat16 *src) { *dst = *src; }
|
||||
|
||||
template <>
|
||||
__device__ __inline__ void copy_vector<c10::BFloat16, 4>(c10::BFloat16 *dst, const c10::BFloat16 *src) { *((float2*) dst) = *((float2*) src); }
|
||||
|
||||
template <>
|
||||
__device__ __inline__ void copy_vector<c10::Half, 1>(c10::Half *dst, const c10::Half *src) { *dst = *src; }
|
||||
|
||||
template <>
|
||||
__device__ __inline__ void copy_vector<c10::Half, 4>(c10::Half *dst, const c10::Half *src) { *((float2*) dst) = *((float2*) src); }
|
||||
|
||||
template <>
|
||||
__device__ __inline__ void copy_vector<uint8_t, 1>(uint8_t *dst, const uint8_t *src) { *dst = *src; }
|
||||
|
||||
template <>
|
||||
__device__ __inline__ void copy_vector<uint8_t, 4>(uint8_t *dst, const uint8_t *src) {*((half2*) dst) = *((half2*) src); }
|
||||
|
||||
template <typename Datatype, int ELEMENTS_PER_LDG>
|
||||
__device__ __inline__ void copy_zero_vector(Datatype *dst);
|
||||
|
||||
template <>
|
||||
__device__ __inline__ void copy_zero_vector<c10::BFloat16, 1>(c10::BFloat16 *dst) { *dst = 0.0; }
|
||||
|
||||
template <>
|
||||
__device__ __inline__ void copy_zero_vector<c10::BFloat16, 4>(c10::BFloat16 *dst) { *((float2*) dst) = make_float2(0.0f, 0.0f); }
|
||||
|
||||
template <>
|
||||
__device__ __inline__ void copy_zero_vector<c10::Half, 1>(c10::Half *dst) { *dst = 0.0; }
|
||||
|
||||
template <>
|
||||
__device__ __inline__ void copy_zero_vector<c10::Half, 4>(c10::Half *dst) { *((float2*) dst) = make_float2(0.0f, 0.0f); }
|
||||
|
||||
|
||||
int log2_ceil(int value) {
|
||||
int log2_value = 0;
|
||||
while ((1 << log2_value) < value) ++log2_value;
|
||||
return log2_value;
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
struct Add {
|
||||
__device__ __forceinline__ T operator()(T a, T b) const {
|
||||
return a + b;
|
||||
}
|
||||
};
|
||||
|
||||
template<typename T>
|
||||
struct Max {
|
||||
__device__ __forceinline__ T operator()(T a, T b) const {
|
||||
return a < b ? b : a;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
__device__ __forceinline__ T WARP_SHFL_XOR_NATIVE(T value, int laneMask, int width = warpSize, unsigned int mask = 0xffffffff)
|
||||
{
|
||||
#if CUDA_VERSION >= 9000
|
||||
return __shfl_xor_sync(mask, value, laneMask, width);
|
||||
#else
|
||||
return __shfl_xor(value, laneMask, width);
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename acc_t, int WARP_BATCH, int WARP_SIZE, template<typename> class ReduceOp>
|
||||
__device__ __forceinline__ void warp_reduce(acc_t* sum) {
|
||||
ReduceOp<acc_t> r;
|
||||
#pragma unroll
|
||||
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < WARP_BATCH; ++i) {
|
||||
acc_t b = WARP_SHFL_XOR_NATIVE(sum[i], offset, WARP_SIZE);
|
||||
sum[i] = r(sum[i], b);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
* Extended softmax (from native aten pytorch) with following additional features
|
||||
* 1) input scaling
|
||||
* 2) Implicit time (diagonal masking)
|
||||
*/
|
||||
template <typename input_t, typename output_t, typename acc_t, int log2_elements>
|
||||
__global__ void scaled_upper_triang_masked_softmax_warp_forward(
|
||||
output_t *dst,
|
||||
const input_t *src,
|
||||
const acc_t scale,
|
||||
int micro_batch_size,
|
||||
int stride,
|
||||
int element_count)
|
||||
{
|
||||
// WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and
|
||||
// warp_size of method warp_softmax_forward_kernel.
|
||||
constexpr int next_power_of_two = 1 << log2_elements;
|
||||
constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
|
||||
constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE;
|
||||
constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1;
|
||||
constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4;
|
||||
|
||||
int first_batch = (blockDim.y * blockIdx.y + threadIdx.y) * gridDim.x * WARP_BATCH + blockIdx.x;
|
||||
int local_seq = blockIdx.x + 1;
|
||||
int warp_iteration_limit = (local_seq + ELEMENTS_PER_LDG_STG * WARP_SIZE - 1)/ WARP_SIZE;
|
||||
|
||||
// micro_batch_size might not be a multiple of WARP_BATCH. Check how
|
||||
// many batches have to computed within this WARP.
|
||||
int local_batches = micro_batch_size - first_batch;
|
||||
if (local_batches > WARP_BATCH)
|
||||
local_batches = WARP_BATCH;
|
||||
|
||||
// there might be multiple batches per warp. compute the index within the batch
|
||||
int local_idx = threadIdx.x;
|
||||
|
||||
src += first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx;
|
||||
dst += first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx;
|
||||
|
||||
// load data from global memory
|
||||
acc_t elements[WARP_BATCH][WARP_ITERATIONS];
|
||||
input_t temp_data[ELEMENTS_PER_LDG_STG];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < WARP_BATCH; ++i) {
|
||||
int batch_element_count = (i >= local_batches) ? 0 : local_seq;
|
||||
|
||||
#pragma unroll
|
||||
for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) {
|
||||
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
|
||||
|
||||
if (element_index < batch_element_count) {
|
||||
copy_vector<input_t, ELEMENTS_PER_LDG_STG>(temp_data, src + i*element_count*stride + it*WARP_SIZE);
|
||||
|
||||
#pragma unroll
|
||||
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
|
||||
if ((element_index + element) < batch_element_count) {
|
||||
elements[i][it+element] = (acc_t)temp_data[element] * scale;
|
||||
} else {
|
||||
elements[i][it + element] = -std::numeric_limits<acc_t>::infinity();
|
||||
}
|
||||
}
|
||||
} else {
|
||||
#pragma unroll
|
||||
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
|
||||
elements[i][it + element] = -std::numeric_limits<acc_t>::infinity();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// compute max_value
|
||||
acc_t max_value[WARP_BATCH];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < WARP_BATCH; ++i) {
|
||||
max_value[i] = elements[i][0];
|
||||
#pragma unroll
|
||||
for (int it = 1; it < WARP_ITERATIONS; ++it) {
|
||||
max_value[i] = (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it];
|
||||
}
|
||||
}
|
||||
warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Max>(max_value);
|
||||
|
||||
acc_t sum[WARP_BATCH] { 0.0f };
|
||||
#pragma unroll
|
||||
for (int i = 0; i < WARP_BATCH; ++i) {
|
||||
#pragma unroll
|
||||
for (int it = 0; it < WARP_ITERATIONS; ++it) {
|
||||
if (it < warp_iteration_limit) {
|
||||
elements[i][it] = std::exp((elements[i][it] - max_value[i]));
|
||||
sum[i] += elements[i][it];
|
||||
}
|
||||
}
|
||||
}
|
||||
warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Add>(sum);
|
||||
|
||||
// store result
|
||||
output_t out[ELEMENTS_PER_LDG_STG];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < WARP_BATCH; ++i) {
|
||||
if (i >= local_batches)
|
||||
break;
|
||||
#pragma unroll
|
||||
for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) {
|
||||
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
|
||||
|
||||
if (element_index < local_seq) {
|
||||
|
||||
#pragma unroll
|
||||
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
|
||||
if (element_index + element < local_seq) {
|
||||
out[element] = elements[i][it + element] / sum[i];
|
||||
} else {
|
||||
out[element] = 0;
|
||||
}
|
||||
}
|
||||
copy_vector<output_t, ELEMENTS_PER_LDG_STG>(dst + i * element_count * stride + it * WARP_SIZE, out);
|
||||
} else if (element_index < element_count) {
|
||||
copy_zero_vector<output_t, ELEMENTS_PER_LDG_STG>(dst + i * element_count * stride + it * WARP_SIZE);
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename input_t, typename output_t, typename acc_t, int log2_elements>
|
||||
__global__ void scaled_upper_triang_masked_softmax_warp_backward(
|
||||
output_t *gradInput,
|
||||
input_t *grad,
|
||||
const input_t *output,
|
||||
acc_t scale,
|
||||
int micro_batch_size,
|
||||
int stride,
|
||||
int element_count)
|
||||
{
|
||||
// WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and
|
||||
// warp_size of method warp_softmax_backward_kernel.
|
||||
constexpr int next_power_of_two = 1 << log2_elements;
|
||||
constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
|
||||
constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE;
|
||||
constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1;
|
||||
constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4;
|
||||
|
||||
int first_batch = (blockDim.y * blockIdx.y + threadIdx.y) * gridDim.x * WARP_BATCH + blockIdx.x;
|
||||
int local_seq = blockIdx.x + 1;
|
||||
|
||||
// micro_batch_size might not be a multiple of WARP_BATCH. Check how
|
||||
// many batches have to computed within this WARP.
|
||||
int local_batches = micro_batch_size - first_batch;
|
||||
if (local_batches > WARP_BATCH)
|
||||
local_batches = WARP_BATCH;
|
||||
|
||||
// there might be multiple batches per warp. compute the index within the batch
|
||||
int local_idx = threadIdx.x;
|
||||
|
||||
// the first element to process by the current thread
|
||||
int thread_offset = first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx;
|
||||
grad += thread_offset;
|
||||
output += thread_offset;
|
||||
gradInput += thread_offset;
|
||||
|
||||
// load data from global memory
|
||||
acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS] { 0.0f };
|
||||
acc_t output_reg[WARP_BATCH][WARP_ITERATIONS] { 0.0f };
|
||||
input_t temp_grad[ELEMENTS_PER_LDG_STG];
|
||||
input_t temp_output[ELEMENTS_PER_LDG_STG];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < WARP_BATCH; ++i) {
|
||||
int batch_element_count = (i >= local_batches) ? 0 : local_seq;
|
||||
|
||||
#pragma unroll
|
||||
for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) {
|
||||
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
|
||||
if (element_index < batch_element_count) {
|
||||
copy_vector<input_t, ELEMENTS_PER_LDG_STG>(temp_grad, grad + i * element_count * stride + it * WARP_SIZE);
|
||||
copy_vector<input_t, ELEMENTS_PER_LDG_STG>(temp_output, output + i * element_count * stride + it * WARP_SIZE);
|
||||
|
||||
#pragma unroll
|
||||
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
|
||||
if (element_index + element < batch_element_count) {
|
||||
output_reg[i][it + element] = (acc_t)temp_output[element];
|
||||
}
|
||||
}
|
||||
#pragma unroll
|
||||
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
|
||||
if (element_index + element < batch_element_count) {
|
||||
grad_reg[i][it + element] = (acc_t)temp_grad[element] * output_reg[i][it + element];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
acc_t sum[WARP_BATCH];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < WARP_BATCH; ++i) {
|
||||
sum[i] = grad_reg[i][0];
|
||||
#pragma unroll
|
||||
for (int it = 1; it < WARP_ITERATIONS; ++it) {
|
||||
sum[i] += grad_reg[i][it];
|
||||
}
|
||||
}
|
||||
warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Add>(sum);
|
||||
|
||||
// store result
|
||||
#pragma unroll
|
||||
for (int i = 0; i < WARP_BATCH; ++i) {
|
||||
if (i >= local_batches)
|
||||
break;
|
||||
#pragma unroll
|
||||
for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) {
|
||||
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
|
||||
if (element_index < element_count) {
|
||||
// compute gradients
|
||||
output_t out[ELEMENTS_PER_LDG_STG];
|
||||
#pragma unroll
|
||||
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
|
||||
out[element] = (output_t)(scale * (grad_reg[i][it + element] - output_reg[i][it + element] * sum[i]));
|
||||
}
|
||||
copy_vector<output_t, ELEMENTS_PER_LDG_STG>(gradInput + i * element_count * stride + it * WARP_SIZE, out);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // end of anonymous namespace
|
||||
|
||||
template<typename input_t, typename output_t, typename acc_t>
|
||||
void dispatch_scaled_upper_triang_masked_softmax_forward(
|
||||
output_t *dst,
|
||||
const input_t *src,
|
||||
const input_t scale,
|
||||
int softmax_elements,
|
||||
int softmax_elements_stride,
|
||||
int attn_batches)
|
||||
{
|
||||
TORCH_INTERNAL_ASSERT(softmax_elements >= 0 && softmax_elements <= 2048 );
|
||||
if (softmax_elements == 0) {
|
||||
return;
|
||||
} else {
|
||||
int log2_elements = log2_ceil(softmax_elements);
|
||||
const int next_power_of_two = 1 << log2_elements;
|
||||
int seq_len = softmax_elements;
|
||||
int batch_count = attn_batches * seq_len;
|
||||
|
||||
// This value must match the WARP_SIZE constexpr value computed inside softmax_warp_forward.
|
||||
int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
|
||||
|
||||
// This value must match the WARP_BATCH constexpr value computed inside softmax_warp_forward.
|
||||
int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;
|
||||
|
||||
// use 128 threads per block to maximimize gpu utilization
|
||||
constexpr int threads_per_block = 128;
|
||||
|
||||
int warps_per_block = (threads_per_block / warp_size);
|
||||
int batches_per_block = warps_per_block * batches_per_warp;
|
||||
TORCH_INTERNAL_ASSERT(attn_batches % batches_per_block == 0);
|
||||
|
||||
int blocks_per_seq = attn_batches / batches_per_block;
|
||||
dim3 blocks(seq_len, blocks_per_seq, 1);
|
||||
dim3 threads(warp_size, warps_per_block, 1);
|
||||
// Launch code would be more elegant if C++ supported FOR CONSTEXPR
|
||||
switch (log2_elements) {
|
||||
case 0: // 1
|
||||
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 0>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
|
||||
break;
|
||||
case 1: // 2
|
||||
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 1>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
|
||||
break;
|
||||
case 2: // 4
|
||||
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 2>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
|
||||
break;
|
||||
case 3: // 8
|
||||
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 3>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
|
||||
break;
|
||||
case 4: // 16
|
||||
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 4>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
|
||||
break;
|
||||
case 5: // 32
|
||||
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 5>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
|
||||
break;
|
||||
case 6: // 64
|
||||
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 6>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
|
||||
break;
|
||||
case 7: // 128
|
||||
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 7>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
|
||||
break;
|
||||
case 8: // 256
|
||||
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 8>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
|
||||
break;
|
||||
case 9: // 512
|
||||
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 9>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
|
||||
break;
|
||||
case 10: // 1024
|
||||
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 10>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
|
||||
break;
|
||||
case 11: // 2048
|
||||
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 11>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<typename input_t, typename output_t, typename acc_t>
|
||||
void dispatch_scaled_upper_triang_masked_softmax_backward(
|
||||
output_t *grad_input,
|
||||
input_t *grad,
|
||||
const input_t *output,
|
||||
const acc_t scale,
|
||||
int softmax_elements,
|
||||
int softmax_elements_stride,
|
||||
int attn_batches)
|
||||
{
|
||||
TORCH_INTERNAL_ASSERT( softmax_elements >= 0 && softmax_elements <= 2048 );
|
||||
if (softmax_elements == 0) {
|
||||
return;
|
||||
} else {
|
||||
int log2_elements = log2_ceil(softmax_elements);
|
||||
const int next_power_of_two = 1 << log2_elements;
|
||||
int seq_len = softmax_elements;
|
||||
int batch_count = attn_batches * seq_len;
|
||||
|
||||
// This value must match the WARP_SIZE constexpr value computed inside softmax_warp_backward.
|
||||
int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
|
||||
|
||||
// This value must match the WARP_BATCH constexpr value computed inside softmax_warp_backward.
|
||||
int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;
|
||||
|
||||
// use 128 threads per block to maximimize gpu utilization
|
||||
constexpr int threads_per_block = 128;
|
||||
|
||||
int warps_per_block = (threads_per_block / warp_size);
|
||||
int batches_per_block = warps_per_block * batches_per_warp;
|
||||
TORCH_INTERNAL_ASSERT(attn_batches % batches_per_block == 0);
|
||||
|
||||
int blocks_per_seq = attn_batches / batches_per_block;
|
||||
dim3 blocks(seq_len, blocks_per_seq, 1);
|
||||
dim3 threads(warp_size, warps_per_block, 1);
|
||||
// Launch code would be more elegant if C++ supported FOR CONSTEXPR
|
||||
switch (log2_elements) {
|
||||
case 0: // 1
|
||||
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 0>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
|
||||
break;
|
||||
case 1: // 2
|
||||
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 1>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
|
||||
break;
|
||||
case 2: // 4
|
||||
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 2>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
|
||||
break;
|
||||
case 3: // 8
|
||||
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 3>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
|
||||
break;
|
||||
case 4: // 16
|
||||
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 4>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
|
||||
break;
|
||||
case 5: // 32
|
||||
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 5>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
|
||||
break;
|
||||
case 6: // 64
|
||||
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 6>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
|
||||
break;
|
||||
case 7: // 128
|
||||
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 7>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
|
||||
break;
|
||||
case 8: // 256
|
||||
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 8>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
|
||||
break;
|
||||
case 9: // 512
|
||||
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 9>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
|
||||
break;
|
||||
case 10: // 1024
|
||||
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 10>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
|
||||
break;
|
||||
case 11: // 2048
|
||||
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 11>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,98 @@
|
|||
/* coding=utf-8
|
||||
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <cuda.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda_profiler_api.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <torch/extension.h>
|
||||
#include "scaled_upper_triang_masked_softmax.h"
|
||||
#include "type_shim.h"
|
||||
|
||||
namespace multihead_attn {
|
||||
namespace fused_softmax {
|
||||
namespace scaled_upper_triang_masked_softmax {
|
||||
|
||||
torch::Tensor fwd_cuda(
|
||||
torch::Tensor const& input,
|
||||
float scale_factor)
|
||||
{
|
||||
// input is a 3d tensor with dimensions [attn_batches, seq_len, seq_len]
|
||||
const int attn_batches = input.size(0);
|
||||
const int seq_len = input.size(1);
|
||||
TORCH_INTERNAL_ASSERT(seq_len <= 2048);
|
||||
|
||||
// Output
|
||||
auto act_options = input.options().requires_grad(false);
|
||||
torch::Tensor softmax_results =
|
||||
torch::empty({attn_batches, seq_len, seq_len}, act_options);
|
||||
|
||||
// Softmax Intermediate Result Ptr
|
||||
void* input_ptr = static_cast<void*>(input.data_ptr());
|
||||
void* softmax_results_ptr = static_cast<void*>(softmax_results.data_ptr());
|
||||
|
||||
DISPATCH_HALF_AND_BFLOAT(
|
||||
input.scalar_type(),
|
||||
"dispatch_scaled_upper_triang_masked_softmax_forward",
|
||||
dispatch_scaled_upper_triang_masked_softmax_forward<scalar_t, scalar_t, float>(
|
||||
reinterpret_cast<scalar_t*>(softmax_results_ptr),
|
||||
reinterpret_cast<const scalar_t*>(input_ptr),
|
||||
scale_factor,
|
||||
seq_len,
|
||||
seq_len,
|
||||
attn_batches);
|
||||
);
|
||||
return softmax_results;
|
||||
}
|
||||
|
||||
|
||||
torch::Tensor bwd_cuda(
|
||||
torch::Tensor const& output_grads_,
|
||||
torch::Tensor const& softmax_results_,
|
||||
float scale_factor) {
|
||||
|
||||
auto output_grads = output_grads_.contiguous();
|
||||
auto softmax_results = softmax_results_.contiguous();
|
||||
|
||||
//output grads is a 3d tensor with dimensions [attn_batches, seq_len, seq_len]
|
||||
const int attn_batches = output_grads.size(0);
|
||||
const int seq_len = output_grads.size(1);
|
||||
TORCH_INTERNAL_ASSERT(output_grads.size(1) == output_grads.size(2));
|
||||
|
||||
void* output_grads_ptr = static_cast<void*>(output_grads.data_ptr());
|
||||
|
||||
//Softmax Grad
|
||||
DISPATCH_HALF_AND_BFLOAT(
|
||||
output_grads_.scalar_type(),
|
||||
"dispatch_scaled_upper_triang_masked_softmax_backward",
|
||||
dispatch_scaled_upper_triang_masked_softmax_backward<scalar_t, scalar_t, float>(
|
||||
reinterpret_cast<scalar_t*>(output_grads_ptr),
|
||||
reinterpret_cast<scalar_t*>(output_grads_ptr),
|
||||
reinterpret_cast<scalar_t const*>(softmax_results.data_ptr()),
|
||||
scale_factor,
|
||||
seq_len,
|
||||
seq_len,
|
||||
attn_batches);
|
||||
);
|
||||
|
||||
//backward pass is completely in-place
|
||||
return output_grads;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,91 @@
|
|||
/* coding=utf-8
|
||||
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include "compat.h"
|
||||
|
||||
|
||||
#define DISPATCH_HALF_AND_BFLOAT(TYPE, NAME, ...) \
|
||||
switch(TYPE) \
|
||||
{ \
|
||||
case at::ScalarType::Half: \
|
||||
{ \
|
||||
using scalar_t = at::Half; \
|
||||
__VA_ARGS__; \
|
||||
break; \
|
||||
} \
|
||||
case at::ScalarType::BFloat16: \
|
||||
{ \
|
||||
using scalar_t = at::BFloat16; \
|
||||
__VA_ARGS__; \
|
||||
break; \
|
||||
} \
|
||||
default: \
|
||||
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
|
||||
}
|
||||
|
||||
|
||||
|
||||
#define DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(TYPEIN, TYPEOUT, NAME, ...) \
|
||||
switch(TYPEIN) \
|
||||
{ \
|
||||
case at::ScalarType::Float: \
|
||||
{ \
|
||||
using scalar_t_in = float; \
|
||||
switch(TYPEOUT) \
|
||||
{ \
|
||||
case at::ScalarType::Float: \
|
||||
{ \
|
||||
using scalar_t_out = float; \
|
||||
__VA_ARGS__; \
|
||||
break; \
|
||||
} \
|
||||
case at::ScalarType::Half: \
|
||||
{ \
|
||||
using scalar_t_out = at::Half; \
|
||||
__VA_ARGS__; \
|
||||
break; \
|
||||
} \
|
||||
case at::ScalarType::BFloat16: \
|
||||
{ \
|
||||
using scalar_t_out = at::BFloat16; \
|
||||
__VA_ARGS__; \
|
||||
break; \
|
||||
} \
|
||||
default: \
|
||||
AT_ERROR(#NAME, " not implemented for '", toString(TYPEOUT), "'"); \
|
||||
} \
|
||||
break; \
|
||||
} \
|
||||
case at::ScalarType::Half: \
|
||||
{ \
|
||||
using scalar_t_in = at::Half; \
|
||||
using scalar_t_out = at::Half; \
|
||||
__VA_ARGS__; \
|
||||
break; \
|
||||
} \
|
||||
case at::ScalarType::BFloat16: \
|
||||
{ \
|
||||
using scalar_t_in = at::BFloat16; \
|
||||
using scalar_t_out = at::BFloat16; \
|
||||
__VA_ARGS__; \
|
||||
break; \
|
||||
} \
|
||||
default: \
|
||||
AT_ERROR(#NAME, " not implemented for '", toString(TYPEIN), "'"); \
|
||||
}
|
||||
|
211
nemo/collections/nlp/modules/common/megatron/fused_softmax.py
Normal file
211
nemo/collections/nlp/modules/common/megatron/fused_softmax.py
Normal file
|
@ -0,0 +1,211 @@
|
|||
# coding=utf-8
|
||||
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from nemo.collections.nlp.modules.common.megatron.enums import AttnMaskType
|
||||
from nemo.collections.nlp.modules.common.megatron.utils import AutocastModuleWrapper
|
||||
|
||||
|
||||
class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function):
|
||||
"""
|
||||
Fused operation which performs following three operations in sequence
|
||||
1. Scale the tensor.
|
||||
2. Apply upper triangular mask (typically used in gpt models).
|
||||
3. Perform softmax.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, inputs, scale):
|
||||
import scaled_upper_triang_masked_softmax_cuda
|
||||
|
||||
scale_t = torch.tensor([scale])
|
||||
softmax_results = scaled_upper_triang_masked_softmax_cuda.forward(inputs, scale_t[0])
|
||||
|
||||
ctx.save_for_backward(softmax_results, scale_t)
|
||||
return softmax_results
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, output_grads):
|
||||
import scaled_upper_triang_masked_softmax_cuda
|
||||
|
||||
softmax_results, scale_t = ctx.saved_tensors
|
||||
input_grads = scaled_upper_triang_masked_softmax_cuda.backward(output_grads, softmax_results, scale_t[0])
|
||||
|
||||
return input_grads, None
|
||||
|
||||
|
||||
class FusedScaledUpperTriangMaskedSoftmax(AutocastModuleWrapper):
|
||||
def __init__(self, fp16=False, bf16=True):
|
||||
super(FusedScaledUpperTriangMaskedSoftmax, self).__init__(fp16, bf16)
|
||||
|
||||
self.func = ScaledUpperTriangMaskedSoftmax
|
||||
|
||||
def forward(self, inputs, scale):
|
||||
return self.autocast_forward(inputs, scale)
|
||||
|
||||
|
||||
class ScaledMaskedSoftmax(torch.autograd.Function):
|
||||
"""
|
||||
Fused operation which performs following three operations in sequence
|
||||
1. Scale the tensor.
|
||||
2. Apply the mask.
|
||||
3. Perform softmax.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, inputs, mask, scale):
|
||||
import scaled_masked_softmax_cuda
|
||||
|
||||
scale_t = torch.tensor([scale])
|
||||
|
||||
softmax_results = scaled_masked_softmax_cuda.forward(inputs, mask, scale_t[0])
|
||||
ctx.save_for_backward(softmax_results, scale_t)
|
||||
return softmax_results
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, output_grads):
|
||||
import scaled_masked_softmax_cuda
|
||||
|
||||
softmax_results, scale_t = ctx.saved_tensors
|
||||
|
||||
input_grads = scaled_masked_softmax_cuda.backward(output_grads, softmax_results, scale_t[0])
|
||||
return input_grads, None, None
|
||||
|
||||
|
||||
class FusedScaledMaskedSoftmax(AutocastModuleWrapper):
|
||||
def __init__(self, fp16=False, bf16=True):
|
||||
super(FusedScaledMaskedSoftmax, self).__init__(fp16, bf16)
|
||||
|
||||
self.fused_func = ScaledMaskedSoftmax
|
||||
|
||||
def forward(self, inputs, mask, scale):
|
||||
return self.autocast_forward(inputs, mask, scale)
|
||||
|
||||
|
||||
class FusedScaleMaskSoftmax(nn.Module):
|
||||
"""
|
||||
fused operation: scaling + mask + softmax
|
||||
|
||||
Arguments:
|
||||
input_in_fp16: flag to indicate if input in fp16 data format.
|
||||
input_in_bf16: flag to indicate if input in bf16 data format.
|
||||
attn_mask_type: attention mask type (pad or causal)
|
||||
scaled_masked_softmax_fusion: flag to indicate user want to use softmax fusion
|
||||
mask_func: mask function to be applied.
|
||||
softmax_in_fp32: if true, softmax in performed at fp32 precision.
|
||||
scale: scaling factor used in input tensor scaling.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_in_fp16,
|
||||
input_in_bf16,
|
||||
attn_mask_type,
|
||||
scaled_masked_softmax_fusion,
|
||||
mask_func,
|
||||
softmax_in_fp32,
|
||||
scale,
|
||||
):
|
||||
super(FusedScaleMaskSoftmax, self).__init__()
|
||||
self.input_in_fp16 = input_in_fp16
|
||||
self.input_in_bf16 = input_in_bf16
|
||||
assert not (
|
||||
self.input_in_fp16 and self.input_in_bf16
|
||||
), "both fp16 and bf16 flags cannot be active at the same time."
|
||||
self.input_in_float16 = self.input_in_fp16 or self.input_in_bf16
|
||||
self.attn_mask_type = attn_mask_type
|
||||
self.scaled_masked_softmax_fusion = scaled_masked_softmax_fusion
|
||||
self.mask_func = mask_func
|
||||
self.softmax_in_fp32 = softmax_in_fp32
|
||||
self.scale = scale
|
||||
|
||||
assert self.scale is None or softmax_in_fp32, "softmax should be in fp32 when scaled"
|
||||
|
||||
self.fused_scaled_uppper_triang_masked_softmax = FusedScaledUpperTriangMaskedSoftmax(
|
||||
input_in_fp16, input_in_bf16
|
||||
)
|
||||
self.fused_scaled_masked_softmax = ScaledMaskedSoftmax(input_in_fp16, input_in_bf16)
|
||||
|
||||
def forward(self, input, mask):
|
||||
# [b, np, sq, sk]
|
||||
assert input.dim() == 4
|
||||
|
||||
if self.is_kernel_available(mask, *input.size()):
|
||||
return self.forward_fused_softmax(input, mask)
|
||||
else:
|
||||
return self.forward_torch_softmax(input, mask)
|
||||
|
||||
def is_kernel_available(self, mask, b, np, sq, sk):
|
||||
attn_batches = b * np
|
||||
|
||||
if (
|
||||
self.scaled_masked_softmax_fusion # user want to fuse
|
||||
and self.input_in_float16 # input must be fp16
|
||||
and mask is not None # mask tensor must not be None
|
||||
and 16 < sk <= 2048 # sk must be 16 ~ 2048
|
||||
and sq % 4 == 0 # sq must be divisor of 4
|
||||
and attn_batches % 4 == 0 # np * b must be divisor of 4
|
||||
):
|
||||
if 0 <= sk <= 2048:
|
||||
batch_per_block = self.get_batch_per_block(sq, sk, b, np)
|
||||
|
||||
if self.attn_mask_type == AttnMaskType.causal:
|
||||
if attn_batches % batch_per_block == 0:
|
||||
return True
|
||||
else:
|
||||
if sq % batch_per_block == 0:
|
||||
return True
|
||||
return False
|
||||
|
||||
def forward_fused_softmax(self, input, mask):
|
||||
b, np, sq, sk = input.size()
|
||||
scale = self.scale if self.scale is not None else 1.0
|
||||
|
||||
if self.attn_mask_type == AttnMaskType.causal:
|
||||
assert sq == sk, "causal mask is only for self attention"
|
||||
|
||||
# input is 3D tensor (attn_batches, sq, sk)
|
||||
input = input.view(-1, sq, sk)
|
||||
probs = self.fused_scaled_uppper_triang_masked_softmax(input, scale)
|
||||
return probs.view(b, np, sq, sk)
|
||||
else:
|
||||
# input is 4D tensor (b, np, sq, sk)
|
||||
return self.fused_scaled_masked_softmax(input, mask, scale)
|
||||
|
||||
def forward_torch_softmax(self, input, mask):
|
||||
if self.input_in_float16 and self.softmax_in_fp32:
|
||||
input = input.float()
|
||||
|
||||
if self.scale is not None:
|
||||
input = input * self.scale
|
||||
mask_output = self.mask_func(input, mask) if mask is not None else input
|
||||
probs = torch.nn.Softmax(dim=-1)(mask_output)
|
||||
|
||||
if self.input_in_float16 and self.softmax_in_fp32:
|
||||
if self.input_in_fp16:
|
||||
probs = probs.half()
|
||||
else:
|
||||
probs = probs.bfloat16()
|
||||
|
||||
return probs
|
||||
|
||||
@staticmethod
|
||||
def get_batch_per_block(sq, sk, b, np):
|
||||
import scaled_masked_softmax_cuda
|
||||
|
||||
return scaled_masked_softmax_cuda.get_batch_per_block(sq, sk, b, np)
|
581
nemo/collections/nlp/modules/common/megatron/language_model.py
Normal file
581
nemo/collections/nlp/modules/common/megatron/language_model.py
Normal file
|
@ -0,0 +1,581 @@
|
|||
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Transformer based language model."""
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from apex.transformer import parallel_state, tensor_parallel
|
||||
from apex.transformer.enums import AttnMaskType, LayerType
|
||||
|
||||
from nemo.collections.nlp.modules.common.megatron.module import MegatronModule
|
||||
from nemo.collections.nlp.modules.common.megatron.transformer import ParallelTransformer
|
||||
from nemo.collections.nlp.modules.common.megatron.utils import (
|
||||
get_linear_layer,
|
||||
init_method_normal,
|
||||
scaled_init_method_normal,
|
||||
)
|
||||
|
||||
|
||||
def parallel_lm_logits(input_, word_embeddings_weight, parallel_output, bias=None):
|
||||
"""LM logits using word embedding weights."""
|
||||
# Parallel logits.
|
||||
input_parallel = tensor_parallel.copy_to_tensor_model_parallel_region(input_)
|
||||
# Matrix multiply.
|
||||
if bias is None:
|
||||
logits_parallel = F.linear(input_parallel, word_embeddings_weight)
|
||||
else:
|
||||
logits_parallel = F.linear(input_parallel, word_embeddings_weight, bias)
|
||||
# Gather if needed.
|
||||
if parallel_output:
|
||||
return logits_parallel
|
||||
|
||||
return tensor_parallel.gather_from_tensor_model_parallel_region(logits_parallel)
|
||||
|
||||
|
||||
def get_language_model(
|
||||
hidden_size,
|
||||
ffn_hidden_size,
|
||||
num_layers,
|
||||
max_position_embeddings,
|
||||
num_tokentypes,
|
||||
add_pooler,
|
||||
vocab_size,
|
||||
num_attention_heads,
|
||||
encoder_attn_mask_type,
|
||||
apply_query_key_layer_scaling=True,
|
||||
kv_channels=None,
|
||||
init_method=None,
|
||||
scaled_init_method=None,
|
||||
add_decoder=False,
|
||||
decoder_attn_mask_type=AttnMaskType.causal,
|
||||
pre_process=True,
|
||||
post_process=True,
|
||||
init_method_std=0.02,
|
||||
use_cpu_initialization=False,
|
||||
hidden_dropout=0.1,
|
||||
fused_fp16=False,
|
||||
fused_bf16=False,
|
||||
fp32_residual_connection=False,
|
||||
activations_checkpoint_method=None,
|
||||
activations_checkpoint_num_layers=1,
|
||||
layernorm_epsilon=1e-5,
|
||||
bias_gelu_fusion=True,
|
||||
openai_gelu=False,
|
||||
onnx_safe=False,
|
||||
):
|
||||
"""Build language model and return along with the key to save."""
|
||||
|
||||
assert not (fused_fp16 and fused_bf16), "both fused_fp16 and fused_bf16 flags cannot be True at the same time."
|
||||
|
||||
if kv_channels is None:
|
||||
assert (
|
||||
hidden_size % num_attention_heads == 0
|
||||
), 'hidden_size must be divisible by num_attention_heads if kv_channels is None'
|
||||
kv_channels = hidden_size // num_attention_heads
|
||||
|
||||
if init_method is None:
|
||||
init_method = init_method_normal(init_method_std)
|
||||
|
||||
if scaled_init_method is None:
|
||||
scaled_init_method = scaled_init_method_normal(init_method_std, num_layers)
|
||||
|
||||
# Language model.
|
||||
language_model = TransformerLanguageModel(
|
||||
init_method=init_method,
|
||||
output_layer_init_method=scaled_init_method,
|
||||
encoder_attn_mask_type=encoder_attn_mask_type,
|
||||
num_tokentypes=num_tokentypes,
|
||||
vocab_size=vocab_size,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
hidden_size=hidden_size,
|
||||
num_layers=num_layers,
|
||||
num_attention_heads=num_attention_heads,
|
||||
apply_query_key_layer_scaling=apply_query_key_layer_scaling,
|
||||
kv_channels=kv_channels,
|
||||
ffn_hidden_size=ffn_hidden_size,
|
||||
add_decoder=add_decoder,
|
||||
decoder_attn_mask_type=decoder_attn_mask_type,
|
||||
add_pooler=add_pooler,
|
||||
pre_process=pre_process,
|
||||
post_process=post_process,
|
||||
use_cpu_initialization=use_cpu_initialization,
|
||||
hidden_dropout=hidden_dropout,
|
||||
fused_fp16=fused_fp16,
|
||||
fused_bf16=fused_bf16,
|
||||
fp32_residual_connection=fp32_residual_connection,
|
||||
activations_checkpoint_method=activations_checkpoint_method,
|
||||
activations_checkpoint_num_layers=activations_checkpoint_num_layers,
|
||||
layernorm_epsilon=layernorm_epsilon,
|
||||
bias_gelu_fusion=bias_gelu_fusion,
|
||||
openai_gelu=openai_gelu,
|
||||
onnx_safe=onnx_safe,
|
||||
)
|
||||
# key used for checkpoints.
|
||||
language_model_key = 'language_model'
|
||||
|
||||
return language_model, language_model_key
|
||||
|
||||
|
||||
class Pooler(MegatronModule):
|
||||
"""Pooler layer.
|
||||
|
||||
Pool hidden states of a specific token (for example start of the
|
||||
sequence) and add a linear transformation followed by a tanh.
|
||||
|
||||
Arguments:
|
||||
hidden_size: hidden size
|
||||
init_method: weight initialization method for the linear layer.
|
||||
bias is set to zero.
|
||||
"""
|
||||
|
||||
def __init__(self, hidden_size, init_method):
|
||||
super(Pooler, self).__init__()
|
||||
self.dense = get_linear_layer(hidden_size, hidden_size, init_method)
|
||||
|
||||
def forward(self, hidden_states, sequence_index=0):
|
||||
# hidden_states: [b, s, h]
|
||||
# sequence_index: index of the token to pool.
|
||||
pooled = hidden_states[:, sequence_index, :]
|
||||
pooled = self.dense(pooled)
|
||||
pooled = torch.tanh(pooled)
|
||||
return pooled
|
||||
|
||||
|
||||
class Embedding(MegatronModule):
|
||||
"""Language model embeddings.
|
||||
|
||||
Arguments:
|
||||
hidden_size: hidden size
|
||||
vocab_size: vocabulary size
|
||||
max_sequence_length: maximum size of sequence. This
|
||||
is used for positional embedding
|
||||
embedding_dropout_prob: dropout probability for embeddings
|
||||
init_method: weight initialization method
|
||||
num_tokentypes: size of the token-type embeddings. 0 value
|
||||
will ignore this embedding
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size,
|
||||
vocab_size,
|
||||
max_sequence_length,
|
||||
embedding_dropout_prob,
|
||||
init_method,
|
||||
num_tokentypes=0,
|
||||
use_cpu_initialization=False,
|
||||
):
|
||||
super(Embedding, self).__init__()
|
||||
|
||||
self.hidden_size = hidden_size
|
||||
self.init_method = init_method
|
||||
self.num_tokentypes = num_tokentypes
|
||||
|
||||
# Word embeddings (parallel).
|
||||
self.word_embeddings = tensor_parallel.VocabParallelEmbedding(
|
||||
vocab_size, self.hidden_size, init_method=self.init_method, use_cpu_initialization=use_cpu_initialization,
|
||||
)
|
||||
self._word_embeddings_key = 'word_embeddings'
|
||||
|
||||
# Position embedding (serial).
|
||||
self.position_embeddings = torch.nn.Embedding(max_sequence_length, self.hidden_size)
|
||||
self._position_embeddings_key = 'position_embeddings'
|
||||
# Initialize the position embeddings.
|
||||
self.init_method(self.position_embeddings.weight)
|
||||
|
||||
# Token type embedding.
|
||||
# Add this as an optional field that can be added through
|
||||
# method call so we can load a pretrain model without
|
||||
# token types and add them as needed.
|
||||
self._tokentype_embeddings_key = 'tokentype_embeddings'
|
||||
if self.num_tokentypes > 0:
|
||||
self.tokentype_embeddings = torch.nn.Embedding(self.num_tokentypes, self.hidden_size)
|
||||
# Initialize the token-type embeddings.
|
||||
self.init_method(self.tokentype_embeddings.weight)
|
||||
else:
|
||||
self.tokentype_embeddings = None
|
||||
|
||||
# Embeddings dropout
|
||||
self.embedding_dropout = torch.nn.Dropout(embedding_dropout_prob)
|
||||
|
||||
def add_tokentype_embeddings(self, num_tokentypes):
|
||||
"""Add token-type embedding. This function is provided so we can add
|
||||
token-type embeddings in case the pretrained model does not have it.
|
||||
This allows us to load the model normally and then add this embedding.
|
||||
"""
|
||||
if self.tokentype_embeddings is not None:
|
||||
raise Exception('tokentype embeddings is already initialized')
|
||||
if torch.distributed.get_rank() == 0:
|
||||
print('adding embedding for {} tokentypes'.format(num_tokentypes), flush=True)
|
||||
self.num_tokentypes = num_tokentypes
|
||||
self.tokentype_embeddings = torch.nn.Embedding(num_tokentypes, self.hidden_size)
|
||||
# Initialize the token-type embeddings.
|
||||
self.init_method(self.tokentype_embeddings.weight)
|
||||
|
||||
def forward(self, input_ids, position_ids, tokentype_ids=None):
|
||||
# Embeddings.
|
||||
words_embeddings = self.word_embeddings(input_ids)
|
||||
position_embeddings = self.position_embeddings(position_ids)
|
||||
embeddings = words_embeddings + position_embeddings
|
||||
if tokentype_ids is not None:
|
||||
assert self.tokentype_embeddings is not None
|
||||
embeddings = embeddings + self.tokentype_embeddings(tokentype_ids)
|
||||
else:
|
||||
assert self.tokentype_embeddings is None
|
||||
|
||||
# Dropout.
|
||||
embeddings = self.embedding_dropout(embeddings)
|
||||
|
||||
return embeddings
|
||||
|
||||
def state_dict_for_save_checkpoint(self, destination=None, prefix='', keep_vars=False):
|
||||
"""For easy load."""
|
||||
|
||||
state_dict_ = {}
|
||||
state_dict_[self._word_embeddings_key] = self.word_embeddings.state_dict(destination, prefix, keep_vars)
|
||||
state_dict_[self._position_embeddings_key] = self.position_embeddings.state_dict(
|
||||
destination, prefix, keep_vars
|
||||
)
|
||||
if self.num_tokentypes > 0:
|
||||
state_dict_[self._tokentype_embeddings_key] = self.tokentype_embeddings.state_dict(
|
||||
destination, prefix, keep_vars
|
||||
)
|
||||
|
||||
return state_dict_
|
||||
|
||||
def load_state_dict(self, state_dict, strict=True):
|
||||
"""Customized load."""
|
||||
|
||||
# Word embedding.
|
||||
if self._word_embeddings_key in state_dict:
|
||||
state_dict_ = state_dict[self._word_embeddings_key]
|
||||
else:
|
||||
# for backward compatibility.
|
||||
state_dict_ = {}
|
||||
for key in state_dict.keys():
|
||||
if 'word_embeddings' in key:
|
||||
state_dict_[key.split('word_embeddings.')[1]] = state_dict[key]
|
||||
self.word_embeddings.load_state_dict(state_dict_, strict=strict)
|
||||
|
||||
# Position embedding.
|
||||
if self._position_embeddings_key in state_dict:
|
||||
state_dict_ = state_dict[self._position_embeddings_key]
|
||||
else:
|
||||
# for backward compatibility.
|
||||
state_dict_ = {}
|
||||
for key in state_dict.keys():
|
||||
if 'position_embeddings' in key:
|
||||
state_dict_[key.split('position_embeddings.')[1]] = state_dict[key]
|
||||
self.position_embeddings.load_state_dict(state_dict_, strict=strict)
|
||||
|
||||
# Tokentype embedding.
|
||||
if self.num_tokentypes > 0:
|
||||
state_dict_ = {}
|
||||
if self._tokentype_embeddings_key in state_dict:
|
||||
state_dict_ = state_dict[self._tokentype_embeddings_key]
|
||||
else:
|
||||
# for backward compatibility.
|
||||
for key in state_dict.keys():
|
||||
if 'tokentype_embeddings' in key:
|
||||
state_dict_[key.split('tokentype_embeddings.')[1]] = state_dict[key]
|
||||
if len(state_dict_.keys()) > 0:
|
||||
self.tokentype_embeddings.load_state_dict(state_dict_, strict=strict)
|
||||
else:
|
||||
print(
|
||||
'***WARNING*** expected tokentype embeddings in the ' 'checkpoint but could not find it',
|
||||
flush=True,
|
||||
)
|
||||
|
||||
|
||||
class TransformerLanguageModel(MegatronModule):
|
||||
"""Transformer language model.
|
||||
|
||||
Arguments:
|
||||
transformer_hparams: transformer hyperparameters
|
||||
vocab_size: vocabulary size
|
||||
max_sequence_length: maximum size of sequence. This
|
||||
is used for positional embedding
|
||||
embedding_dropout_prob: dropout probability for embeddings
|
||||
num_tokentypes: size of the token-type embeddings. 0 value
|
||||
will ignore this embedding
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
init_method,
|
||||
output_layer_init_method,
|
||||
encoder_attn_mask_type,
|
||||
vocab_size,
|
||||
max_position_embeddings,
|
||||
hidden_size,
|
||||
ffn_hidden_size,
|
||||
num_layers,
|
||||
num_tokentypes,
|
||||
num_attention_heads,
|
||||
apply_query_key_layer_scaling=True,
|
||||
kv_channels=None,
|
||||
add_decoder=False,
|
||||
decoder_attn_mask_type=AttnMaskType.causal,
|
||||
add_pooler=False,
|
||||
pre_process=True,
|
||||
post_process=True,
|
||||
use_cpu_initialization=False,
|
||||
hidden_dropout=0.1,
|
||||
fused_fp16=False,
|
||||
fused_bf16=False,
|
||||
fp32_residual_connection=False,
|
||||
activations_checkpoint_method=None,
|
||||
activations_checkpoint_num_layers=1,
|
||||
layernorm_epsilon=1e-5,
|
||||
bias_gelu_fusion=True,
|
||||
openai_gelu=False,
|
||||
onnx_safe=False,
|
||||
):
|
||||
super(TransformerLanguageModel, self).__init__()
|
||||
|
||||
self.pre_process = pre_process
|
||||
self.post_process = post_process
|
||||
self.hidden_size = hidden_size
|
||||
self.num_layers = num_layers
|
||||
self.vocab_size = vocab_size
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.num_tokentypes = num_tokentypes
|
||||
self.init_method = init_method
|
||||
self.encoder_attn_mask_type = encoder_attn_mask_type
|
||||
self.add_decoder = add_decoder
|
||||
self.decoder_attn_mask_type = decoder_attn_mask_type
|
||||
self.add_pooler = add_pooler
|
||||
self.hidden_dropout = hidden_dropout
|
||||
self.output_layer_init_method = output_layer_init_method
|
||||
|
||||
if kv_channels is None:
|
||||
|
||||
assert (
|
||||
hidden_size % num_attention_heads == 0
|
||||
), 'hidden_size must be divisible by num_attention_heads if kv_channels is None'
|
||||
kv_channels = hidden_size // num_attention_heads
|
||||
|
||||
# Embeddings.
|
||||
if self.pre_process:
|
||||
self.embedding = Embedding(
|
||||
hidden_size=self.hidden_size,
|
||||
vocab_size=self.vocab_size,
|
||||
max_sequence_length=self.max_position_embeddings,
|
||||
init_method=self.init_method,
|
||||
num_tokentypes=self.num_tokentypes,
|
||||
use_cpu_initialization=use_cpu_initialization,
|
||||
embedding_dropout_prob=self.hidden_dropout,
|
||||
)
|
||||
self._embedding_key = 'embedding'
|
||||
|
||||
# Transformer.
|
||||
self.encoder = ParallelTransformer(
|
||||
init_method=self.init_method,
|
||||
output_layer_init_method=self.output_layer_init_method,
|
||||
num_layers=self.num_layers,
|
||||
hidden_size=self.hidden_size,
|
||||
num_attention_heads=num_attention_heads,
|
||||
apply_query_key_layer_scaling=apply_query_key_layer_scaling,
|
||||
kv_channels=kv_channels,
|
||||
ffn_hidden_size=ffn_hidden_size,
|
||||
self_attn_mask_type=self.encoder_attn_mask_type,
|
||||
pre_process=self.pre_process,
|
||||
post_process=self.post_process,
|
||||
fused_fp16=fused_fp16,
|
||||
fused_bf16=fused_bf16,
|
||||
fp32_residual_connection=fp32_residual_connection,
|
||||
activations_checkpoint_method=activations_checkpoint_method,
|
||||
activations_checkpoint_num_layers=activations_checkpoint_num_layers,
|
||||
layernorm_epsilon=layernorm_epsilon,
|
||||
hidden_dropout=hidden_dropout,
|
||||
use_cpu_initialization=use_cpu_initialization,
|
||||
bias_gelu_fusion=bias_gelu_fusion,
|
||||
openai_gelu=openai_gelu,
|
||||
onnx_safe=onnx_safe,
|
||||
)
|
||||
self._encoder_key = 'encoder'
|
||||
|
||||
# Decoder
|
||||
if self.add_decoder:
|
||||
assert (
|
||||
parallel_state.get_pipeline_model_parallel_world_size() == 1
|
||||
), 'pipeline parallelism is not supported in the presence of decoder'
|
||||
self.decoder = ParallelTransformer(
|
||||
layer_type=LayerType.decoder,
|
||||
self_attn_mask_type=self.decoder_attn_mask_type,
|
||||
init_method=self.init_method,
|
||||
output_layer_init_method=self.output_layer_init_method,
|
||||
num_layers=self.num_layers,
|
||||
hidden_size=self.hidden_size,
|
||||
num_attention_heads=num_attention_heads,
|
||||
apply_query_key_layer_scaling=apply_query_key_layer_scaling,
|
||||
kv_channels=kv_channels,
|
||||
ffn_hidden_size=ffn_hidden_size,
|
||||
pre_process=self.pre_process,
|
||||
post_process=self.post_process,
|
||||
fused_fp16=fused_fp16,
|
||||
fused_bf16=fused_bf16,
|
||||
fp32_residual_connection=fp32_residual_connection,
|
||||
activations_checkpoint_method=activations_checkpoint_method,
|
||||
activations_checkpoint_num_layers=activations_checkpoint_num_layers,
|
||||
layernorm_epsilon=layernorm_epsilon,
|
||||
hidden_dropout=hidden_dropout,
|
||||
use_cpu_initialization=use_cpu_initialization,
|
||||
bias_gelu_fusion=bias_gelu_fusion,
|
||||
openai_gelu=openai_gelu,
|
||||
onnx_safe=onnx_safe,
|
||||
)
|
||||
self._decoder_key = 'decoder'
|
||||
|
||||
if self.post_process:
|
||||
# Pooler.
|
||||
if self.add_pooler:
|
||||
self.pooler = Pooler(self.hidden_size, self.init_method)
|
||||
self._pooler_key = 'pooler'
|
||||
|
||||
def set_input_tensor(self, input_tensor):
|
||||
""" See megatron.model.transformer.set_input_tensor()"""
|
||||
self.encoder.set_input_tensor(input_tensor)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
enc_input_ids,
|
||||
enc_position_ids,
|
||||
enc_attn_mask,
|
||||
dec_input_ids=None,
|
||||
dec_position_ids=None,
|
||||
dec_attn_mask=None,
|
||||
enc_dec_attn_mask=None,
|
||||
tokentype_ids=None,
|
||||
layer_past=None,
|
||||
get_key_value=False,
|
||||
pooling_sequence_index=0,
|
||||
enc_hidden_states=None,
|
||||
output_enc_hidden=False,
|
||||
):
|
||||
# Embeddings.
|
||||
if self.pre_process:
|
||||
embedding_output = self.embedding(enc_input_ids, enc_position_ids, tokentype_ids=tokentype_ids)
|
||||
encoder_input = embedding_output
|
||||
else:
|
||||
encoder_input = None
|
||||
|
||||
# encoder.
|
||||
if enc_hidden_states is None:
|
||||
encoder_output = self.encoder(
|
||||
encoder_input, enc_attn_mask, layer_past=layer_past, get_key_value=get_key_value
|
||||
)
|
||||
else:
|
||||
encoder_output = enc_hidden_states.to(encoder_input.dtype)
|
||||
|
||||
if self.post_process:
|
||||
if self.add_pooler:
|
||||
pooled_output = self.pooler(encoder_output, pooling_sequence_index)
|
||||
|
||||
# output_enc_hidden refers to when we just need the encoder's
|
||||
# output. For example, it is helpful to compute
|
||||
# similarity between two sequences by average pooling
|
||||
if not self.add_decoder or output_enc_hidden:
|
||||
if self.add_pooler and self.post_process:
|
||||
return encoder_output, pooled_output
|
||||
else:
|
||||
return encoder_output
|
||||
|
||||
# Decoder Embedding
|
||||
dec_embedding_output = self.embedding(dec_input_ids, dec_position_ids)
|
||||
# decoder
|
||||
decoder_output = self.decoder(
|
||||
dec_embedding_output,
|
||||
dec_attn_mask,
|
||||
layer_past=layer_past,
|
||||
get_key_value=get_key_value,
|
||||
encoder_output=encoder_output,
|
||||
enc_dec_attn_mask=enc_dec_attn_mask,
|
||||
)
|
||||
|
||||
if self.add_pooler and self.post_process:
|
||||
return decoder_output, encoder_output, pooled_output
|
||||
else:
|
||||
return decoder_output, encoder_output
|
||||
|
||||
def state_dict_for_save_checkpoint(self, destination=None, prefix='', keep_vars=False):
|
||||
"""For easy load."""
|
||||
|
||||
state_dict_ = {}
|
||||
if self.pre_process:
|
||||
state_dict_[self._embedding_key] = self.embedding.state_dict_for_save_checkpoint(
|
||||
destination, prefix, keep_vars
|
||||
)
|
||||
state_dict_[self._encoder_key] = self.encoder.state_dict_for_save_checkpoint(destination, prefix, keep_vars)
|
||||
if self.post_process:
|
||||
if self.add_pooler:
|
||||
state_dict_[self._pooler_key] = self.pooler.state_dict_for_save_checkpoint(
|
||||
destination, prefix, keep_vars
|
||||
)
|
||||
if self.add_decoder:
|
||||
state_dict_[self._decoder_key] = self.decoder.state_dict_for_save_checkpoint(
|
||||
destination, prefix, keep_vars
|
||||
)
|
||||
|
||||
return state_dict_
|
||||
|
||||
def load_state_dict(self, state_dict, strict=True):
|
||||
"""Customized load."""
|
||||
|
||||
# Embedding.
|
||||
if self.pre_process:
|
||||
if self._embedding_key in state_dict:
|
||||
state_dict_ = state_dict[self._embedding_key]
|
||||
else:
|
||||
# for backward compatibility.
|
||||
state_dict_ = {}
|
||||
for key in state_dict.keys():
|
||||
if '_embeddings' in key:
|
||||
state_dict_[key] = state_dict[key]
|
||||
self.embedding.load_state_dict(state_dict_, strict=strict)
|
||||
|
||||
# Encoder.
|
||||
if self._encoder_key in state_dict:
|
||||
state_dict_ = state_dict[self._encoder_key]
|
||||
# for backward compatibility.
|
||||
elif 'transformer' in state_dict:
|
||||
state_dict_ = state_dict['transformer']
|
||||
else:
|
||||
# for backward compatibility.
|
||||
state_dict_ = {}
|
||||
for key in state_dict.keys():
|
||||
if 'transformer.' in key:
|
||||
state_dict_[key.split('transformer.')[1]] = state_dict[key]
|
||||
|
||||
# for backward compatibility.
|
||||
state_dict_self_attention = {}
|
||||
for key in state_dict_.keys():
|
||||
if '.attention.' in key:
|
||||
state_dict_self_attention[key.replace(".attention.", ".self_attention.")] = state_dict_[key]
|
||||
else:
|
||||
state_dict_self_attention[key] = state_dict_[key]
|
||||
state_dict_ = state_dict_self_attention
|
||||
|
||||
self.encoder.load_state_dict(state_dict_, strict=strict)
|
||||
|
||||
if self.post_process:
|
||||
# pooler
|
||||
if self.add_pooler:
|
||||
assert 'pooler' in state_dict, 'could not find data for pooler in the checkpoint'
|
||||
self.pooler.load_state_dict(state_dict[self._pooler_key], strict=strict)
|
||||
# decoder
|
||||
if self.add_decoder:
|
||||
assert 'decoder' in state_dict, 'could not find data for pooler in the checkpoint'
|
||||
self.decoder.load_state_dict(state_dict[self._decoder_key], strict=strict)
|
|
@ -17,11 +17,8 @@
|
|||
import os
|
||||
|
||||
import torch
|
||||
from megatron import get_args, initialize_megatron
|
||||
from megatron.checkpointing import set_checkpoint_version
|
||||
from megatron.model import get_language_model
|
||||
from megatron.model.bert_model import bert_attention_mask_func, bert_extended_attention_mask, bert_position_ids
|
||||
from megatron.mpu import (
|
||||
from apex.transformer.enums import AttnMaskType
|
||||
from apex.transformer.parallel_state import (
|
||||
get_model_parallel_group,
|
||||
model_parallel_is_initialized,
|
||||
set_pipeline_model_parallel_rank,
|
||||
|
@ -30,6 +27,7 @@ from megatron.mpu import (
|
|||
from omegaconf import DictConfig, OmegaConf
|
||||
|
||||
from nemo.collections.nlp.modules.common.bert_module import BertModule
|
||||
from nemo.collections.nlp.modules.common.megatron.language_model import get_language_model
|
||||
from nemo.core.classes import typecheck
|
||||
from nemo.utils import logging
|
||||
from nemo.utils.app_state import AppState
|
||||
|
@ -37,13 +35,6 @@ from nemo.utils.app_state import AppState
|
|||
__all__ = ['MegatronBertEncoder']
|
||||
|
||||
|
||||
def complete_lazy_init(self):
|
||||
# finish megatron-lm initialization
|
||||
if hasattr(self, "_lazy_init_fn") and self._lazy_init_fn is not None:
|
||||
self._lazy_init_fn()
|
||||
self._lazy_init_fn = None
|
||||
|
||||
|
||||
class MegatronBertEncoder(BertModule):
|
||||
"""
|
||||
MegatronBERT wraps around the Megatron Language model
|
||||
|
@ -59,96 +50,47 @@ class MegatronBertEncoder(BertModule):
|
|||
|
||||
super().__init__()
|
||||
|
||||
self._model_parallel_size = model_parallel_size
|
||||
self._model_parallel_rank = model_parallel_rank
|
||||
self._restore_path = None
|
||||
self._app_state = None
|
||||
self._model_name = model_name
|
||||
|
||||
if 'vocab_size' in config:
|
||||
self._vocab_size = config.pop('vocab_size')
|
||||
else:
|
||||
self._vocab_size = None
|
||||
|
||||
self._hidden_size = config.get('hidden_size')
|
||||
|
||||
if not os.path.exists(vocab_file):
|
||||
raise ValueError(f'Vocab file not found at {vocab_file}')
|
||||
|
||||
# convert config to dictionary
|
||||
if isinstance(config, DictConfig):
|
||||
config = OmegaConf.to_container(config)
|
||||
config["vocab_file"] = vocab_file
|
||||
config['tokenizer_type'] = 'BertWordPieceLowerCase'
|
||||
config['lazy_mpu_init'] = True
|
||||
config['onnx_safe'] = True
|
||||
|
||||
num_tokentypes = config.pop('num_tokentypes', 2)
|
||||
|
||||
# if 'model_parallel_size' in config:
|
||||
if self._model_parallel_size is not None:
|
||||
app_state = AppState()
|
||||
self._app_state = app_state
|
||||
|
||||
# must be set for model parallel megatron-lm
|
||||
os.environ["WORLD_SIZE"] = str(app_state.world_size)
|
||||
os.environ["RANK"] = str(self._model_parallel_rank)
|
||||
|
||||
extra_args_provider = self._update_megatron_args(tensor_model_parallel_size=self._model_parallel_size)
|
||||
|
||||
else:
|
||||
extra_args_provider = self._update_megatron_args()
|
||||
|
||||
# configure globals for megatron
|
||||
set_pipeline_model_parallel_rank(0) # pipeline model parallelism not implemented in NeMo
|
||||
set_pipeline_model_parallel_world_size(1) # pipeline model parallelism not implemented in NeMo
|
||||
|
||||
# Initialize part of Megatron global state that is needed for its constructor.
|
||||
# We set 'lazy_mpu_init' flag on to make Megatron do only the initialization that does not depend
|
||||
# on ddp be initialized yet (and we don't want Megatron to initialize DDP itself either)
|
||||
# and to return a hook for us to call after PTL has torch.distributed initialized.
|
||||
# (or if no PTL in case of inference - then we'll initialize torch.distributed)
|
||||
# We call and clear this hook on first call to forward()
|
||||
self._lazy_init_fn = initialize_megatron(
|
||||
extra_args_provider=extra_args_provider, args_defaults=config, ignore_unknown_args=True
|
||||
raise ValueError(
|
||||
f'megatron-lm bert has been deprecated in NeMo 1.5. Please use an earlier release of NeMo. Megatron bert support will be added back to NeMo in a future release.'
|
||||
)
|
||||
|
||||
# read Megatron arguments back
|
||||
args = get_args()
|
||||
logging.info(f'Megatron-lm argparse args: {args}')
|
||||
# self._model_parallel_size = model_parallel_size
|
||||
# self._model_parallel_rank = model_parallel_rank
|
||||
# self._restore_path = None
|
||||
# self._app_state = None
|
||||
# self._model_name = model_name
|
||||
|
||||
self.language_model, self._language_model_key = get_language_model(
|
||||
attention_mask_func=bert_attention_mask_func, num_tokentypes=num_tokentypes, add_pooler=False
|
||||
)
|
||||
# if 'vocab_size' in config:
|
||||
# self._vocab_size = config.pop('vocab_size')
|
||||
# else:
|
||||
# self._vocab_size = None
|
||||
|
||||
self.config = OmegaConf.create(config)
|
||||
# key used for checkpoints
|
||||
self._hidden_size = self.language_model.hidden_size
|
||||
# self._hidden_size = config.get('hidden_size')
|
||||
|
||||
def _update_megatron_args(
|
||||
self,
|
||||
micro_batch_size=1,
|
||||
tensor_model_parallel_size=1,
|
||||
scaled_masked_softmax_fusion=False,
|
||||
bias_gelu_fusion=False,
|
||||
bias_dropout_fusion=False,
|
||||
):
|
||||
def extra_args_provider(parser):
|
||||
parser.set_defaults(micro_batch_size=micro_batch_size)
|
||||
parser.set_defaults(tensor_model_parallel_size=tensor_model_parallel_size)
|
||||
parser.set_defaults(scaled_masked_softmax_fusion=scaled_masked_softmax_fusion)
|
||||
parser.set_defaults(bias_gelu_fusion=bias_gelu_fusion)
|
||||
parser.set_defaults(bias_dropout_fusion=bias_dropout_fusion)
|
||||
# if not os.path.exists(vocab_file):
|
||||
# raise ValueError(f'Vocab file not found at {vocab_file}')
|
||||
|
||||
return parser
|
||||
# # convert config to dictionary
|
||||
# if isinstance(config, DictConfig):
|
||||
# config = OmegaConf.to_container(config)
|
||||
# config["vocab_file"] = vocab_file
|
||||
# config['tokenizer_type'] = 'BertWordPieceLowerCase'
|
||||
# config['lazy_mpu_init'] = True
|
||||
# config['onnx_safe'] = True
|
||||
|
||||
return extra_args_provider
|
||||
# num_tokentypes = config.pop('num_tokentypes', 2)
|
||||
|
||||
def complete_lazy_init(self):
|
||||
# finish megatron-lm initialization
|
||||
if hasattr(self, "_lazy_init_fn") and self._lazy_init_fn is not None:
|
||||
self._lazy_init_fn()
|
||||
self._lazy_init_fn = None
|
||||
# # configure globals for megatron
|
||||
# set_pipeline_model_parallel_rank(0) # pipeline model parallelism not implemented in NeMo
|
||||
# set_pipeline_model_parallel_world_size(1) # pipeline model parallelism not implemented in NeMo
|
||||
|
||||
# self.language_model, self._language_model_key = get_language_model(
|
||||
# encoder_attn_mask_type=AttnMaskType.padding, num_tokentypes=num_tokentypes, add_pooler=False
|
||||
# )
|
||||
|
||||
# self.config = OmegaConf.create(config)
|
||||
# # key used for checkpoints
|
||||
# self._hidden_size = self.language_model.hidden_size
|
||||
|
||||
@property
|
||||
def hidden_size(self):
|
||||
|
@ -172,17 +114,14 @@ class MegatronBertEncoder(BertModule):
|
|||
|
||||
@typecheck()
|
||||
def forward(self, input_ids, attention_mask, token_type_ids=None):
|
||||
app_state = AppState()
|
||||
if app_state.model_parallel_size is None:
|
||||
self.complete_lazy_init()
|
||||
|
||||
extended_attention_mask = bert_extended_attention_mask(attention_mask)
|
||||
position_ids = bert_position_ids(input_ids)
|
||||
|
||||
sequence_output = self.language_model(
|
||||
input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
attention_mask=extended_attention_mask,
|
||||
enc_input_ids=input_ids,
|
||||
enc_position_ids=position_ids,
|
||||
enc_attn_mask=extended_attention_mask,
|
||||
tokentype_ids=token_type_ids,
|
||||
)
|
||||
return sequence_output
|
||||
|
@ -196,13 +135,13 @@ class MegatronBertEncoder(BertModule):
|
|||
state_dict = torch.load(filename, map_location='cpu')
|
||||
if 'checkpoint_version' in state_dict:
|
||||
if state_dict['checkpoint_version'] is not None:
|
||||
set_checkpoint_version(state_dict['checkpoint_version'])
|
||||
set_megatron_checkpoint_version(state_dict['checkpoint_version'])
|
||||
logging.info(
|
||||
f"Megatron-lm checkpoint version found. Setting checkpoint_version to {state_dict['checkpoint_version']}."
|
||||
)
|
||||
else:
|
||||
logging.warning('Megatron-lm checkpoint version not found. Setting checkpoint_version to 0.')
|
||||
set_checkpoint_version(0)
|
||||
set_megatron_checkpoint_version(0)
|
||||
# to load from Megatron pretrained checkpoint
|
||||
if 'model' in state_dict:
|
||||
self.language_model.load_state_dict(state_dict['model'][self._language_model_key])
|
||||
|
@ -233,3 +172,39 @@ class MegatronBertEncoder(BertModule):
|
|||
logging.info(f'torch.distributed not initialized yet. Will not restore model parallel checkpoint')
|
||||
else:
|
||||
logging.error(f'restore_path: {restore_path} must be a file or directory.')
|
||||
|
||||
|
||||
def get_megatron_checkpoint_version():
|
||||
app_state = AppState()
|
||||
return app_state._megatron_checkpoint_version
|
||||
|
||||
|
||||
def set_megatron_checkpoint_version(version: int = None):
|
||||
app_state = AppState()
|
||||
app_state._megatron_checkpoint_version = version
|
||||
|
||||
|
||||
def bert_extended_attention_mask(attention_mask):
|
||||
# We create a 3D attention mask from a 2D tensor mask.
|
||||
# [b, 1, s]
|
||||
attention_mask_b1s = attention_mask.unsqueeze(1)
|
||||
# [b, s, 1]
|
||||
attention_mask_bs1 = attention_mask.unsqueeze(2)
|
||||
# [b, s, s]
|
||||
attention_mask_bss = attention_mask_b1s * attention_mask_bs1
|
||||
# [b, 1, s, s]
|
||||
extended_attention_mask = attention_mask_bss.unsqueeze(1)
|
||||
|
||||
# Convert attention mask to binary:
|
||||
extended_attention_mask = extended_attention_mask < 0.5
|
||||
|
||||
return extended_attention_mask
|
||||
|
||||
|
||||
def bert_position_ids(token_ids):
|
||||
# Create position ids
|
||||
seq_length = token_ids.size(1)
|
||||
position_ids = torch.arange(seq_length, dtype=torch.long, device=token_ids.device)
|
||||
position_ids = position_ids.unsqueeze(0).expand_as(token_ids)
|
||||
|
||||
return position_ids
|
||||
|
|
|
@ -0,0 +1,89 @@
|
|||
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from apex.transformer import tensor_parallel
|
||||
from apex.transformer.parallel_state import (
|
||||
get_pipeline_model_parallel_rank,
|
||||
set_pipeline_model_parallel_rank,
|
||||
set_pipeline_model_parallel_world_size,
|
||||
set_tensor_model_parallel_rank,
|
||||
set_tensor_model_parallel_world_size,
|
||||
)
|
||||
|
||||
from nemo.collections.nlp.modules.common.megatron.megatron_utils import compute_model_parallel_rank
|
||||
from nemo.utils import AppState
|
||||
|
||||
|
||||
def initialize_model_parallel_for_nemo(
|
||||
world_size, global_rank, local_rank, tensor_model_parallel_size=1, seed=1234,
|
||||
):
|
||||
|
||||
# updating NeMo globals
|
||||
app_state = AppState()
|
||||
app_state.global_rank = global_rank
|
||||
app_state.world_size = world_size
|
||||
app_state.model_parallel_size = tensor_model_parallel_size
|
||||
app_state.model_parallel_rank = compute_model_parallel_rank(local_rank, tensor_model_parallel_size)
|
||||
|
||||
# update apex.mpu globals
|
||||
set_tensor_model_parallel_world_size(tensor_model_parallel_size)
|
||||
set_tensor_model_parallel_rank(app_state.model_parallel_rank)
|
||||
|
||||
# pipeline model parallelism not implemented in NeMo yet
|
||||
set_pipeline_model_parallel_rank(0)
|
||||
set_pipeline_model_parallel_world_size(1)
|
||||
|
||||
_set_random_seed(seed)
|
||||
|
||||
app_state._is_megatron_initialized = True
|
||||
|
||||
|
||||
def _set_random_seed(seed_):
|
||||
"""Set random seed for reproducability."""
|
||||
if seed_ is not None and seed_ > 0:
|
||||
# Ensure that different pipeline MP stages get different seeds.
|
||||
seed = seed_ + (100 * get_pipeline_model_parallel_rank())
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
if torch.cuda.device_count() > 0:
|
||||
tensor_parallel.model_parallel_cuda_manual_seed(seed)
|
||||
else:
|
||||
raise ValueError('Seed ({}) should be a positive integer.'.format(seed_))
|
||||
|
||||
|
||||
def set_jit_fusion_options():
|
||||
"""Set PyTorch JIT layer fusion options."""
|
||||
# flags required to enable jit fusion kernels
|
||||
TORCH_MAJOR = int(torch.__version__.split('.')[0])
|
||||
TORCH_MINOR = int(torch.__version__.split('.')[1])
|
||||
if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR >= 10):
|
||||
# nvfuser
|
||||
torch._C._jit_set_profiling_executor(True)
|
||||
torch._C._jit_set_profiling_mode(True)
|
||||
torch._C._jit_override_can_fuse_on_cpu(False)
|
||||
torch._C._jit_override_can_fuse_on_gpu(False)
|
||||
torch._C._jit_set_texpr_fuser_enabled(False)
|
||||
torch._C._jit_set_nvfuser_enabled(True)
|
||||
torch._C._debug_set_autodiff_subgraph_inlining(False)
|
||||
else:
|
||||
# legacy pytorch fuser
|
||||
torch._C._jit_set_profiling_mode(False)
|
||||
torch._C._jit_set_profiling_executor(False)
|
||||
torch._C._jit_override_can_fuse_on_cpu(True)
|
||||
torch._C._jit_override_can_fuse_on_gpu(True)
|
|
@ -46,6 +46,14 @@ MEGATRON_CACHE = os.path.join(torch_home, "megatron")
|
|||
CONFIGS = {"345m": {"hidden_size": 1024, "num_attention_heads": 16, "num_layers": 24, "max_position_embeddings": 512}}
|
||||
|
||||
MEGATRON_CONFIG_MAP = {
|
||||
"megatron-gpt-345m": {
|
||||
"config": CONFIGS["345m"],
|
||||
"checkpoint": "models/nvidia/megatron_lm_345m/versions/v0.0/files/release/mp_rank_00/model_optim_rng.pt",
|
||||
"vocab": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-vocab.json",
|
||||
"merges_file": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-merges.txt",
|
||||
"do_lower_case": False,
|
||||
"tokenizer_name": "gpt2",
|
||||
},
|
||||
"megatron-bert-345m-uncased": {
|
||||
"config": CONFIGS["345m"],
|
||||
"checkpoint": "https://api.ngc.nvidia.com/v2/models/nvidia/megatron_bert_345m/versions/v0.0/files/release/mp_rank_00/model_optim_rng.pt",
|
||||
|
@ -97,6 +105,7 @@ def get_megatron_lm_model(
|
|||
config_file: Optional[str] = None,
|
||||
checkpoint_file: Optional[str] = None,
|
||||
vocab_file: Optional[str] = None,
|
||||
merges_file: Optional[str] = None,
|
||||
) -> Tuple[MegatronBertEncoder, str]:
|
||||
"""
|
||||
Returns MegatronBertEncoder and a default or user specified path to the checkpoint file
|
||||
|
@ -113,81 +122,87 @@ def get_megatron_lm_model(
|
|||
model: MegatronBertEncoder
|
||||
checkpoint_file: path to checkpoint file or directory
|
||||
"""
|
||||
config = None
|
||||
# get default config and checkpoint
|
||||
if config_file:
|
||||
with open(config_file) as f:
|
||||
config = json.load(f)
|
||||
# replace dashes with underscores in config keys
|
||||
fixed_config = {}
|
||||
for key in config.keys():
|
||||
fixed_key = key.replace("-", "_")
|
||||
if fixed_key == 'max_seq_length':
|
||||
fixed_key = 'max_position_embeddings'
|
||||
fixed_config[fixed_key] = config[key]
|
||||
# 'vocab_size" no longer used.
|
||||
if 'vocab_size' in fixed_config:
|
||||
fixed_config.pop('vocab_size')
|
||||
config = fixed_config
|
||||
elif config_dict:
|
||||
config = config_dict
|
||||
elif pretrained_model_name in get_megatron_lm_models_list():
|
||||
config = get_megatron_config(pretrained_model_name)
|
||||
else:
|
||||
raise ValueError(f"{pretrained_model_name} is not supported")
|
||||
|
||||
if config is None:
|
||||
raise ValueError(f"config_file or config_dict is required for {pretrained_model_name}")
|
||||
|
||||
if not checkpoint_file:
|
||||
checkpoint_file = get_megatron_checkpoint(pretrained_model_name)
|
||||
|
||||
if not vocab_file:
|
||||
vocab_file = get_megatron_vocab_file(pretrained_model_name)
|
||||
|
||||
app_state = AppState()
|
||||
if app_state.model_parallel_size is not None and app_state.model_parallel_rank is not None:
|
||||
# model parallel already known from .nemo restore
|
||||
model_parallel_size = app_state.model_parallel_size
|
||||
model_parallel_rank = app_state.model_parallel_rank
|
||||
elif os.path.isdir(checkpoint_file):
|
||||
# starting training from megatron-lm checkpoint
|
||||
mp_ranks = glob.glob(os.path.join(checkpoint_file, 'mp_rank*'))
|
||||
model_parallel_size = len(mp_ranks)
|
||||
app_state.model_parallel_size = model_parallel_size
|
||||
logging.info(
|
||||
(
|
||||
f'restore_path: {checkpoint_file} is a directory. '
|
||||
f'Assuming megatron model parallelism with '
|
||||
f'model_parallel_size: {model_parallel_size}'
|
||||
)
|
||||
)
|
||||
# try to get local rank from global
|
||||
local_rank = None
|
||||
try:
|
||||
local_rank = int(os.environ['LOCAL_RANK'])
|
||||
except:
|
||||
logging.info('Global variable LOCAL_RANK not yet specified')
|
||||
if local_rank is not None:
|
||||
app_state.local_rank = local_rank
|
||||
else:
|
||||
# if local is None then we are on the main process
|
||||
local_rank = 0
|
||||
model_parallel_rank = compute_model_parallel_rank(local_rank, model_parallel_size)
|
||||
app_state.model_parallel_rank = model_parallel_rank
|
||||
else:
|
||||
model_parallel_size = None
|
||||
model_parallel_rank = None
|
||||
|
||||
model = MegatronBertEncoder(
|
||||
model_name=pretrained_model_name,
|
||||
config=config,
|
||||
vocab_file=vocab_file,
|
||||
model_parallel_size=model_parallel_size,
|
||||
model_parallel_rank=model_parallel_rank,
|
||||
raise ValueError(
|
||||
f'megatron-lm bert has been deprecated in NeMo 1.5. Please use an earlier release of NeMo. Megatron bert support will be added back to NeMo in a future release.'
|
||||
)
|
||||
# config = None
|
||||
# # get default config and checkpoint
|
||||
# if config_file:
|
||||
# with open(config_file) as f:
|
||||
# config = json.load(f)
|
||||
# # replace dashes with underscores in config keys
|
||||
# fixed_config = {}
|
||||
# for key in config.keys():
|
||||
# fixed_key = key.replace("-", "_")
|
||||
# if fixed_key == 'max_seq_length':
|
||||
# fixed_key = 'max_position_embeddings'
|
||||
# fixed_config[fixed_key] = config[key]
|
||||
# # 'vocab_size" no longer used.
|
||||
# if 'vocab_size' in fixed_config:
|
||||
# fixed_config.pop('vocab_size')
|
||||
# config = fixed_config
|
||||
# elif config_dict:
|
||||
# config = config_dict
|
||||
# elif pretrained_model_name in get_megatron_lm_models_list():
|
||||
# config = get_megatron_config(pretrained_model_name)
|
||||
# else:
|
||||
# raise ValueError(f"{pretrained_model_name} is not supported")
|
||||
|
||||
return model, checkpoint_file
|
||||
# if config is None:
|
||||
# raise ValueError(f"config_file or config_dict is required for {pretrained_model_name}")
|
||||
|
||||
# if not checkpoint_file:
|
||||
# checkpoint_file = get_megatron_checkpoint(pretrained_model_name)
|
||||
|
||||
# if not vocab_file:
|
||||
# vocab_file = get_megatron_vocab_file(pretrained_model_name)
|
||||
|
||||
# if not merges_file:
|
||||
# merges_file = get_megatron_merges_file(pretrained_model_name)
|
||||
|
||||
# app_state = AppState()
|
||||
# if app_state.model_parallel_size is not None and app_state.model_parallel_rank is not None:
|
||||
# # model parallel already known from .nemo restore
|
||||
# model_parallel_size = app_state.model_parallel_size
|
||||
# model_parallel_rank = app_state.model_parallel_rank
|
||||
# elif os.path.isdir(checkpoint_file):
|
||||
# # starting training from megatron-lm checkpoint
|
||||
# mp_ranks = glob.glob(os.path.join(checkpoint_file, 'mp_rank*'))
|
||||
# model_parallel_size = len(mp_ranks)
|
||||
# app_state.model_parallel_size = model_parallel_size
|
||||
# logging.info(
|
||||
# (
|
||||
# f'restore_path: {checkpoint_file} is a directory. '
|
||||
# f'Assuming megatron model parallelism with '
|
||||
# f'model_parallel_size: {model_parallel_size}'
|
||||
# )
|
||||
# )
|
||||
# # try to get local rank from global
|
||||
# local_rank = None
|
||||
# try:
|
||||
# local_rank = int(os.environ['LOCAL_RANK'])
|
||||
# except:
|
||||
# logging.info('Global variable LOCAL_RANK not yet specified')
|
||||
# if local_rank is not None:
|
||||
# app_state.local_rank = local_rank
|
||||
# else:
|
||||
# # if local is None then we are on the main process
|
||||
# local_rank = 0
|
||||
# model_parallel_rank = compute_model_parallel_rank(local_rank, model_parallel_size)
|
||||
# app_state.model_parallel_rank = model_parallel_rank
|
||||
# else:
|
||||
# model_parallel_size = None
|
||||
# model_parallel_rank = None
|
||||
|
||||
# model = MegatronBertEncoder(
|
||||
# model_name=pretrained_model_name,
|
||||
# config=config,
|
||||
# vocab_file=vocab_file,
|
||||
# model_parallel_size=model_parallel_size,
|
||||
# model_parallel_rank=model_parallel_rank,
|
||||
# )
|
||||
|
||||
# return model, checkpoint_file
|
||||
|
||||
|
||||
def compute_model_parallel_rank(local_rank, model_parallel_size):
|
||||
|
@ -239,6 +254,26 @@ def get_megatron_vocab_file(pretrained_model_name: str) -> str:
|
|||
return path
|
||||
|
||||
|
||||
def get_megatron_merges_file(pretrained_model_name: str) -> str:
|
||||
"""
|
||||
Gets merge file from cache or downloads it
|
||||
|
||||
Args:
|
||||
pretrained_model_name: pretrained model name
|
||||
|
||||
Returns:
|
||||
path: path to the vocab file
|
||||
"""
|
||||
if 'gpt' not in pretrained_model_name.lower():
|
||||
return None
|
||||
_check_megatron_name(pretrained_model_name)
|
||||
url = MEGATRON_CONFIG_MAP[pretrained_model_name]["merges_file"]
|
||||
|
||||
path = os.path.join(MEGATRON_CACHE, pretrained_model_name + "_merges")
|
||||
path = _download(path, url)
|
||||
return path
|
||||
|
||||
|
||||
def get_megatron_checkpoint(pretrained_model_name: str) -> str:
|
||||
"""
|
||||
Gets checkpoint file from cache or downloads it
|
||||
|
|
92
nemo/collections/nlp/modules/common/megatron/module.py
Normal file
92
nemo/collections/nlp/modules/common/megatron/module.py
Normal file
|
@ -0,0 +1,92 @@
|
|||
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Megatron Module"""
|
||||
|
||||
import torch
|
||||
from apex.transformer import parallel_state, tensor_parallel
|
||||
|
||||
|
||||
def param_is_not_shared(param):
|
||||
return not hasattr(param, 'shared') or not param.shared
|
||||
|
||||
|
||||
class MegatronModule(torch.nn.Module):
|
||||
"""Megatron specific extensions of torch Module with support
|
||||
for pipelining."""
|
||||
|
||||
def __init__(self, share_word_embeddings=True):
|
||||
super(MegatronModule, self).__init__()
|
||||
self.share_word_embeddings = share_word_embeddings
|
||||
|
||||
def word_embeddings_weight(self):
|
||||
if parallel_state.is_pipeline_first_stage(ignore_virtual=True):
|
||||
return self.language_model.embedding.word_embeddings.weight
|
||||
if parallel_state.is_pipeline_last_stage(ignore_virtual=True):
|
||||
if not self.share_word_embeddings:
|
||||
raise Exception(
|
||||
'word_embeddings_weight() called for last ' 'stage, but share_word_embeddings is false'
|
||||
)
|
||||
return self.word_embeddings.weight
|
||||
raise Exception('word_embeddings_weight() should be ' 'called for first and last stage only')
|
||||
|
||||
def initialize_word_embeddings(self, init_method, vocab_size, hidden_size, pipeline_model_parallel_size=1):
|
||||
if not self.share_word_embeddings:
|
||||
raise Exception('initialize_word_embeddings() was called but ' 'share_word_embeddings is false')
|
||||
|
||||
# TODO: pipeline model parallelism is not implemented in NeMo yet
|
||||
# This function just initializes the word embeddings in the final stage
|
||||
# when we are using pipeline parallelism. If we aren't using pipeline
|
||||
# parallelism there is nothing to do.
|
||||
if pipeline_model_parallel_size == 1:
|
||||
return
|
||||
|
||||
# Parameters are shared between the word embeddings layer, and the
|
||||
# heads at the end of the model. In a pipelined setup with more than
|
||||
# one stage, the initial embedding layer and the head are on different
|
||||
# workers, so we do the following:
|
||||
# 1. Create a second copy of word_embeddings on the last stage, with
|
||||
# initial parameters of 0.0.
|
||||
# 2. Do an all-reduce between the first and last stage to ensure that
|
||||
# the two copies of word_embeddings start off with the same
|
||||
# parameter values.
|
||||
# 3. In the training loop, before an all-reduce between the grads of
|
||||
# the two word_embeddings layers to ensure that every applied weight
|
||||
# update is the same on both stages.
|
||||
if parallel_state.is_pipeline_last_stage():
|
||||
assert not parallel_state.is_pipeline_first_stage()
|
||||
self._word_embeddings_for_head_key = 'word_embeddings_for_head'
|
||||
# set word_embeddings weights to 0 here, then copy first
|
||||
# stage's weights using all_reduce below.
|
||||
self.word_embeddings = tensor_parallel.VocabParallelEmbedding(
|
||||
vocab_size, hidden_size, init_method=init_method
|
||||
)
|
||||
self.word_embeddings.weight.data.fill_(0)
|
||||
self.word_embeddings.weight.shared = True
|
||||
|
||||
# Ensure that first and last stages have the same initial parameter
|
||||
# values.
|
||||
if torch.distributed.is_initialized():
|
||||
if parallel_state.is_pipeline_first_stage() or parallel_state.is_pipeline_last_stage():
|
||||
torch.distributed.all_reduce(
|
||||
self.word_embeddings_weight().data, group=parallel_state.get_embedding_group()
|
||||
)
|
||||
else:
|
||||
print(
|
||||
"WARNING! Distributed processes aren't initialized, so "
|
||||
"word embeddings in the last layer are not initialized. "
|
||||
"If you are just manipulating a model this is fine, but "
|
||||
"this needs to be handled manually. If you are training "
|
||||
"something is definitely wrong."
|
||||
)
|
850
nemo/collections/nlp/modules/common/megatron/transformer.py
Normal file
850
nemo/collections/nlp/modules/common/megatron/transformer.py
Normal file
|
@ -0,0 +1,850 @@
|
|||
# coding=utf-8
|
||||
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Transformer."""
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from apex.normalization.fused_layer_norm import MixedFusedLayerNorm as LayerNorm
|
||||
from apex.transformer import parallel_state, tensor_parallel
|
||||
from apex.transformer.enums import AttnMaskType, AttnType, LayerType
|
||||
from apex.transformer.functional.fused_softmax import FusedScaleMaskSoftmax
|
||||
|
||||
from nemo.collections.nlp.modules.common.megatron.fused_bias_dropout_add import (
|
||||
BiasDropoutAddFusedInference,
|
||||
BiasDropoutAddFusedTrain,
|
||||
bias_dropout_add,
|
||||
)
|
||||
from nemo.collections.nlp.modules.common.megatron.fused_bias_gelu import FusedBiasGeLU
|
||||
from nemo.collections.nlp.modules.common.megatron.module import MegatronModule
|
||||
from nemo.collections.nlp.modules.common.megatron.utils import attention_mask_func, erf_gelu
|
||||
|
||||
|
||||
""" We use the following notation throughout this file:
|
||||
h: hidden size
|
||||
n: number of attention heads
|
||||
p: number of model parallel partitions
|
||||
np: n/p
|
||||
hp: h/p
|
||||
hn: h/n
|
||||
b: batch size
|
||||
s: sequence length
|
||||
l: number of layers
|
||||
Transformer takes input of size [s, b, h] and returns a
|
||||
tensor of the same size. We use the following arguments:
|
||||
hyperparameters: transformer hyperparameters
|
||||
"""
|
||||
|
||||
|
||||
class ParallelMLP(MegatronModule):
|
||||
"""MLP.
|
||||
|
||||
MLP will take the input with h hidden state, project it to 4*h
|
||||
hidden dimension, perform nonlinear transformation, and project the
|
||||
state back into h hidden dimension.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
init_method,
|
||||
output_layer_init_method,
|
||||
hidden_size,
|
||||
ffn_hidden_size,
|
||||
use_cpu_initialization=False,
|
||||
bias_gelu_fusion=True,
|
||||
openai_gelu=False,
|
||||
onnx_safe=False,
|
||||
fused_fp16=False,
|
||||
fused_bf16=False,
|
||||
):
|
||||
super(ParallelMLP, self).__init__()
|
||||
|
||||
# Project to 4h.
|
||||
self.dense_h_to_4h = tensor_parallel.ColumnParallelLinear(
|
||||
hidden_size,
|
||||
ffn_hidden_size,
|
||||
gather_output=False,
|
||||
init_method=init_method,
|
||||
skip_bias_add=True,
|
||||
use_cpu_initialization=use_cpu_initialization,
|
||||
)
|
||||
|
||||
self.bias_gelu_fusion = bias_gelu_fusion
|
||||
self.activation_func = F.gelu
|
||||
if openai_gelu:
|
||||
self.activation_func = openai_gelu
|
||||
elif onnx_safe:
|
||||
self.activation_func = erf_gelu
|
||||
|
||||
# Project back to h.
|
||||
self.dense_4h_to_h = tensor_parallel.RowParallelLinear(
|
||||
ffn_hidden_size,
|
||||
hidden_size,
|
||||
input_is_parallel=True,
|
||||
init_method=output_layer_init_method,
|
||||
skip_bias_add=True,
|
||||
use_cpu_initialization=use_cpu_initialization,
|
||||
)
|
||||
|
||||
self.bias_gelu_impl = FusedBiasGeLU(fused_fp16, fused_bf16)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
|
||||
# [s, b, 4hp]
|
||||
intermediate_parallel, bias_parallel = self.dense_h_to_4h(hidden_states)
|
||||
|
||||
if self.bias_gelu_fusion:
|
||||
intermediate_parallel = self.bias_gelu_impl(intermediate_parallel, bias_parallel)
|
||||
else:
|
||||
intermediate_parallel = self.activation_func(intermediate_parallel + bias_parallel)
|
||||
|
||||
# [s, b, h]
|
||||
output, output_bias = self.dense_4h_to_h(intermediate_parallel)
|
||||
return output, output_bias
|
||||
|
||||
|
||||
class ParallelAttention(MegatronModule):
|
||||
"""Parallel self-attention layer abstract class.
|
||||
|
||||
Self-attention layer takes input with size [b, s, h]
|
||||
and returns output of the same size.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
init_method,
|
||||
output_layer_init_method,
|
||||
layer_number,
|
||||
num_attention_heads,
|
||||
hidden_size,
|
||||
attention_type=AttnType.self_attn,
|
||||
attn_mask_type=AttnMaskType.padding,
|
||||
fused_fp16=False,
|
||||
fused_bf16=False,
|
||||
apply_query_key_layer_scaling=True,
|
||||
kv_channels=None,
|
||||
use_cpu_initialization=False,
|
||||
masked_softmax_fusion=True,
|
||||
attention_dropout=0.1,
|
||||
):
|
||||
super(ParallelAttention, self).__init__()
|
||||
|
||||
self.apply_query_key_layer_scaling = apply_query_key_layer_scaling
|
||||
self.attention_softmax_in_fp32 = False
|
||||
if self.apply_query_key_layer_scaling:
|
||||
self.attention_softmax_in_fp32 = True
|
||||
self.layer_number = max(1, layer_number)
|
||||
self.attention_type = attention_type
|
||||
self.attn_mask_type = attn_mask_type
|
||||
|
||||
if kv_channels is None:
|
||||
assert (
|
||||
hidden_size % num_attention_heads == 0
|
||||
), 'hidden_size must be divisible by num_attention_heads if kv_channels is None'
|
||||
kv_channels = hidden_size // num_attention_heads
|
||||
projection_size = kv_channels * num_attention_heads
|
||||
|
||||
# Per attention head and per partition values.
|
||||
world_size = parallel_state.get_tensor_model_parallel_world_size()
|
||||
self.hidden_size_per_partition = tensor_parallel.divide(projection_size, world_size)
|
||||
self.hidden_size_per_attention_head = tensor_parallel.divide(projection_size, num_attention_heads)
|
||||
self.num_attention_heads_per_partition = tensor_parallel.divide(num_attention_heads, world_size)
|
||||
|
||||
# Strided linear layer.
|
||||
if attention_type == AttnType.self_attn:
|
||||
self.query_key_value = tensor_parallel.ColumnParallelLinear(
|
||||
hidden_size,
|
||||
3 * projection_size,
|
||||
gather_output=False,
|
||||
init_method=init_method,
|
||||
use_cpu_initialization=use_cpu_initialization,
|
||||
)
|
||||
else:
|
||||
assert attention_type == AttnType.cross_attn
|
||||
self.query = tensor_parallel.ColumnParallelLinear(
|
||||
hidden_size, projection_size, gather_output=False, init_method=init_method
|
||||
)
|
||||
|
||||
self.key_value = tensor_parallel.ColumnParallelLinear(
|
||||
hidden_size, 2 * projection_size, gather_output=False, init_method=init_method
|
||||
)
|
||||
|
||||
coeff = None
|
||||
self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
|
||||
if self.apply_query_key_layer_scaling:
|
||||
coeff = self.layer_number
|
||||
self.norm_factor *= coeff
|
||||
|
||||
self.scale_mask_softmax = FusedScaleMaskSoftmax(
|
||||
fused_fp16,
|
||||
fused_bf16,
|
||||
self.attn_mask_type,
|
||||
masked_softmax_fusion,
|
||||
attention_mask_func,
|
||||
self.attention_softmax_in_fp32,
|
||||
coeff,
|
||||
)
|
||||
|
||||
# Dropout. Note that for a single iteration, this layer will generate
|
||||
# different outputs on different number of parallel partitions but
|
||||
# on average it should not be partition dependent.
|
||||
self.attention_dropout = torch.nn.Dropout(attention_dropout)
|
||||
|
||||
# Output.
|
||||
self.dense = tensor_parallel.RowParallelLinear(
|
||||
projection_size,
|
||||
hidden_size,
|
||||
input_is_parallel=True,
|
||||
init_method=output_layer_init_method,
|
||||
skip_bias_add=True,
|
||||
use_cpu_initialization=use_cpu_initialization,
|
||||
)
|
||||
|
||||
def forward(self, hidden_states, attention_mask, layer_past=None, get_key_value=False, encoder_output=None):
|
||||
# hidden_states: [sq, b, h]
|
||||
|
||||
# =====================
|
||||
# Query, Key, and Value
|
||||
# =====================
|
||||
|
||||
if self.attention_type == AttnType.self_attn:
|
||||
# Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)]
|
||||
mixed_x_layer, _ = self.query_key_value(hidden_states)
|
||||
|
||||
# [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn]
|
||||
new_tensor_shape = mixed_x_layer.size()[:-1] + (
|
||||
self.num_attention_heads_per_partition,
|
||||
3 * self.hidden_size_per_attention_head,
|
||||
)
|
||||
mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)
|
||||
|
||||
# [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn]
|
||||
(query_layer, key_layer, value_layer) = tensor_parallel.split_tensor_along_last_dim(mixed_x_layer, 3)
|
||||
else:
|
||||
# Attention heads [sk, b, h] --> [sk, b, (np * 2 * hn)]
|
||||
mixed_kv_layer, _ = self.key_value(encoder_output)
|
||||
|
||||
# [sk, b, (np * 2 * hn)] --> [sk, b, np, 2 * hn]
|
||||
new_tensor_shape = mixed_kv_layer.size()[:-1] + (
|
||||
self.num_attention_heads_per_partition,
|
||||
2 * self.hidden_size_per_attention_head,
|
||||
)
|
||||
mixed_kv_layer = mixed_kv_layer.view(*new_tensor_shape)
|
||||
|
||||
# [sk, b, np, 2 * hn] --> 2 [sk, b, np, hn]
|
||||
(key_layer, value_layer) = tensor_parallel.split_tensor_along_last_dim(mixed_kv_layer, 2)
|
||||
|
||||
# Attention head [sq, b, h] --> [sq, b, hp]
|
||||
query_layer, _ = self.query(hidden_states)
|
||||
# [sq, b, hp] --> [sq, b, np, hn]
|
||||
new_tensor_shape = query_layer.size()[:-1] + (
|
||||
self.num_attention_heads_per_partition,
|
||||
self.hidden_size_per_attention_head,
|
||||
)
|
||||
query_layer = query_layer.view(*new_tensor_shape)
|
||||
|
||||
# ==================================
|
||||
# Adjust key and value for inference
|
||||
# ==================================
|
||||
|
||||
if layer_past is not None:
|
||||
past_key, past_value = layer_past
|
||||
key_layer = torch.cat((past_key.type_as(key_layer), key_layer), dim=0)
|
||||
value_layer = torch.cat((past_value.type_as(value_layer), value_layer), dim=0)
|
||||
if get_key_value:
|
||||
present = (key_layer, value_layer)
|
||||
|
||||
# ===================================
|
||||
# Raw attention scores. [b, np, s, s]
|
||||
# ===================================
|
||||
|
||||
# [b, np, sq, sk]
|
||||
output_size = (query_layer.size(1), query_layer.size(2), query_layer.size(0), key_layer.size(0))
|
||||
|
||||
# [sq, b, np, hn] -> [sq, b * np, hn]
|
||||
query_layer = query_layer.view(output_size[2], output_size[0] * output_size[1], -1)
|
||||
# [sk, b, np, hn] -> [sk, b * np, hn]
|
||||
key_layer = key_layer.view(output_size[3], output_size[0] * output_size[1], -1)
|
||||
|
||||
# preallocting result tensor: [b * np, sq, sk]
|
||||
matmul_result = torch.empty(
|
||||
output_size[0] * output_size[1],
|
||||
output_size[2],
|
||||
output_size[3],
|
||||
dtype=query_layer.dtype,
|
||||
device=torch.cuda.current_device(),
|
||||
)
|
||||
|
||||
# Raw attention scores. [b * np, sq, sk]
|
||||
matmul_result = torch.baddbmm(
|
||||
matmul_result,
|
||||
query_layer.transpose(0, 1), # [b * np, sq, hn]
|
||||
key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk]
|
||||
beta=0.0,
|
||||
alpha=(1.0 / self.norm_factor),
|
||||
)
|
||||
|
||||
# change view to [b, np, sq, sk]
|
||||
attention_scores = matmul_result.view(*output_size)
|
||||
|
||||
# ==================================================
|
||||
# Update attention mask for inference. [b, np, sq, sk]
|
||||
# ==================================================
|
||||
|
||||
if get_key_value:
|
||||
with torch.no_grad():
|
||||
if layer_past is not None:
|
||||
attention_mask = attention_mask[
|
||||
..., attention_scores.size(3) - 1, : attention_scores.size(3)
|
||||
].unsqueeze(2)
|
||||
else:
|
||||
attention_mask = attention_mask[..., : attention_scores.size(3), : attention_scores.size(3)]
|
||||
|
||||
# ===========================
|
||||
# Attention probs and dropout
|
||||
# ===========================
|
||||
|
||||
# attention scores and attention mask [b, np, sq, sk]
|
||||
attention_probs = self.scale_mask_softmax(attention_scores, attention_mask)
|
||||
|
||||
# This is actually dropping out entire tokens to attend to, which might
|
||||
# seem a bit unusual, but is taken from the original Transformer paper.
|
||||
with tensor_parallel.random.get_cuda_rng_tracker().fork():
|
||||
attention_probs = self.attention_dropout(attention_probs)
|
||||
|
||||
# =========================
|
||||
# Context layer. [sq, b, hp]
|
||||
# =========================
|
||||
|
||||
# value_layer -> context layer.
|
||||
# [sk, b, np, hn] --> [b, np, sq, hn]
|
||||
|
||||
# context layer shape: [b, np, sq, hn]
|
||||
output_size = (value_layer.size(1), value_layer.size(2), query_layer.size(0), value_layer.size(3))
|
||||
|
||||
# change view [sk, b * np, hn]
|
||||
value_layer = value_layer.view(value_layer.size(0), output_size[0] * output_size[1], -1)
|
||||
|
||||
# change view [b * np, sq, sk]
|
||||
attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1)
|
||||
|
||||
# matmul: [b * np, sq, hn]
|
||||
context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1))
|
||||
|
||||
# change view [b, np, sq, hn]
|
||||
context_layer = context_layer.view(*output_size)
|
||||
|
||||
# [b, np, sq, hn] --> [sq, b, np, hn]
|
||||
context_layer = context_layer.permute(2, 0, 1, 3).contiguous()
|
||||
|
||||
# [sq, b, np, hn] --> [sq, b, hp]
|
||||
new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
|
||||
context_layer = context_layer.view(*new_context_layer_shape)
|
||||
|
||||
# =================
|
||||
# Output. [sq, b, h]
|
||||
# =================
|
||||
|
||||
output, bias = self.dense(context_layer)
|
||||
|
||||
if get_key_value:
|
||||
output = [output, present]
|
||||
|
||||
return output, bias
|
||||
|
||||
|
||||
def get_bias_dropout_add(training):
|
||||
def _bias_dropout_add(x, bias, residual, prob):
|
||||
return bias_dropout_add(x, bias, residual, prob, training)
|
||||
|
||||
return _bias_dropout_add
|
||||
|
||||
|
||||
class ParallelTransformerLayer_(MegatronModule):
|
||||
"""A single transformer layer.
|
||||
|
||||
Transformer layer takes input with size [b, s, h] and returns an
|
||||
output of the same size.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
init_method,
|
||||
output_layer_init_method,
|
||||
layer_number,
|
||||
hidden_size,
|
||||
ffn_hidden_size,
|
||||
num_attention_heads,
|
||||
layer_type=LayerType.encoder,
|
||||
self_attn_mask_type=AttnMaskType.padding,
|
||||
fp32_residual_connection=False,
|
||||
apply_residual_connection_post_layernorm=False,
|
||||
fused_fp16=False,
|
||||
fused_bf16=False,
|
||||
apply_query_key_layer_scaling=True,
|
||||
kv_channels=None,
|
||||
layernorm_epsilon=1e-5,
|
||||
hidden_dropout=0.1,
|
||||
bias_dropout_fusion=True,
|
||||
use_cpu_initialization=False,
|
||||
bias_gelu_fusion=True,
|
||||
openai_gelu=False,
|
||||
onnx_safe=False,
|
||||
masked_softmax_fusion=True,
|
||||
attention_dropout=0.1,
|
||||
):
|
||||
super(ParallelTransformerLayer_, self).__init__()
|
||||
|
||||
assert not (fused_fp16 and fused_bf16), "both fused_fp16 and fused_bf16 flags cannot be True at the same time."
|
||||
|
||||
if kv_channels is None:
|
||||
assert (
|
||||
hidden_size % num_attention_heads == 0
|
||||
), 'hidden_size must be divisible by num_attention_heads if kv_channels is None'
|
||||
kv_channels = hidden_size // num_attention_heads
|
||||
|
||||
self.layer_number = layer_number
|
||||
self.layer_type = layer_type
|
||||
|
||||
self.apply_residual_connection_post_layernorm = apply_residual_connection_post_layernorm # if true apply residual connection post layer norm (like original bert)
|
||||
|
||||
self.bf16 = fused_bf16
|
||||
self.fp32_residual_connection = fp32_residual_connection # if true move residual connections to fp32
|
||||
|
||||
# Layernorm on the input data.
|
||||
self.input_layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon, fp16=fused_fp16, bf16=fused_bf16)
|
||||
|
||||
# Self attention.
|
||||
self.self_attention = ParallelAttention(
|
||||
init_method=init_method,
|
||||
output_layer_init_method=output_layer_init_method,
|
||||
layer_number=layer_number,
|
||||
num_attention_heads=num_attention_heads,
|
||||
hidden_size=hidden_size,
|
||||
attention_type=AttnType.self_attn,
|
||||
attn_mask_type=self_attn_mask_type,
|
||||
fused_fp16=fused_fp16,
|
||||
fused_bf16=fused_bf16,
|
||||
apply_query_key_layer_scaling=apply_query_key_layer_scaling,
|
||||
kv_channels=kv_channels,
|
||||
use_cpu_initialization=use_cpu_initialization,
|
||||
masked_softmax_fusion=masked_softmax_fusion,
|
||||
attention_dropout=attention_dropout,
|
||||
)
|
||||
self.hidden_dropout = hidden_dropout
|
||||
self.bias_dropout_fusion = bias_dropout_fusion # if true, enable bias dropout fusion
|
||||
|
||||
# Layernorm on the attention output
|
||||
self.post_attention_layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon, fp16=fused_fp16, bf16=fused_bf16)
|
||||
|
||||
if self.layer_type == LayerType.decoder:
|
||||
self.inter_attention = ParallelAttention(
|
||||
init_method=init_method,
|
||||
output_layer_init_method=output_layer_init_method,
|
||||
layer_number=layer_number,
|
||||
num_attention_heads=num_attention_heads,
|
||||
hidden_size=hidden_size,
|
||||
attention_type=AttnType.cross_attn,
|
||||
attn_mask_type=AttnMaskType.padding,
|
||||
fused_fp16=fused_fp16,
|
||||
fused_bf16=fused_bf16,
|
||||
apply_query_key_layer_scaling=apply_query_key_layer_scaling,
|
||||
kv_channels=kv_channels,
|
||||
use_cpu_initialization=use_cpu_initialization,
|
||||
masked_softmax_fusion=masked_softmax_fusion,
|
||||
attention_dropout=attention_dropout,
|
||||
)
|
||||
# Layernorm on the attention output.
|
||||
self.post_inter_attention_layernorm = LayerNorm(
|
||||
hidden_size, eps=layernorm_epsilon, fp16=fused_fp16, bf16=fused_bf16
|
||||
)
|
||||
|
||||
# MLP
|
||||
self.mlp = ParallelMLP(
|
||||
init_method=init_method,
|
||||
output_layer_init_method=output_layer_init_method,
|
||||
hidden_size=hidden_size,
|
||||
ffn_hidden_size=ffn_hidden_size,
|
||||
use_cpu_initialization=use_cpu_initialization,
|
||||
bias_gelu_fusion=bias_gelu_fusion,
|
||||
openai_gelu=openai_gelu,
|
||||
onnx_safe=onnx_safe,
|
||||
fused_fp16=fused_fp16,
|
||||
fused_bf16=fused_bf16,
|
||||
)
|
||||
self.bias_dropout_add_fused_train = BiasDropoutAddFusedTrain(fused_fp16, fused_bf16)
|
||||
self.bias_dropout_add_fused_inference = BiasDropoutAddFusedInference(fused_fp16, fused_bf16)
|
||||
|
||||
self.bias_dropout_add_fused_train = BiasDropoutAddFusedTrain(fused_fp16, fused_bf16)
|
||||
self.bias_dropout_add_fused_inference = BiasDropoutAddFusedInference(fused_fp16, fused_bf16)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
encoder_output=None,
|
||||
enc_dec_attn_mask=None,
|
||||
layer_past=None,
|
||||
get_key_value=False,
|
||||
):
|
||||
# hidden_states: [b, s, h]
|
||||
|
||||
# Layer norm at the beginning of the transformer layer.
|
||||
layernorm_output = self.input_layernorm(hidden_states)
|
||||
# Self attention.
|
||||
attention_output, attention_bias = self.self_attention(
|
||||
layernorm_output, attention_mask, layer_past=layer_past, get_key_value=get_key_value
|
||||
)
|
||||
|
||||
if get_key_value:
|
||||
attention_output, presents = attention_output
|
||||
|
||||
# Residual connection.
|
||||
if self.apply_residual_connection_post_layernorm:
|
||||
residual = layernorm_output
|
||||
else:
|
||||
residual = hidden_states
|
||||
|
||||
# jit scripting for a nn.module (with dropout) is not
|
||||
# trigerring the fusion kernel. For now, we use two
|
||||
# different nn.functional routines to account for varying
|
||||
# dropout semantics during training and inference phases.
|
||||
if self.bias_dropout_fusion:
|
||||
if self.training:
|
||||
bias_dropout_add_func = self.bias_dropout_add_fused_train
|
||||
else:
|
||||
bias_dropout_add_func = self.bias_dropout_add_fused_inference
|
||||
else:
|
||||
bias_dropout_add_func = get_bias_dropout_add(self.training)
|
||||
|
||||
# re-enable torch grad to enable fused optimization.
|
||||
with torch.enable_grad():
|
||||
layernorm_input = bias_dropout_add_func(
|
||||
attention_output, attention_bias.expand_as(residual), residual, self.hidden_dropout
|
||||
)
|
||||
|
||||
# Layer norm post the self attention.
|
||||
layernorm_output = self.post_attention_layernorm(layernorm_input)
|
||||
|
||||
if self.layer_type == LayerType.decoder:
|
||||
attention_output, attention_bias = self.inter_attention(
|
||||
layernorm_output, enc_dec_attn_mask, encoder_output=encoder_output
|
||||
)
|
||||
# residual connection
|
||||
if self.apply_residual_connection_post_layernorm:
|
||||
residual = layernorm_output
|
||||
else:
|
||||
residual = layernorm_input
|
||||
|
||||
# re-enable torch grad to enable fused optimization.
|
||||
with torch.enable_grad():
|
||||
layernorm_input = bias_dropout_add_func(
|
||||
attention_output, attention_bias.expand_as(residual), residual, self.hidden_dropout
|
||||
)
|
||||
|
||||
# Layer norm post the decoder attention
|
||||
layernorm_output = self.post_inter_attention_layernorm(layernorm_input)
|
||||
|
||||
# MLP.
|
||||
mlp_output, mlp_bias = self.mlp(layernorm_output)
|
||||
|
||||
# Second residual connection.
|
||||
if self.apply_residual_connection_post_layernorm:
|
||||
residual = layernorm_output
|
||||
else:
|
||||
residual = layernorm_input
|
||||
|
||||
# re-enable torch grad to enable fused optimization.
|
||||
with torch.enable_grad():
|
||||
output = bias_dropout_add_func(mlp_output, mlp_bias.expand_as(residual), residual, self.hidden_dropout)
|
||||
|
||||
if get_key_value:
|
||||
output = [output, presents]
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class ParallelTransformerLayer(ParallelTransformerLayer_):
|
||||
def __init__(self, **kwargs):
|
||||
super(ParallelTransformerLayer, self).__init__(**kwargs)
|
||||
|
||||
assert not (kwargs['fused_fp16'] and kwargs['fused_bf16'])
|
||||
if kwargs['fused_fp16']:
|
||||
self.dtype = torch.float16
|
||||
elif kwargs['fused_bf16']:
|
||||
self.dtype = torch.bfloat16
|
||||
else:
|
||||
self.dtype = torch.float32
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
encoder_output=None,
|
||||
enc_dec_attn_mask=None,
|
||||
layer_past=None,
|
||||
get_key_value=False,
|
||||
):
|
||||
|
||||
if self.dtype == torch.float32:
|
||||
return super().forward(
|
||||
hidden_states, attention_mask, encoder_output, enc_dec_attn_mask, layer_past, get_key_value
|
||||
)
|
||||
else:
|
||||
with torch.cuda.amp.autocast(dtype=self.dtype):
|
||||
return super().forward(
|
||||
hidden_states, attention_mask, encoder_output, enc_dec_attn_mask, layer_past, get_key_value
|
||||
)
|
||||
|
||||
|
||||
class ParallelTransformer(MegatronModule):
|
||||
"""Transformer class."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
init_method,
|
||||
output_layer_init_method,
|
||||
num_layers,
|
||||
hidden_size,
|
||||
ffn_hidden_size,
|
||||
num_attention_heads,
|
||||
apply_query_key_layer_scaling=True,
|
||||
kv_channels=None,
|
||||
layer_type=LayerType.encoder,
|
||||
self_attn_mask_type=AttnMaskType.padding,
|
||||
pre_process=True,
|
||||
post_process=True,
|
||||
fused_fp16=False,
|
||||
fused_bf16=False,
|
||||
fp32_residual_connection=False,
|
||||
activations_checkpoint_method=None,
|
||||
activations_checkpoint_num_layers=1,
|
||||
layernorm_epsilon=1e-5,
|
||||
hidden_dropout=0.1,
|
||||
use_cpu_initialization=False,
|
||||
bias_gelu_fusion=True,
|
||||
openai_gelu=False,
|
||||
onnx_safe=False,
|
||||
):
|
||||
super(ParallelTransformer, self).__init__()
|
||||
|
||||
assert not (fused_fp16 and fused_bf16), "both fused_fp16 and fused_bf16 flags cannot be True at the same time."
|
||||
|
||||
if kv_channels is None:
|
||||
assert (
|
||||
hidden_size % num_attention_heads == 0
|
||||
), 'hidden_size must be divisible by num_attention_heads if kv_channels is None'
|
||||
kv_channels = hidden_size // num_attention_heads
|
||||
|
||||
self.bf16 = fused_bf16
|
||||
self.fp32_residual_connection = fp32_residual_connection
|
||||
self.pre_process = pre_process
|
||||
self.post_process = post_process
|
||||
self.input_tensor = None
|
||||
|
||||
# Store activation checkpointing flag.
|
||||
self.activations_checkpoint_method = activations_checkpoint_method
|
||||
self.activations_checkpoint_num_layers = activations_checkpoint_num_layers
|
||||
|
||||
# Number of layers.
|
||||
assert (
|
||||
num_layers % parallel_state.get_pipeline_model_parallel_world_size() == 0
|
||||
), 'num_layers must be divisible by pipeline_model_parallel_size'
|
||||
self.num_layers = num_layers // parallel_state.get_pipeline_model_parallel_world_size()
|
||||
|
||||
# Transformer layers.
|
||||
def build_layer(layer_number):
|
||||
return ParallelTransformerLayer(
|
||||
init_method=init_method,
|
||||
output_layer_init_method=output_layer_init_method,
|
||||
layer_number=layer_number,
|
||||
hidden_size=hidden_size,
|
||||
ffn_hidden_size=ffn_hidden_size,
|
||||
num_attention_heads=num_attention_heads,
|
||||
apply_query_key_layer_scaling=apply_query_key_layer_scaling,
|
||||
kv_channels=kv_channels,
|
||||
layer_type=layer_type,
|
||||
self_attn_mask_type=self_attn_mask_type,
|
||||
fused_fp16=fused_fp16,
|
||||
fused_bf16=fused_bf16,
|
||||
fp32_residual_connection=fp32_residual_connection,
|
||||
layernorm_epsilon=layernorm_epsilon,
|
||||
hidden_dropout=hidden_dropout,
|
||||
use_cpu_initialization=use_cpu_initialization,
|
||||
bias_gelu_fusion=bias_gelu_fusion,
|
||||
openai_gelu=openai_gelu,
|
||||
onnx_safe=onnx_safe,
|
||||
)
|
||||
|
||||
# TODO: get virtual_pipeline_model_parallel_size from apex.mpu
|
||||
# if parallel_state.get_virtual_pipeline_model_parallel_rank() is not None:
|
||||
# assert args.num_layers % args.virtual_pipeline_model_parallel_size == 0, (
|
||||
# 'num_layers_per_stage must be divisible by ' 'virtual_pipeline_model_parallel_size'
|
||||
# )
|
||||
# # Number of layers in each model chunk is the number of layers in the stage,
|
||||
# # divided by the number of model chunks in a stage.
|
||||
# self.num_layers = self.num_layers // args.virtual_pipeline_model_parallel_size
|
||||
# # With 8 layers, 2 stages, and 4 model chunks, we want an assignment of
|
||||
# # layers to stages like (each list is a model chunk):
|
||||
# # Stage 0: [0] [2] [4] [6]
|
||||
# # Stage 1: [1] [3] [5] [7]
|
||||
# # With 8 layers, 2 stages, and 2 virtual stages, we want an assignment of
|
||||
# # layers to stages like (each list is a model chunk):
|
||||
# # Stage 0: [0, 1] [4, 5]
|
||||
# # Stage 1: [2, 3] [6, 7]
|
||||
# offset = parallel_state.get_virtual_pipeline_model_parallel_rank() * (
|
||||
# args.num_layers // args.virtual_pipeline_model_parallel_size
|
||||
# ) + (parallel_state.get_pipeline_model_parallel_rank() * self.num_layers)
|
||||
# else:
|
||||
# # Each stage gets a contiguous set of layers.
|
||||
# offset = parallel_state.get_pipeline_model_parallel_rank() * self.num_layers
|
||||
offset = parallel_state.get_pipeline_model_parallel_rank() * self.num_layers
|
||||
|
||||
self.layers = torch.nn.ModuleList([build_layer(i + 1 + offset) for i in range(self.num_layers)])
|
||||
|
||||
if self.post_process:
|
||||
# Final layer norm before output.
|
||||
self.final_layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon, fp16=fused_fp16, bf16=fused_bf16)
|
||||
|
||||
def _get_layer(self, layer_number):
|
||||
return self.layers[layer_number]
|
||||
|
||||
def _checkpointed_forward(self, hidden_states, attention_mask, encoder_output, enc_dec_attn_mask):
|
||||
"""Forward method with activation checkpointing."""
|
||||
|
||||
def custom(start, end):
|
||||
def custom_forward(*inputs):
|
||||
x_ = inputs[0]
|
||||
attention_mask = inputs[1]
|
||||
encoder_output = inputs[2]
|
||||
enc_dec_attn_mask = inputs[3]
|
||||
for index in range(start, end):
|
||||
layer = self._get_layer(index)
|
||||
x_ = layer(x_, attention_mask, encoder_output, enc_dec_attn_mask)
|
||||
return x_
|
||||
|
||||
return custom_forward
|
||||
|
||||
# Make sure memory is freed.
|
||||
tensor_parallel.reset_checkpointed_activations_memory_buffer()
|
||||
|
||||
if self.activations_checkpoint_method == 'uniform':
|
||||
# Uniformly divide the total number of Transformer layers and checkpoint
|
||||
# the input activation of each divided chunk.
|
||||
# A method to further reduce memory usage reducing checkpoints.
|
||||
l = 0
|
||||
while l < self.num_layers:
|
||||
hidden_states = tensor_parallel.checkpoint(
|
||||
custom(l, l + self.activations_checkpoint_num_layers),
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
encoder_output,
|
||||
enc_dec_attn_mask,
|
||||
)
|
||||
l += self.activations_checkpoint_num_layers
|
||||
elif self.activations_checkpoint_method == 'block':
|
||||
# Checkpoint the input activation of only a set number of individual
|
||||
# Transformer layers and skip the rest.
|
||||
# A method fully use the device memory removing redundant re-computation.
|
||||
for l in range(self.num_layers):
|
||||
if l < self.activations_checkpoint_num_layers:
|
||||
hidden_states = tensor_parallel.checkpoint(
|
||||
custom(l, l + 1), hidden_states, attention_mask, encoder_output, enc_dec_attn_mask
|
||||
)
|
||||
else:
|
||||
hidden_states = custom(l, l + 1)(hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
|
||||
else:
|
||||
raise ValueError("Invalid activation checkpoint method.")
|
||||
|
||||
return hidden_states
|
||||
|
||||
def set_input_tensor(self, input_tensor):
|
||||
"""Set input tensor to be used instead of forward()'s input.
|
||||
|
||||
When doing pipeline parallelism the input from the previous
|
||||
stage comes from communication, not from the input, so the
|
||||
model's forward_step_func won't have it. This function is thus
|
||||
used by internal code to bypass the input provided by the
|
||||
forward_step_func"""
|
||||
self.input_tensor = input_tensor
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
layer_past=None,
|
||||
get_key_value=False,
|
||||
encoder_output=None,
|
||||
enc_dec_attn_mask=None,
|
||||
):
|
||||
# Checks.
|
||||
if layer_past is not None:
|
||||
assert get_key_value, 'for not None values in layer_past, ' 'expected get_key_value to be set'
|
||||
if get_key_value:
|
||||
assert self.activations_checkpoint_method is None, (
|
||||
'get_key_value does not work with ' 'activation checkpointing'
|
||||
)
|
||||
|
||||
if self.pre_process:
|
||||
# Data format change to avoid explicit tranposes : [b s h] --> [s b h].
|
||||
# If the input flag for fp32 residual connection is set, convert for float.
|
||||
if self.fp32_residual_connection:
|
||||
hidden_states = hidden_states.transpose(0, 1).contiguous().float()
|
||||
# Otherwise, leave it as is.
|
||||
else:
|
||||
hidden_states = hidden_states.transpose(0, 1).contiguous()
|
||||
else:
|
||||
# See set_input_tensor()
|
||||
hidden_states = self.input_tensor
|
||||
|
||||
if encoder_output is not None:
|
||||
encoder_output = encoder_output.transpose(0, 1).contiguous()
|
||||
|
||||
if self.activations_checkpoint_method is not None:
|
||||
hidden_states = self._checkpointed_forward(
|
||||
hidden_states, attention_mask, encoder_output, enc_dec_attn_mask
|
||||
)
|
||||
else:
|
||||
if get_key_value:
|
||||
presents = []
|
||||
for index in range(self.num_layers):
|
||||
layer = self._get_layer(index)
|
||||
past = None
|
||||
if layer_past is not None:
|
||||
past = layer_past[index]
|
||||
hidden_states = layer(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
encoder_output=encoder_output,
|
||||
enc_dec_attn_mask=enc_dec_attn_mask,
|
||||
layer_past=past,
|
||||
get_key_value=get_key_value,
|
||||
)
|
||||
if get_key_value:
|
||||
hidden_states, present = hidden_states
|
||||
presents.append(present)
|
||||
|
||||
# Final layer norm.
|
||||
if self.post_process:
|
||||
# Reverting data format change [s b h] --> [b s h].
|
||||
hidden_states = hidden_states.transpose(0, 1).contiguous()
|
||||
output = self.final_layernorm(hidden_states)
|
||||
else:
|
||||
output = hidden_states
|
||||
if get_key_value:
|
||||
output = [output, presents]
|
||||
|
||||
return output
|
166
nemo/collections/nlp/modules/common/megatron/utils.py
Normal file
166
nemo/collections/nlp/modules/common/megatron/utils.py
Normal file
|
@ -0,0 +1,166 @@
|
|||
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Utilities for models."""
|
||||
|
||||
import math
|
||||
|
||||
import torch
|
||||
from apex.transformer import parallel_state
|
||||
|
||||
|
||||
def init_method_normal(sigma):
|
||||
"""Init method based on N(0, sigma)."""
|
||||
|
||||
def init_(tensor):
|
||||
return torch.nn.init.normal_(tensor, mean=0.0, std=sigma)
|
||||
|
||||
return init_
|
||||
|
||||
|
||||
def scaled_init_method_normal(sigma, num_layers):
|
||||
"""Init method based on N(0, sigma/sqrt(2*num_layers)."""
|
||||
std = sigma / math.sqrt(2.0 * num_layers)
|
||||
|
||||
def init_(tensor):
|
||||
return torch.nn.init.normal_(tensor, mean=0.0, std=std)
|
||||
|
||||
return init_
|
||||
|
||||
|
||||
def attention_mask_func(attention_scores, attention_mask):
|
||||
attention_scores.masked_fill_(attention_mask, -10000.0)
|
||||
return attention_scores
|
||||
|
||||
|
||||
def get_linear_layer(rows, columns, init_method):
|
||||
"""Simple linear layer with weight initialization."""
|
||||
layer = torch.nn.Linear(rows, columns)
|
||||
init_method(layer.weight)
|
||||
with torch.no_grad():
|
||||
layer.bias.zero_()
|
||||
return layer
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def gelu_impl(x):
|
||||
"""OpenAI's gelu implementation."""
|
||||
return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * x * (1.0 + 0.044715 * x * x)))
|
||||
|
||||
|
||||
def openai_gelu(x):
|
||||
return gelu_impl(x)
|
||||
|
||||
|
||||
# This is actually Python equivalent of torch.nn.functional.gelu(), also with type hints for ONNX exporter
|
||||
@torch.jit.script
|
||||
def erf_gelu(x):
|
||||
return x * 0.5 * (torch.erf(x / 1.41421).to(dtype=x.dtype) + torch.ones_like(x).to(dtype=x.dtype))
|
||||
|
||||
|
||||
def average_losses_across_data_parallel_group(losses):
|
||||
"""Reduce a tensor of losses across all GPUs."""
|
||||
averaged_losses = torch.cat([loss.clone().detach().view(1) for loss in losses])
|
||||
torch.distributed.all_reduce(averaged_losses, group=parallel_state.get_data_parallel_group())
|
||||
averaged_losses = averaged_losses / torch.distributed.get_world_size(
|
||||
group=parallel_state.get_data_parallel_group()
|
||||
)
|
||||
|
||||
return averaged_losses
|
||||
|
||||
|
||||
def get_ltor_masks_and_position_ids(data, eod_token, reset_position_ids, reset_attention_mask, eod_mask_loss):
|
||||
"""Build masks and position id for left to right model."""
|
||||
|
||||
# Extract batch size and sequence length.
|
||||
micro_batch_size, seq_length = data.size()
|
||||
|
||||
# Attention mask (lower triangular).
|
||||
if reset_attention_mask:
|
||||
att_mask_batch = micro_batch_size
|
||||
else:
|
||||
att_mask_batch = 1
|
||||
attention_mask = torch.tril(torch.ones((att_mask_batch, seq_length, seq_length), device=data.device)).view(
|
||||
att_mask_batch, 1, seq_length, seq_length
|
||||
)
|
||||
|
||||
# Loss mask.
|
||||
loss_mask = torch.ones(data.size(), dtype=torch.float, device=data.device)
|
||||
if eod_mask_loss:
|
||||
loss_mask[data == eod_token] = 0.0
|
||||
|
||||
# Position ids.
|
||||
position_ids = torch.arange(seq_length, dtype=torch.long, device=data.device)
|
||||
position_ids = position_ids.unsqueeze(0).expand_as(data)
|
||||
# We need to clone as the ids will be modifed based on batch index.
|
||||
if reset_position_ids:
|
||||
position_ids = position_ids.clone()
|
||||
|
||||
if reset_position_ids or reset_attention_mask:
|
||||
# Loop through the batches:
|
||||
for b in range(micro_batch_size):
|
||||
|
||||
# Find indecies where EOD token is.
|
||||
eod_index = position_ids[b, data[b] == eod_token]
|
||||
# Detach indecies from positions if going to modify positions.
|
||||
if reset_position_ids:
|
||||
eod_index = eod_index.clone()
|
||||
|
||||
# Loop through EOD indicies:
|
||||
prev_index = 0
|
||||
for j in range(eod_index.size()[0]):
|
||||
i = eod_index[j]
|
||||
# Mask attention loss.
|
||||
if reset_attention_mask:
|
||||
attention_mask[b, 0, (i + 1) :, : (i + 1)] = 0
|
||||
# Reset positions.
|
||||
if reset_position_ids:
|
||||
position_ids[b, (i + 1) :] -= i + 1 - prev_index
|
||||
prev_index = i + 1
|
||||
|
||||
# Convert attention mask to binary:
|
||||
attention_mask = attention_mask < 0.5
|
||||
|
||||
return attention_mask, loss_mask, position_ids
|
||||
|
||||
|
||||
class AutocastModuleWrapper(torch.nn.Module):
|
||||
def __init__(self, fp16=False, bf16=False):
|
||||
super(AutocastModuleWrapper, self).__init__()
|
||||
|
||||
# TODO: @titu1994 Maybe useful to make this adaptive - prefer bf16 over fp16 when both flags are set and both dtypes are supported?
|
||||
assert not (fp16 and bf16)
|
||||
if fp16 or bf16:
|
||||
self.precision = torch.float16 if fp16 else torch.bfloat16
|
||||
self.use_autocast = True
|
||||
else:
|
||||
self.use_autocast = False
|
||||
|
||||
# class or function to autocast should be defined
|
||||
self.func = None
|
||||
|
||||
def autocast_forward(self, *args):
|
||||
|
||||
assert self.func is not None
|
||||
if self.use_autocast:
|
||||
if isinstance(self.func, torch.autograd.function.FunctionMeta):
|
||||
fwd = torch.cuda.amp.custom_fwd(cast_inputs=self.precision)(self.func.apply)
|
||||
else:
|
||||
fwd = torch.cuda.amp.custom_fwd(cast_inputs=self.precision)(self.func)
|
||||
return fwd(*args)
|
||||
else:
|
||||
if isinstance(self.func, torch.autograd.function.FunctionMeta):
|
||||
return self.func.apply(*args)
|
||||
else:
|
||||
return self.func(*args)
|
|
@ -31,6 +31,13 @@ from nemo.utils import logging
|
|||
__all__ = ['get_tokenizer', 'get_tokenizer_list']
|
||||
|
||||
|
||||
megatron_tokenizer_model_map = {
|
||||
'BertWordPieceLowerCase': 'megatron-bert-345m-uncased',
|
||||
'BertWordPieceCase': 'megatron-bert-345m-cased',
|
||||
'GPT2BPETokenizer': 'megatron-gpt-345m',
|
||||
}
|
||||
|
||||
|
||||
def get_tokenizer_list() -> List[str]:
|
||||
"""
|
||||
Returns all all supported tokenizer names
|
||||
|
@ -57,6 +64,7 @@ def get_tokenizer(
|
|||
tokenizer_name: str,
|
||||
tokenizer_model: Optional[str] = None,
|
||||
vocab_file: Optional[str] = None,
|
||||
merges_file: Optional[str] = None,
|
||||
special_tokens: Optional[Dict[str, str]] = None,
|
||||
use_fast: Optional[bool] = False,
|
||||
bpe_dropout: Optional[float] = 0.0,
|
||||
|
@ -84,6 +92,9 @@ def get_tokenizer(
|
|||
vocab_file = nemo.collections.nlp.modules.common.megatron.megatron_utils.get_megatron_vocab_file(
|
||||
tokenizer_name
|
||||
)
|
||||
merges_file = nemo.collections.nlp.modules.common.megatron.megatron_utils.get_megatron_merges_file(
|
||||
tokenizer_name
|
||||
)
|
||||
tokenizer_name = get_megatron_tokenizer(tokenizer_name)
|
||||
|
||||
if tokenizer_name == 'sentencepiece':
|
||||
|
@ -101,7 +112,11 @@ def get_tokenizer(
|
|||
f"Getting HuggingFace AutoTokenizer with pretrained_model_name: {tokenizer_name}, vocab_file: {vocab_file}, special_tokens_dict: {special_tokens_dict}, and use_fast: {use_fast}"
|
||||
)
|
||||
return AutoTokenizer(
|
||||
pretrained_model_name=tokenizer_name, vocab_file=vocab_file, **special_tokens_dict, use_fast=use_fast
|
||||
pretrained_model_name=tokenizer_name,
|
||||
vocab_file=vocab_file,
|
||||
merges_file=merges_file,
|
||||
**special_tokens_dict,
|
||||
use_fast=use_fast,
|
||||
)
|
||||
|
||||
|
||||
|
@ -110,6 +125,7 @@ def get_nmt_tokenizer(
|
|||
model_name: Optional[str] = None,
|
||||
tokenizer_model: Optional[str] = None,
|
||||
vocab_file: Optional[str] = None,
|
||||
merges_file: Optional[str] = None,
|
||||
special_tokens: Optional[Dict[str, str]] = None,
|
||||
use_fast: Optional[bool] = False,
|
||||
bpe_dropout: Optional[float] = 0.0,
|
||||
|
@ -138,10 +154,14 @@ def get_nmt_tokenizer(
|
|||
elif library == 'huggingface':
|
||||
logging.info(f'Getting HuggingFace AutoTokenizer with pretrained_model_name: {model_name}')
|
||||
return AutoTokenizer(
|
||||
pretrained_model_name=model_name, vocab_file=vocab_file, **special_tokens_dict, use_fast=use_fast
|
||||
pretrained_model_name=model_name,
|
||||
vocab_file=vocab_file,
|
||||
merges_file=merges_file,
|
||||
**special_tokens_dict,
|
||||
use_fast=use_fast,
|
||||
)
|
||||
elif library == 'sentencepiece':
|
||||
logging.info(f'Getting SentencePiece with model: {model_name}')
|
||||
logging.info(f'Getting SentencePiece with model: {tokenizer_model}')
|
||||
return nemo.collections.common.tokenizers.sentencepiece_tokenizer.SentencePieceTokenizer(
|
||||
model_path=tokenizer_model, special_tokens=special_tokens_dict
|
||||
)
|
||||
|
@ -149,10 +169,12 @@ def get_nmt_tokenizer(
|
|||
logging.info(f'Using byte-level tokenization')
|
||||
return ByteLevelTokenizer()
|
||||
elif library == 'megatron':
|
||||
if model_name in megatron_tokenizer_model_map:
|
||||
model_name = megatron_tokenizer_model_map[model_name]
|
||||
logging.info(
|
||||
f'Getting Megatron tokenizer for pretrained model name: {model_name} and custom vocab file: {vocab_file}'
|
||||
)
|
||||
return get_tokenizer(tokenizer_name=model_name, vocab_file=vocab_file)
|
||||
return get_tokenizer(tokenizer_name=model_name, vocab_file=vocab_file, merges_file=merges_file)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
'Currently we only support "yttm", "huggingface", "sentencepiece", "megatron", and "byte-level" tokenizer libraries.'
|
||||
|
|
|
@ -13,23 +13,34 @@
|
|||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
from megatron import mpu
|
||||
from megatron.checkpointing import get_checkpoint_version, set_checkpoint_version
|
||||
from megatron.initialize import _set_random_seed
|
||||
from pytorch_lightning.core.lightning import LightningModule
|
||||
from apex.transformer import parallel_state
|
||||
from pytorch_lightning.overrides import LightningDistributedModule
|
||||
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
|
||||
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
|
||||
from pytorch_lightning.plugins.precision.native_amp import NativeMixedPrecisionPlugin
|
||||
from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin
|
||||
from pytorch_lightning.plugins.training_type.ddp import DDPPlugin
|
||||
from pytorch_lightning.trainer.connectors.checkpoint_connector import CheckpointConnector
|
||||
from pytorch_lightning.trainer.trainer import Trainer
|
||||
from pytorch_lightning.utilities import rank_zero_warn
|
||||
from pytorch_lightning.utilities.cloud_io import atomic_save
|
||||
from pytorch_lightning.utilities.enums import GradClipAlgorithmType
|
||||
from torch.nn.modules.module import Module
|
||||
from torch.nn.parallel import DistributedDataParallel
|
||||
from torch.optim.optimizer import Optimizer
|
||||
|
||||
from nemo.collections.nlp.modules.common.megatron.megatron_bert import MegatronBertEncoder
|
||||
from nemo.collections.nlp.modules.common.megatron.megatron_encoder import MegatronEncoderModule
|
||||
from nemo.collections.nlp.modules.common.megatron.clip_grads import clip_grad_norm_fp32
|
||||
from nemo.collections.nlp.modules.common.megatron.megatron_bert import (
|
||||
get_megatron_checkpoint_version,
|
||||
set_megatron_checkpoint_version,
|
||||
)
|
||||
from nemo.core.connectors.save_restore_connector import SaveRestoreConnector
|
||||
from nemo.utils import AppState, logging
|
||||
|
||||
|
||||
|
@ -45,21 +56,22 @@ class NLPDDPPlugin(DDPPlugin):
|
|||
num_nodes: int = 1,
|
||||
cluster_environment: ClusterEnvironment = None,
|
||||
sync_batchnorm: bool = False,
|
||||
checkpoint_io: Optional[CheckpointIO] = None,
|
||||
**kwargs: Union[Any, Dict[str, Any]],
|
||||
) -> None:
|
||||
super().__init__(parallel_devices, num_nodes, cluster_environment, sync_batchnorm, **kwargs)
|
||||
super().__init__(parallel_devices, num_nodes, cluster_environment, checkpoint_io, sync_batchnorm, **kwargs)
|
||||
|
||||
def init_ddp_connection(self, global_rank: int = None, world_size: int = None) -> None:
|
||||
def setup_distributed(self, global_rank: int = None, world_size: int = None) -> None:
|
||||
# call PTL init ddp
|
||||
super().init_ddp_connection()
|
||||
super().setup_distributed()
|
||||
|
||||
# init model parallel if needed
|
||||
app_state = AppState()
|
||||
|
||||
if app_state.model_parallel_size is not None:
|
||||
|
||||
if self.lightning_module.has_megatron_encoder and not self.lightning_module.is_model_parallel_initialized:
|
||||
self.init_model_parallel(app_state.global_rank, app_state.world_size)
|
||||
self.init_model_parallel(app_state.global_rank, app_state.world_size)
|
||||
# if self.lightning_module.has_megatron_encoder and not self.lightning_module.is_model_parallel_initialized:
|
||||
# self.init_model_parallel(app_state.global_rank, app_state.world_size)
|
||||
|
||||
def start_training(self, trainer: 'Trainer') -> None:
|
||||
""" PTL Hook that is called after DPP is initialized. """
|
||||
|
@ -73,7 +85,7 @@ class NLPDDPPlugin(DDPPlugin):
|
|||
if not hasattr(p, 'model_parallel'):
|
||||
p.model_parallel = False
|
||||
|
||||
if get_checkpoint_version() is not None:
|
||||
if get_megatron_checkpoint_version() is not None:
|
||||
# megatron checkpoint already restored
|
||||
pass
|
||||
elif trainer.checkpoint_connector.resume_checkpoint_path is not None:
|
||||
|
@ -92,14 +104,14 @@ class NLPDDPPlugin(DDPPlugin):
|
|||
'checkpoint_version', None
|
||||
)
|
||||
if checkpoint_version is not None:
|
||||
set_checkpoint_version(checkpoint_version)
|
||||
set_megatron_checkpoint_version(checkpoint_version)
|
||||
else:
|
||||
logging.warning('Megatron-lm checkpoint version not found. Setting checkpoint_version to 0.')
|
||||
set_checkpoint_version(0)
|
||||
set_megatron_checkpoint_version(0)
|
||||
else:
|
||||
self.lightning_module.restore_megatron_encoder_weights()
|
||||
else:
|
||||
if get_checkpoint_version() is not None:
|
||||
if get_megatron_checkpoint_version() is not None:
|
||||
# megatron checkpoint already restored
|
||||
pass
|
||||
else:
|
||||
|
@ -117,7 +129,7 @@ class NLPDDPPlugin(DDPPlugin):
|
|||
|
||||
if self.has_megatron_encoder:
|
||||
# check megatron checkpoint version
|
||||
checkpoint_version = get_checkpoint_version()
|
||||
checkpoint_version = get_megatron_checkpoint_version()
|
||||
if checkpoint_version is None:
|
||||
raise ValueError("Unable to find megatron checkpoint version.")
|
||||
|
||||
|
@ -125,7 +137,7 @@ class NLPDDPPlugin(DDPPlugin):
|
|||
|
||||
def configure_ddp(self):
|
||||
""" Override LightningModule ddp if using model parallel.
|
||||
Sets find_unused_parameters to True.
|
||||
Sets find_unused_parameters to False to use activation-checkpoint-recomputation.
|
||||
"""
|
||||
|
||||
app_state = AppState()
|
||||
|
@ -142,7 +154,7 @@ class NLPDDPPlugin(DDPPlugin):
|
|||
device_ids=device_ids,
|
||||
output_device=device_ids[0],
|
||||
process_group=app_state.data_parallel_group,
|
||||
find_unused_parameters=True,
|
||||
find_unused_parameters=False,
|
||||
**self._ddp_kwargs,
|
||||
)
|
||||
|
||||
|
@ -163,17 +175,35 @@ class NLPDDPPlugin(DDPPlugin):
|
|||
# after initializing DDP with PTL.
|
||||
if app_state.model_parallel_size is not None:
|
||||
if torch.distributed.is_initialized():
|
||||
mpu.initialize_model_parallel(app_state.model_parallel_size)
|
||||
app_state.model_parallel_group = mpu.get_model_parallel_group()
|
||||
app_state.data_parallel_group = mpu.get_data_parallel_group()
|
||||
app_state.model_parallel_rank = mpu.get_tensor_model_parallel_rank()
|
||||
app_state.data_parallel_rank = mpu.get_data_parallel_rank()
|
||||
app_state.data_parallel_size = mpu.get_data_parallel_world_size()
|
||||
parallel_state.initialize_model_parallel(app_state.model_parallel_size)
|
||||
app_state.model_parallel_group = parallel_state.get_tensor_model_parallel_group()
|
||||
app_state.data_parallel_group = parallel_state.get_data_parallel_group()
|
||||
app_state.model_parallel_rank = parallel_state.get_tensor_model_parallel_rank()
|
||||
app_state.data_parallel_rank = parallel_state.get_data_parallel_rank()
|
||||
app_state.data_parallel_size = parallel_state.get_data_parallel_world_size()
|
||||
logging.info(f'mp_rank: {app_state.model_parallel_rank}')
|
||||
logging.info(f'dp_rank: {app_state.data_parallel_rank}')
|
||||
seed = os.environ.get("PL_GLOBAL_SEED", 1234)
|
||||
# random seed must be set for megatron model parallel init
|
||||
_set_random_seed(seed)
|
||||
|
||||
def save_checkpoint(self, checkpoint: Dict[str, Any], filepath: str) -> None:
|
||||
"""Save model/training states as a checkpoint file through state-dump and file-write.
|
||||
|
||||
Args:
|
||||
checkpoint: dict containing model and trainer state
|
||||
filepath: write-target file's path
|
||||
"""
|
||||
app_state = AppState()
|
||||
# dump states as a checkpoint dictionary object
|
||||
# TrainingTypePlugin.on_save() just seems to return the same thing.
|
||||
# checkpoint = self.on_save(checkpoint)
|
||||
if self.is_global_zero or app_state.data_parallel_rank == 0:
|
||||
try:
|
||||
# write the checkpoint dictionary on the file
|
||||
atomic_save(checkpoint, filepath)
|
||||
except AttributeError as err:
|
||||
key = pl.LightningModule.CHECKPOINT_HYPER_PARAMS_KEY
|
||||
checkpoint.pop(key, None)
|
||||
rank_zero_warn(f"Warning, `{key}` dropped from checkpoint. An attribute is not picklable: {err}")
|
||||
atomic_save(checkpoint, filepath)
|
||||
|
||||
@property
|
||||
def distributed_sampler_kwargs(self):
|
||||
|
@ -195,18 +225,17 @@ class NLPCheckpointConnector(CheckpointConnector):
|
|||
""" Override PTL CheckpointConnector to support model parallel checkpoints from Megatron-LM.
|
||||
"""
|
||||
|
||||
def __init__(self, trainer):
|
||||
super().__init__(trainer)
|
||||
def __init__(self, trainer, resume_from_checkpoint):
|
||||
super().__init__(trainer, resume_from_checkpoint)
|
||||
|
||||
def save_checkpoint(self, filepath, weights_only: bool):
|
||||
def save_checkpoint(self, filepath, weights_only: bool = False) -> None:
|
||||
"""Slightly modified version of PyTorch Lightning's save_checkpoint.
|
||||
Accounts for model parallel training.
|
||||
Save model/training states as a checkpoint file through state-dump and file-write.
|
||||
|
||||
Args:
|
||||
filepath ([str]): [description]
|
||||
weights_only (bool): [description]
|
||||
|
||||
Returns:
|
||||
[type]: [description]
|
||||
filepath: write-target file's path
|
||||
weights_only: saving model weights only
|
||||
"""
|
||||
app_state = AppState()
|
||||
if app_state.model_parallel_size is not None:
|
||||
|
@ -214,22 +243,137 @@ class NLPCheckpointConnector(CheckpointConnector):
|
|||
dirname = os.path.dirname(filepath)
|
||||
basename = os.path.basename(filepath)
|
||||
filepath = f'{dirname}/mp_rank_{app_state.model_parallel_rank:02d}/{basename}'
|
||||
|
||||
# dump states as a checkpoint dictionary object
|
||||
checkpoint = self.dump_checkpoint(weights_only)
|
||||
|
||||
_checkpoint = self.dump_checkpoint(weights_only)
|
||||
# each model parallel rank needs to save a copy of its model
|
||||
if app_state.data_parallel_rank == 0:
|
||||
# write the checkpoint dictionary on the file
|
||||
if self.trainer.accelerator_backend:
|
||||
checkpoint = self.trainer.accelerator_backend.on_save(checkpoint)
|
||||
try:
|
||||
atomic_save(checkpoint, filepath)
|
||||
except AttributeError as err:
|
||||
if LightningModule.CHECKPOINT_HYPER_PARAMS_KEY in checkpoint:
|
||||
del checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY]
|
||||
rank_zero_warn(
|
||||
'Warning, `hyper_parameters` dropped from checkpoint.' f' An attribute is not picklable {err}'
|
||||
)
|
||||
atomic_save(checkpoint, filepath)
|
||||
return None
|
||||
self.trainer.accelerator.save_checkpoint(_checkpoint, filepath)
|
||||
else:
|
||||
super().save_checkpoint(filepath, weights_only)
|
||||
|
||||
|
||||
class NLPSaveRestoreConnector(SaveRestoreConnector):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def save_to(self, model, save_path: str):
|
||||
app_state = AppState()
|
||||
if app_state.model_parallel_size is not None and app_state.model_parallel_size > 1:
|
||||
|
||||
dir_name = os.path.dirname(save_path)
|
||||
|
||||
# first we save the weights for each model parallel rank
|
||||
if app_state.data_parallel_rank == 0:
|
||||
mp_model_weights = os.path.join(
|
||||
dir_name, f'mp_rank_{app_state.model_parallel_rank:02d}_' + self.model_weights_ckpt
|
||||
)
|
||||
self._save_state_dict_to_disk(model.state_dict(), mp_model_weights)
|
||||
|
||||
torch.distributed.barrier()
|
||||
|
||||
# create nemo file from folder with all mp_ranks checkpoints
|
||||
if app_state.model_parallel_rank == 0 and app_state.data_parallel_rank == 0:
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
|
||||
# move weights to the tmpdir
|
||||
for mp_rank in range(app_state.model_parallel_size):
|
||||
os.makedirs(os.path.join(tmpdir, f'mp_rank_{mp_rank:02d}'))
|
||||
mp_model_weights = os.path.join(dir_name, f'mp_rank_{mp_rank:02d}_' + self.model_weights_ckpt)
|
||||
shutil.move(
|
||||
mp_model_weights, os.path.join(tmpdir, f'mp_rank_{mp_rank:02d}', self.model_weights_ckpt)
|
||||
)
|
||||
|
||||
# create config and artifacts in tmpdir
|
||||
config_yaml = os.path.join(tmpdir, self.model_config_yaml)
|
||||
model.to_config_file(path2yaml_file=config_yaml)
|
||||
if hasattr(model, 'artifacts') and model.artifacts is not None:
|
||||
self._handle_artifacts(model, nemo_file_folder=tmpdir)
|
||||
self._update_artifact_paths(model, path2yaml_file=config_yaml)
|
||||
|
||||
# create tar file
|
||||
self._make_nemo_file_from_folder(save_path, tmpdir)
|
||||
|
||||
else:
|
||||
return super().save_to(model, save_path)
|
||||
|
||||
|
||||
class NLPNativeMixedPrecisionPlugin(NativeMixedPrecisionPlugin):
|
||||
def __init__(self, init_scale: float = 2 ** 32, growth_interval: int = 1000) -> None:
|
||||
super().__init__(precision=16)
|
||||
|
||||
self.scaler = torch.cuda.amp.GradScaler(init_scale=init_scale, growth_interval=growth_interval)
|
||||
|
||||
def clip_gradients(
|
||||
self,
|
||||
optimizer: Optimizer,
|
||||
clip_val: Union[int, float],
|
||||
gradient_clip_algorithm: GradClipAlgorithmType,
|
||||
model: Optional[Module],
|
||||
) -> None:
|
||||
"""Override PTL gradient clipping.
|
||||
Do nothing because we've already clipped gradients in `on_before_optimizer_step` hook.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class NLPNativeBfloat16PrecisionPlugin(NativeMixedPrecisionPlugin):
|
||||
def __init__(self) -> None:
|
||||
super().__init__(precision='bf16')
|
||||
|
||||
def clip_gradients(
|
||||
self,
|
||||
optimizer: Optimizer,
|
||||
clip_val: Union[int, float],
|
||||
gradient_clip_algorithm: GradClipAlgorithmType,
|
||||
model: Optional[Module],
|
||||
) -> None:
|
||||
"""Override PTL gradient clipping.
|
||||
Model parallel models require gradient clipping from megatron-lm.
|
||||
"""
|
||||
|
||||
if clip_val is None:
|
||||
return
|
||||
|
||||
clip_val = float(clip_val)
|
||||
if clip_val <= 0:
|
||||
return
|
||||
|
||||
app_state = AppState()
|
||||
if app_state.model_parallel_size is not None:
|
||||
parameters = model.parameters()
|
||||
clip_grad_norm_fp32(parameters=parameters, max_norm=clip_val)
|
||||
else:
|
||||
return super().clip_gradients(
|
||||
optimizer, clip_val, gradient_clip_algorithm=gradient_clip_algorithm, model=model
|
||||
)
|
||||
|
||||
|
||||
class NLPPrecisionPlugin(PrecisionPlugin):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def clip_gradients(
|
||||
self,
|
||||
optimizer: Optimizer,
|
||||
clip_val: Union[int, float],
|
||||
gradient_clip_algorithm: GradClipAlgorithmType,
|
||||
model: Optional[Module],
|
||||
) -> None:
|
||||
"""Override PTL gradient clipping.
|
||||
Model parallel models require gradient clipping from megatron-lm.
|
||||
"""
|
||||
|
||||
if clip_val is None:
|
||||
return
|
||||
|
||||
clip_val = float(clip_val)
|
||||
if clip_val <= 0:
|
||||
return
|
||||
|
||||
app_state = AppState()
|
||||
if app_state.model_parallel_size is not None:
|
||||
parameters = model.parameters()
|
||||
clip_grad_norm_fp32(parameters=parameters, max_norm=clip_val)
|
||||
else:
|
||||
return super().clip_gradients(
|
||||
optimizer, clip_val, gradient_clip_algorithm=gradient_clip_algorithm, model=model
|
||||
)
|
||||
|
|
|
@ -150,7 +150,7 @@ class FastSpeech2Model(SpectrogramGenerator):
|
|||
self.log(name="train_loss", value=total_loss)
|
||||
return {"loss": total_loss, "outputs": [spec, mel]}
|
||||
|
||||
def training_epoch_end(self, training_step_outputs):
|
||||
def training_epoch_end(self, outputs):
|
||||
if self.log_train_images and self.logger is not None and self.logger.experiment is not None:
|
||||
tb_logger = self.logger.experiment
|
||||
if isinstance(self.logger, LoggerCollection):
|
||||
|
@ -158,7 +158,7 @@ class FastSpeech2Model(SpectrogramGenerator):
|
|||
if isinstance(logger, TensorBoardLogger):
|
||||
tb_logger = logger.experiment
|
||||
break
|
||||
spec_target, spec_predict = training_step_outputs[0]["outputs"]
|
||||
spec_target, spec_predict = outputs[0]["outputs"]
|
||||
tb_logger.add_image(
|
||||
"train_mel_target",
|
||||
plot_spectrogram_to_numpy(spec_target[0].data.cpu().numpy()),
|
||||
|
@ -171,6 +171,8 @@ class FastSpeech2Model(SpectrogramGenerator):
|
|||
)
|
||||
self.log_train_images = False
|
||||
|
||||
return super().training_epoch_end(outputs)
|
||||
|
||||
def validation_step(self, batch, batch_idx):
|
||||
f, fl, t, tl, _, _, _ = batch
|
||||
spec, spec_len = self.audio_to_melspec_preprocessor(f, fl)
|
||||
|
|
|
@ -126,7 +126,7 @@ class HifiGanModel(Vocoder, Exportable):
|
|||
def convert_spectrogram_to_audio(self, spec: 'torch.tensor') -> 'torch.tensor':
|
||||
return self(spec=spec).squeeze(1)
|
||||
|
||||
def training_step(self, batch, batch_idx, optimizer_idx):
|
||||
def training_step(self, batch, batch_idx):
|
||||
# if in finetune mode the mels are pre-computed using a
|
||||
# spectrogram generator
|
||||
if self.input_as_mel:
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
|
||||
"""Interfaces common to all Neural Modules and Models."""
|
||||
import hashlib
|
||||
import inspect
|
||||
import traceback
|
||||
from abc import ABC, abstractmethod
|
||||
from contextlib import contextmanager
|
||||
|
@ -429,7 +430,7 @@ class Typing(ABC):
|
|||
|
||||
class Serialization(ABC):
|
||||
@classmethod
|
||||
def from_config_dict(cls, config: 'DictConfig'):
|
||||
def from_config_dict(cls, config: 'DictConfig', trainer: Optional['Trainer'] = None):
|
||||
"""Instantiates object using DictConfig-based configuration"""
|
||||
# Resolve the config dict
|
||||
if _HAS_HYDRA:
|
||||
|
@ -470,7 +471,12 @@ class Serialization(ABC):
|
|||
imported_cls = cls
|
||||
|
||||
try:
|
||||
instance = imported_cls(cfg=config)
|
||||
accepts_trainer = Serialization._inspect_signature_for_trainer(imported_cls)
|
||||
if accepts_trainer:
|
||||
instance = imported_cls(cfg=config, trainer=trainer)
|
||||
else:
|
||||
instance = imported_cls(cfg=config)
|
||||
|
||||
except Exception as e:
|
||||
imported_cls_tb = traceback.format_exc()
|
||||
instance_init_error = str(e)
|
||||
|
@ -484,8 +490,14 @@ class Serialization(ABC):
|
|||
f"Falling back to `cls`.\n"
|
||||
f"{imported_cls_tb}"
|
||||
)
|
||||
instance = cls(cfg=config, trainer=trainer)
|
||||
try:
|
||||
instance = cls(cfg=config)
|
||||
accepts_trainer = Serialization._inspect_signature_for_trainer(cls)
|
||||
if accepts_trainer:
|
||||
instance = cls(cfg=config, trainer=trainer)
|
||||
else:
|
||||
instance = cls(cfg=config)
|
||||
|
||||
except Exception as e:
|
||||
if imported_cls_tb is not None:
|
||||
logging.error(f"Instance failed restore_from due to: {instance_init_error}")
|
||||
|
@ -515,6 +527,17 @@ class Serialization(ABC):
|
|||
'to_config_dict() can currently only return object._cfg but current object does not have it.'
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _inspect_signature_for_trainer(cls, check_cls):
|
||||
if hasattr(check_cls, '__init__'):
|
||||
signature = inspect.signature(check_cls.__init__)
|
||||
if 'trainer' in signature.parameters:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
class FileIO(ABC):
|
||||
def save_to(self, save_path: str):
|
||||
|
@ -529,6 +552,8 @@ class FileIO(ABC):
|
|||
map_location: Optional['torch.device'] = None,
|
||||
strict: bool = True,
|
||||
return_config: bool = False,
|
||||
trainer: Optional['Trainer'] = None,
|
||||
save_restore_connector: SaveRestoreConnector = None,
|
||||
):
|
||||
"""Restores module/model with weights"""
|
||||
raise NotImplementedError()
|
||||
|
@ -644,6 +669,7 @@ class Model(Typing, Serialization, FileIO):
|
|||
map_location: Optional['torch.device'] = None,
|
||||
strict: bool = True,
|
||||
return_config: bool = False,
|
||||
trainer: Optional['Trainer'] = None,
|
||||
save_restore_connector: SaveRestoreConnector = None,
|
||||
):
|
||||
"""
|
||||
|
@ -708,6 +734,7 @@ class Model(Typing, Serialization, FileIO):
|
|||
map_location=map_location,
|
||||
strict=strict,
|
||||
return_config=return_config,
|
||||
trainer=trainer,
|
||||
save_restore_connector=save_restore_connector,
|
||||
)
|
||||
return instance
|
||||
|
|
|
@ -201,7 +201,6 @@ class ModelPT(LightningModule, Model):
|
|||
|
||||
return self._save_restore_connector.register_artifact(self, config_path, src, verify_src_exists)
|
||||
|
||||
@rank_zero_only
|
||||
def save_to(self, save_path: str):
|
||||
"""
|
||||
Saves model instance (weights and configuration) into .nemo file
|
||||
|
@ -215,12 +214,24 @@ class ModelPT(LightningModule, Model):
|
|||
save_path: Path to .nemo file where model instance should be saved
|
||||
"""
|
||||
|
||||
app_state = AppState()
|
||||
# Add NeMo rank check as well
|
||||
if not is_global_rank_zero():
|
||||
return
|
||||
else:
|
||||
if app_state.model_parallel_size is not None:
|
||||
if app_state.model_parallel_size > 1:
|
||||
if isinstance(self._save_restore_connector, SaveRestoreConnector):
|
||||
raise ValueError(
|
||||
'Default NeMo SaveRestoreConnector will not work in model parallel mode. You should use a connector which supports model parallel mode, such as NLPSaveRestoreConnector in NLP. You can also you custom one.'
|
||||
)
|
||||
|
||||
save_path = os.path.abspath(os.path.expanduser(save_path))
|
||||
# connector checks for ranks properly, no need to check here
|
||||
self._save_restore_connector.save_to(self, save_path)
|
||||
else:
|
||||
if not is_global_rank_zero():
|
||||
return
|
||||
else:
|
||||
save_path = os.path.abspath(os.path.expanduser(save_path))
|
||||
self._save_restore_connector.save_to(self, save_path)
|
||||
|
||||
@classmethod
|
||||
def restore_from(
|
||||
|
@ -231,6 +242,7 @@ class ModelPT(LightningModule, Model):
|
|||
strict: bool = True,
|
||||
return_config: bool = False,
|
||||
save_restore_connector: SaveRestoreConnector = None,
|
||||
trainer: Optional[Trainer] = None,
|
||||
):
|
||||
"""
|
||||
Restores model instance (weights and configuration) from .nemo file.
|
||||
|
@ -244,6 +256,8 @@ class ModelPT(LightningModule, Model):
|
|||
strict: Passed to load_state_dict. By default True.
|
||||
return_config: If set to true, will return just the underlying config of the restored
|
||||
model as an OmegaConf DictConfig object without instantiating the model.
|
||||
trainer: Optional, a pytorch lightning Trainer object that will be forwarded to the
|
||||
instantiated model's constructor.
|
||||
save_restore_connector (SaveRestoreConnector): Can be overrided to add custom save and restore logic.
|
||||
|
||||
Example:
|
||||
|
@ -268,7 +282,7 @@ class ModelPT(LightningModule, Model):
|
|||
|
||||
cls.update_save_restore_connector(save_restore_connector)
|
||||
instance = cls._save_restore_connector.restore_from(
|
||||
cls, restore_path, override_config_path, map_location, strict, return_config
|
||||
cls, restore_path, override_config_path, map_location, strict, return_config, trainer
|
||||
)
|
||||
if isinstance(instance, ModelPT):
|
||||
instance._save_restore_connector = save_restore_connector
|
||||
|
|
|
@ -48,6 +48,7 @@ class TrainerConfig:
|
|||
tpu_cores: Optional[Any] = None
|
||||
log_gpu_memory: Optional[str] = None
|
||||
progress_bar_refresh_rate: int = 1
|
||||
enable_progress_bar: bool = True
|
||||
overfit_batches: Any = 0.0
|
||||
track_grad_norm: Any = -1
|
||||
check_val_every_n_epoch: int = 1
|
||||
|
@ -65,11 +66,10 @@ class TrainerConfig:
|
|||
log_every_n_steps: int = 50
|
||||
accelerator: Optional[str] = None
|
||||
sync_batchnorm: bool = False
|
||||
precision: int = 32
|
||||
precision: Any = 32
|
||||
weights_summary: Optional[str] = "full" # ModelSummary.MODE_DEFAULT
|
||||
weights_save_path: Optional[str] = None
|
||||
num_sanity_val_steps: int = 2
|
||||
truncated_bptt_steps: Optional[int] = None
|
||||
resume_from_checkpoint: Optional[str] = None
|
||||
profiler: Optional[Any] = None
|
||||
benchmark: bool = False
|
||||
|
@ -77,11 +77,12 @@ class TrainerConfig:
|
|||
reload_dataloaders_every_epoch: bool = False
|
||||
auto_lr_find: Any = False
|
||||
replace_sampler_ddp: bool = True
|
||||
detect_anomaly: bool = False
|
||||
terminate_on_nan: bool = False
|
||||
auto_scale_batch_size: Any = False
|
||||
prepare_data_per_node: bool = True
|
||||
amp_backend: str = 'native'
|
||||
amp_level: str = 'O2' # backward compatible, todo: remove in v1.0.0
|
||||
amp_level: Optional[str] = None
|
||||
plugins: Optional[Any] = None # Optional[Union[str, list]]
|
||||
move_metrics_to_cpu: bool = False
|
||||
multiple_trainloader_mode: str = 'max_size_cycle'
|
||||
|
|
|
@ -51,6 +51,18 @@ class WarmupHoldSchedulerParams(WarmupSchedulerParams):
|
|||
min_lr: float = 0.0
|
||||
|
||||
|
||||
@dataclass
|
||||
class WarmupAnnealingHoldSchedulerParams(WarmupSchedulerParams):
|
||||
"""
|
||||
Base configuration for all schedulers.
|
||||
It is not derived from Config as it is not a NeMo object (and in particular it doesn't need a name).
|
||||
"""
|
||||
|
||||
constant_steps: Optional[float] = None
|
||||
constant_ratio: Optional[float] = None
|
||||
min_lr: float = 0.0
|
||||
|
||||
|
||||
@dataclass
|
||||
class SquareAnnealingParams(WarmupSchedulerParams):
|
||||
"""
|
||||
|
@ -72,7 +84,7 @@ class SquareRootAnnealingParams(WarmupSchedulerParams):
|
|||
|
||||
|
||||
@dataclass
|
||||
class CosineAnnealingParams(WarmupSchedulerParams):
|
||||
class CosineAnnealingParams(WarmupAnnealingHoldSchedulerParams):
|
||||
"""
|
||||
Cosine Annealing parameter config
|
||||
It is not derived from Config as it is not a NeMo object (and in particular it doesn't need a name).
|
||||
|
@ -98,7 +110,7 @@ class WarmupAnnealingParams(WarmupSchedulerParams):
|
|||
It is not derived from Config as it is not a NeMo object (and in particular it doesn't need a name).
|
||||
"""
|
||||
|
||||
warmup_ratio: 0.0
|
||||
warmup_ratio: float = 0.0
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -240,6 +252,7 @@ AVAILABLE_SCHEDULER_PARAMS = {
|
|||
'SchedulerParams': SchedulerParams,
|
||||
'WarmupPolicyParams': WarmupSchedulerParams,
|
||||
'WarmupHoldPolicyParams': WarmupHoldSchedulerParams,
|
||||
'WarmupAnnealingHoldSchedulerParams': WarmupAnnealingHoldSchedulerParams,
|
||||
'SquareAnnealingParams': SquareAnnealingParams,
|
||||
'SquareRootAnnealingParams': SquareRootAnnealingParams,
|
||||
'InverseSquareRootAnnealingParams': InverseSquareRootAnnealingParams,
|
||||
|
|
|
@ -23,9 +23,11 @@ from typing import Optional, Union
|
|||
import torch
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
from omegaconf.omegaconf import open_dict
|
||||
from pytorch_lightning.trainer.trainer import Trainer
|
||||
|
||||
from nemo.utils import logging, model_utils
|
||||
from nemo.utils.app_state import AppState
|
||||
from nemo.utils.get_rank import is_global_rank_zero
|
||||
|
||||
|
||||
class SaveRestoreConnector:
|
||||
|
@ -47,16 +49,19 @@ class SaveRestoreConnector:
|
|||
save_path: Path to .nemo file where model instance should be saved
|
||||
"""
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
config_yaml = os.path.join(tmpdir, self.model_config_yaml)
|
||||
model_weights = os.path.join(tmpdir, self.model_weights_ckpt)
|
||||
model.to_config_file(path2yaml_file=config_yaml)
|
||||
if hasattr(model, 'artifacts') and model.artifacts is not None:
|
||||
self._handle_artifacts(model, nemo_file_folder=tmpdir)
|
||||
# We should not update self._cfg here - the model can still be in use
|
||||
self._update_artifact_paths(model, path2yaml_file=config_yaml)
|
||||
self._save_state_dict_to_disk(model.state_dict(), model_weights)
|
||||
self._make_nemo_file_from_folder(filename=save_path, source_dir=tmpdir)
|
||||
if is_global_rank_zero():
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
config_yaml = os.path.join(tmpdir, self.model_config_yaml)
|
||||
model_weights = os.path.join(tmpdir, self.model_weights_ckpt)
|
||||
model.to_config_file(path2yaml_file=config_yaml)
|
||||
if hasattr(model, 'artifacts') and model.artifacts is not None:
|
||||
self._handle_artifacts(model, nemo_file_folder=tmpdir)
|
||||
# We should not update self._cfg here - the model can still be in use
|
||||
self._update_artifact_paths(model, path2yaml_file=config_yaml)
|
||||
self._save_state_dict_to_disk(model.state_dict(), model_weights)
|
||||
self._make_nemo_file_from_folder(filename=save_path, source_dir=tmpdir)
|
||||
else:
|
||||
return
|
||||
|
||||
def restore_from(
|
||||
self,
|
||||
|
@ -66,6 +71,7 @@ class SaveRestoreConnector:
|
|||
map_location: Optional[torch.device] = None,
|
||||
strict: bool = True,
|
||||
return_config: bool = False,
|
||||
trainer: Trainer = None,
|
||||
):
|
||||
"""
|
||||
Restores model instance (weights and configuration) into .nemo file
|
||||
|
@ -125,7 +131,7 @@ class SaveRestoreConnector:
|
|||
return instance
|
||||
else:
|
||||
app_state = AppState()
|
||||
if app_state.model_parallel_rank is not None:
|
||||
if app_state.model_parallel_rank is not None and app_state.model_parallel_size > 1:
|
||||
model_weights = self._inject_model_parallel_rank_for_ckpt(tmpdir, self.model_weights_ckpt)
|
||||
else:
|
||||
model_weights = os.path.join(tmpdir, self.model_weights_ckpt)
|
||||
|
@ -133,7 +139,7 @@ class SaveRestoreConnector:
|
|||
os.chdir(cwd)
|
||||
# get the class
|
||||
calling_cls._set_model_restore_state(is_being_restored=True, folder=tmpdir)
|
||||
instance = calling_cls.from_config_dict(config=conf)
|
||||
instance = calling_cls.from_config_dict(config=conf, trainer=trainer)
|
||||
instance = instance.to(map_location)
|
||||
# add load_state_dict override
|
||||
instance.load_state_dict(
|
||||
|
@ -369,6 +375,8 @@ class SaveRestoreConnector:
|
|||
|
||||
@staticmethod
|
||||
def _make_nemo_file_from_folder(filename, source_dir):
|
||||
dirname = os.path.dirname(filename)
|
||||
os.makedirs(dirname, exist_ok=True)
|
||||
with tarfile.open(filename, "w:gz") as tar:
|
||||
tar.add(source_dir, arcname=".")
|
||||
|
||||
|
|
|
@ -160,6 +160,84 @@ class WarmupHoldPolicy(WarmupPolicy):
|
|||
return self._get_lr(step)
|
||||
|
||||
|
||||
class WarmupAnnealHoldPolicy(_LRScheduler):
|
||||
"""Adds warmup kwargs and warmup logic to lr policy.
|
||||
All arguments should be passed as kwargs for clarity,
|
||||
Args:
|
||||
warmup_steps: Number of training steps in warmup stage
|
||||
warmup_ratio: Ratio of warmup steps to total steps
|
||||
max_steps: Total number of steps while training or `None` for
|
||||
infinite training
|
||||
min_lr: Minimum lr to hold the learning rate after decay at.
|
||||
constant_steps: Number of steps to keep lr constant at.
|
||||
constant_ratio: Ratio of steps to keep lr constant.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
optimizer,
|
||||
*,
|
||||
warmup_steps=None,
|
||||
warmup_ratio=None,
|
||||
constant_steps=None,
|
||||
constant_ratio=None,
|
||||
max_steps=None,
|
||||
min_lr=0.0,
|
||||
last_epoch=-1,
|
||||
):
|
||||
assert not (
|
||||
warmup_steps is not None and warmup_ratio is not None
|
||||
), "Either use particular number of step or ratio"
|
||||
assert warmup_ratio is None or max_steps is not None, "If there is a ratio, there should be a total steps"
|
||||
|
||||
# It is necessary to assign all attributes *before* __init__,
|
||||
# as class is wrapped by an inner class.
|
||||
self.max_steps = max_steps
|
||||
|
||||
if warmup_steps is not None:
|
||||
self.warmup_steps = warmup_steps
|
||||
elif warmup_ratio is not None:
|
||||
self.warmup_steps = int(warmup_ratio * max_steps)
|
||||
else:
|
||||
self.warmup_steps = 0
|
||||
|
||||
if constant_steps is not None:
|
||||
self.constant_steps = constant_steps
|
||||
self.decay_steps = max_steps - (self.constant_steps + self.warmup_steps)
|
||||
assert self.decay_steps > 0
|
||||
elif constant_ratio is not None:
|
||||
self.constant_steps = int(constant_ratio * max_steps)
|
||||
self.decay_steps = max_steps - (self.constant_steps + self.warmup_steps)
|
||||
assert self.decay_steps > 0
|
||||
else:
|
||||
self.constant_steps = 0
|
||||
self.decay_steps = max_steps - self.warmup_steps
|
||||
|
||||
self.min_lr = min_lr
|
||||
super().__init__(optimizer, last_epoch)
|
||||
|
||||
def get_lr(self):
|
||||
if not self._get_lr_called_within_step:
|
||||
warnings.warn(
|
||||
"To get the last learning rate computed by the scheduler, please use `get_last_lr()`.", UserWarning
|
||||
)
|
||||
|
||||
step = self.last_epoch
|
||||
|
||||
if step <= self.warmup_steps:
|
||||
lr_val = (step + 1) / (self.warmup_steps + 1)
|
||||
return [initial_lr * lr_val for initial_lr in self.base_lrs]
|
||||
|
||||
if step > self.max_steps:
|
||||
return [self.min_lr for _ in self.base_lrs]
|
||||
|
||||
return self._get_lr(step)
|
||||
|
||||
def _get_lr(self, step):
|
||||
"""Simple const lr policy"""
|
||||
return self.base_lrs
|
||||
|
||||
|
||||
def _squareroot_annealing(initial_lr, step, max_steps, min_lr):
|
||||
mult = ((max_steps - step) / max_steps) ** 0.5
|
||||
out_lr = initial_lr * mult
|
||||
|
@ -180,6 +258,30 @@ def _cosine_annealing(initial_lr, step, max_steps, min_lr):
|
|||
return out_lr
|
||||
|
||||
|
||||
def _linear_warmup_with_cosine_annealing(max_lr, warmup_steps, step, decay_steps, min_lr):
|
||||
|
||||
assert max_lr > min_lr
|
||||
# Use linear warmup for the initial part.
|
||||
if warmup_steps > 0 and step <= warmup_steps:
|
||||
return max_lr * float(step) / float(warmup_steps)
|
||||
|
||||
# For any steps larger than `decay_steps`, use `min_lr`.
|
||||
if step > warmup_steps + decay_steps:
|
||||
return min_lr
|
||||
|
||||
# If we are done with the warmup period, use the decay style.
|
||||
num_steps_ = step - warmup_steps
|
||||
decay_steps_ = decay_steps
|
||||
decay_ratio = float(num_steps_) / float(decay_steps_)
|
||||
assert decay_ratio >= 0.0
|
||||
assert decay_ratio <= 1.0
|
||||
delta_lr = max_lr - min_lr
|
||||
|
||||
coeff = 0.5 * (math.cos(math.pi * decay_ratio) + 1.0)
|
||||
|
||||
return min_lr + coeff * delta_lr
|
||||
|
||||
|
||||
def _poly_decay(initial_lr, step, decay_steps, power, min_lr, cycle):
|
||||
if cycle:
|
||||
multiplier = 1.0 if step == 0 else math.ceil(step / decay_steps)
|
||||
|
@ -221,7 +323,7 @@ class SquareRootAnnealing(WarmupPolicy):
|
|||
return new_lrs
|
||||
|
||||
|
||||
class CosineAnnealing(WarmupPolicy):
|
||||
class CosineAnnealing(WarmupAnnealHoldPolicy):
|
||||
def __init__(self, optimizer, *, max_steps, min_lr=0, last_epoch=-1, **kwargs):
|
||||
super().__init__(optimizer=optimizer, max_steps=max_steps, last_epoch=last_epoch, min_lr=min_lr, **kwargs)
|
||||
|
||||
|
@ -232,15 +334,27 @@ class CosineAnnealing(WarmupPolicy):
|
|||
f"{self} received an initial learning rate that was lower than the minimum learning rate."
|
||||
)
|
||||
|
||||
new_lrs = [
|
||||
_cosine_annealing(
|
||||
initial_lr=initial_lr,
|
||||
step=step - self.warmup_steps,
|
||||
max_steps=self.max_steps - self.warmup_steps,
|
||||
min_lr=self.min_lr,
|
||||
)
|
||||
for initial_lr in self.base_lrs
|
||||
]
|
||||
if self.constant_steps is None:
|
||||
new_lrs = [
|
||||
_cosine_annealing(
|
||||
initial_lr=initial_lr,
|
||||
step=step - self.warmup_steps,
|
||||
max_steps=self.max_steps - self.warmup_steps,
|
||||
min_lr=self.min_lr,
|
||||
)
|
||||
for initial_lr in self.base_lrs
|
||||
]
|
||||
else:
|
||||
new_lrs = [
|
||||
_linear_warmup_with_cosine_annealing(
|
||||
max_lr=self.base_lrs[0],
|
||||
warmup_steps=self.warmup_steps,
|
||||
step=step,
|
||||
decay_steps=self.decay_steps,
|
||||
min_lr=self.min_lr,
|
||||
)
|
||||
for _ in self.base_lrs
|
||||
]
|
||||
return new_lrs
|
||||
|
||||
|
||||
|
@ -587,7 +701,17 @@ def prepare_lr_scheduler(
|
|||
|
||||
# Compute effective num max_steps
|
||||
num_samples = len(train_dataloader.dataset)
|
||||
batch_size = train_dataloader.batch_size
|
||||
# TODO: not sure if this will be the correct LR schedule for Megatron
|
||||
# we may need to override ModelPT setup_optimization
|
||||
if train_dataloader.batch_size is not None:
|
||||
batch_size = train_dataloader.batch_size
|
||||
elif train_dataloader.batch_sampler is not None:
|
||||
if train_dataloader.batch_sampler.micro_batch_size is not None:
|
||||
batch_size = train_dataloader.batch_sampler.micro_batch_size
|
||||
else:
|
||||
raise ValueError(f'Could not find batch_size from batch_sampler: {train_dataloader.batch_sampler}')
|
||||
else:
|
||||
raise ValueError(f'Could not find batch_size from train_dataloader: {train_dataloader}')
|
||||
drop_last = train_dataloader.drop_last
|
||||
|
||||
max_steps = compute_max_steps(
|
||||
|
|
|
@ -17,6 +17,7 @@ from functools import partial
|
|||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import hydra
|
||||
import torch
|
||||
import torch.optim as optim
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
from torch.optim import adadelta, adagrad, adamax, rmsprop, rprop
|
||||
|
@ -41,10 +42,12 @@ AVAILABLE_OPTIMIZERS = {
|
|||
|
||||
try:
|
||||
from apex.optimizers import FusedLAMB
|
||||
from apex.optimizers import FusedAdam
|
||||
|
||||
AVAILABLE_OPTIMIZERS['lamb'] = FusedLAMB
|
||||
AVAILABLE_OPTIMIZERS['fused_adam'] = FusedAdam
|
||||
except ModuleNotFoundError:
|
||||
logging.warning("Apex was not found. Using the lamb optimizer will error out.")
|
||||
logging.warning("Apex was not found. Using the lamb or fused_adam optimizer will error out.")
|
||||
|
||||
__all__ = ['get_optimizer', 'register_optimizer', 'parse_optimizer_args']
|
||||
|
||||
|
@ -165,6 +168,9 @@ def get_optimizer(name: str, **kwargs: Optional[Dict[str, Any]]) -> Optimizer:
|
|||
raise ValueError(
|
||||
f"Cannot resolve optimizer '{name}'. Available optimizers are : " f"{AVAILABLE_OPTIMIZERS.keys()}"
|
||||
)
|
||||
if name == 'fused_adam':
|
||||
if not torch.cuda.is_available():
|
||||
raise ValueError(f'CUDA must be available to use fused_adam.')
|
||||
|
||||
optimizer = AVAILABLE_OPTIMIZERS[name]
|
||||
optimizer = partial(optimizer, **kwargs)
|
||||
|
|
|
@ -14,9 +14,9 @@
|
|||
|
||||
|
||||
MAJOR = 1
|
||||
MINOR = 4
|
||||
MINOR = 5
|
||||
PATCH = 0
|
||||
PRE_RELEASE = ''
|
||||
PRE_RELEASE = 'b1'
|
||||
|
||||
# Use the following formatting: (major, minor, patch, pre-release)
|
||||
VERSION = (MAJOR, MINOR, PATCH, PRE_RELEASE)
|
||||
|
|
|
@ -44,8 +44,10 @@ class AppState(metaclass=Singleton):
|
|||
self._world_size = None
|
||||
self._model_parallel_size = None
|
||||
self._model_parallel_group = None
|
||||
self._is_megatron_initialized = False
|
||||
self._data_parallel_size = None
|
||||
self._data_parallel_group = None
|
||||
self._megatron_checkpoint_version = None
|
||||
|
||||
self._random_seed = None
|
||||
|
||||
|
|
|
@ -31,11 +31,10 @@ from pytorch_lightning.callbacks import Callback, ModelCheckpoint
|
|||
from pytorch_lightning.loggers import LoggerCollection as _LoggerCollection
|
||||
from pytorch_lightning.loggers import TensorBoardLogger, WandbLogger
|
||||
from pytorch_lightning.plugins.training_type.ddp import DDPPlugin
|
||||
from pytorch_lightning.utilities import rank_zero_only
|
||||
from pytorch_lightning.utilities.types import _METRIC
|
||||
|
||||
from nemo.constants import NEMO_ENV_VARNAME_VERSION
|
||||
from nemo.utils import app_state, logging, timers
|
||||
from nemo.utils import logging, timers
|
||||
from nemo.utils.app_state import AppState
|
||||
from nemo.utils.exceptions import NeMoBaseException
|
||||
from nemo.utils.get_rank import is_global_rank_zero
|
||||
|
@ -72,7 +71,6 @@ class CallbackParams:
|
|||
save_top_k: Optional[int] = 3
|
||||
save_weights_only: Optional[bool] = False
|
||||
mode: Optional[str] = "min"
|
||||
period: Optional[int] = None
|
||||
every_n_val_epochs: Optional[int] = 1
|
||||
prefix: Optional[str] = None # If None, exp_manager will attempt to handle the filepath
|
||||
postfix: str = ".nemo"
|
||||
|
@ -380,6 +378,7 @@ def check_resume(
|
|||
NotFoundError: If resume is True, resume_ignore_no_checkpoint is False, and checkpoints could not be found.
|
||||
ValueError: If resume is True, and there were more than 1 checkpoint could found.
|
||||
"""
|
||||
|
||||
if not log_dir:
|
||||
raise ValueError(f"Resuming requires the log_dir {log_dir} to be passed to exp_manager")
|
||||
|
||||
|
@ -707,58 +706,55 @@ class NeMoModelCheckpoint(ModelCheckpoint):
|
|||
for _ in range(models_to_delete):
|
||||
model = best_k_models.pop(-1)
|
||||
self.best_k_models.pop(model)
|
||||
self._del_model(model)
|
||||
self._del_model_without_trainer(model)
|
||||
logging.debug(f"Removed checkpoint: {model}")
|
||||
|
||||
self.kth_best_model_path = best_k_models[-1]
|
||||
self.best_model_path = best_k_models[0]
|
||||
self.best_model_score = self.best_k_models[self.best_model_path]
|
||||
|
||||
@rank_zero_only
|
||||
def on_save_checkpoint(self, trainer, pl_module, checkpoint):
|
||||
output = super().on_save_checkpoint(trainer, pl_module, checkpoint)
|
||||
|
||||
if not self.always_save_nemo:
|
||||
return output
|
||||
|
||||
# Load the best model and then re-save it
|
||||
app_state = AppState()
|
||||
# since we are creating tarfile artifacts we need to update .nemo path
|
||||
app_state.model_restore_path = os.path.abspath(
|
||||
os.path.expanduser(os.path.join(self.dirpath, self.prefix + self.postfix))
|
||||
)
|
||||
if self.save_best_model:
|
||||
if not os.path.exists(self.best_model_path):
|
||||
return output
|
||||
|
||||
if self.best_model_path == self.previous_best_path:
|
||||
return output
|
||||
|
||||
self.previous_model_path = self.best_model_path
|
||||
old_state_dict = deepcopy(pl_module.state_dict())
|
||||
checkpoint = torch.load(self.best_model_path, map_location='cpu')
|
||||
if 'state_dict' in checkpoint:
|
||||
checkpoint = checkpoint['state_dict']
|
||||
# get a new instanace of the model
|
||||
pl_module.load_state_dict(checkpoint, strict=True)
|
||||
pl_module.save_to(save_path=app_state.model_restore_path)
|
||||
pl_module.load_state_dict(old_state_dict, strict=True)
|
||||
else:
|
||||
pl_module.save_to(save_path=app_state.model_restore_path)
|
||||
return output
|
||||
# Load the best model and then re-save it
|
||||
app_state = AppState()
|
||||
if app_state.model_parallel_size is not None and app_state.model_parallel_size > 1:
|
||||
raise ValueError(f'always_save_nemo is not implemented for model parallel models.')
|
||||
# since we are creating tarfile artifacts we need to update .nemo path
|
||||
app_state.model_restore_path = os.path.abspath(
|
||||
os.path.expanduser(os.path.join(self.dirpath, self.prefix + self.postfix))
|
||||
)
|
||||
if self.save_best_model:
|
||||
if not os.path.exists(self.best_model_path):
|
||||
return output
|
||||
|
||||
if self.best_model_path == self.previous_best_path:
|
||||
return output
|
||||
|
||||
self.previous_model_path = self.best_model_path
|
||||
old_state_dict = deepcopy(pl_module.state_dict())
|
||||
checkpoint = torch.load(self.best_model_path, map_location='cpu')
|
||||
if 'state_dict' in checkpoint:
|
||||
checkpoint = checkpoint['state_dict']
|
||||
# get a new instanace of the model
|
||||
pl_module.load_state_dict(checkpoint, strict=True)
|
||||
pl_module.save_to(save_path=app_state.model_restore_path)
|
||||
pl_module.load_state_dict(old_state_dict, strict=True)
|
||||
else:
|
||||
pl_module.save_to(save_path=app_state.model_restore_path)
|
||||
return output
|
||||
|
||||
@rank_zero_only
|
||||
def on_train_end(self, trainer, pl_module):
|
||||
if trainer.fast_dev_run:
|
||||
return None
|
||||
app_state = AppState()
|
||||
if app_state.model_parallel_size is not None:
|
||||
return None
|
||||
|
||||
# TODO: make this work for model parallel, need to call on data parallel rank 0 and update best_model_path
|
||||
# Load the best model and then re-save it
|
||||
if self.save_best_model:
|
||||
trainer.checkpoint_connector.restore(self.best_model_path)
|
||||
|
||||
pl_module.save_to(save_path=os.path.join(self.dirpath, self.prefix + self.postfix))
|
||||
|
||||
def _del_model(self, trainer: "pl.Trainer", filepath: str) -> None:
|
||||
|
@ -768,18 +764,36 @@ class NeMoModelCheckpoint(ModelCheckpoint):
|
|||
app_state = AppState()
|
||||
if app_state.model_parallel_size is not None:
|
||||
# filepath needs to be updated to include mp_rank
|
||||
# TODO: figure out a good way to update these filepaths
|
||||
dirname = os.path.dirname(filepath)
|
||||
basename = os.path.basename(filepath)
|
||||
filepath = f'{dirname}/mp_rank_{app_state.model_parallel_rank:02d}/{basename}'
|
||||
|
||||
# each model parallel rank needs to remove its model
|
||||
if app_state.data_parallel_rank == 0:
|
||||
super()._del_model(trainer, filepath)
|
||||
logging.info(f"Removed model parallel checkpoint: {filepath}")
|
||||
if self._fs.exists(filepath):
|
||||
self._fs.rm(filepath)
|
||||
logging.info(f"Removed model parallel checkpoint: {filepath}")
|
||||
|
||||
else:
|
||||
return super()._del_model(trainer, filepath)
|
||||
|
||||
def _del_model_without_trainer(self, filepath: str) -> None:
|
||||
app_state = AppState()
|
||||
if app_state.model_parallel_size is not None:
|
||||
# filepath needs to be updated to include mp_rank
|
||||
dirname = os.path.dirname(filepath)
|
||||
basename = os.path.basename(filepath)
|
||||
filepath = f'{dirname}/mp_rank_{app_state.model_parallel_rank:02d}/{basename}'
|
||||
|
||||
# each model parallel rank needs to remove its model
|
||||
if is_global_rank_zero() or (app_state.model_parallel_size is not None and app_state.data_parallel_rank == 0):
|
||||
try:
|
||||
self._fs.rm(filepath)
|
||||
logging.info(f"Removed checkpoint: {filepath}")
|
||||
except:
|
||||
logging.info(f"Tried to remove checkpoint: {filepath} but failed.")
|
||||
|
||||
def _save_last_checkpoint(self, trainer: 'pl.Trainer', monitor_candidates: Dict[str, _METRIC]) -> None:
|
||||
""" Overrides PTL method to account for model parallel checkpoints.
|
||||
Checks for data parallel rank 0 rather than global rank 0.
|
||||
|
@ -792,11 +806,18 @@ class NeMoModelCheckpoint(ModelCheckpoint):
|
|||
filepath = self._format_checkpoint_name(self.CHECKPOINT_NAME_LAST, monitor_candidates)
|
||||
filepath = os.path.join(self.dirpath, f"{filepath}{self.FILE_EXTENSION}")
|
||||
|
||||
self._save_model(trainer, filepath)
|
||||
trainer.save_checkpoint(filepath)
|
||||
|
||||
# TODO: figure out where self.last_model_path is being set
|
||||
if self.last_model_path is not None:
|
||||
if 'mp_rank' in self.last_model_path:
|
||||
last_model_path = Path(self.last_model_path)
|
||||
last_model_path = last_model_path.parent.parent.joinpath(last_model_path.name)
|
||||
self.last_model_path = str(last_model_path)
|
||||
|
||||
# for model parallel we need to delete models for each model parallel rank
|
||||
if self.last_model_path and self.last_model_path != filepath and app_state.data_parallel_rank == 0:
|
||||
self._del_model(self.last_model_path)
|
||||
self._del_model(trainer, self.last_model_path)
|
||||
|
||||
self.last_model_path = filepath
|
||||
|
||||
|
@ -813,7 +834,8 @@ class NeMoModelCheckpoint(ModelCheckpoint):
|
|||
return
|
||||
|
||||
filepath = self._get_metric_interpolated_filepath_name(monitor_candidates, trainer)
|
||||
self._save_model(trainer, filepath)
|
||||
|
||||
trainer.save_checkpoint(filepath)
|
||||
|
||||
if (
|
||||
self.save_top_k is None
|
||||
|
@ -821,7 +843,7 @@ class NeMoModelCheckpoint(ModelCheckpoint):
|
|||
and self.best_model_path != filepath
|
||||
and app_state.data_parallel_rank == 0
|
||||
):
|
||||
self._del_model(self.best_model_path)
|
||||
self._del_model(trainer, self.best_model_path)
|
||||
|
||||
self.best_model_path = filepath
|
||||
else:
|
||||
|
@ -887,14 +909,6 @@ def configure_checkpointing(
|
|||
f"{trainer.check_val_every_n_epoch} epochs to ensure that checkpointing will not error out."
|
||||
)
|
||||
|
||||
if params.period is not None:
|
||||
logging.warning(
|
||||
"The use of `period` in the checkpoint callback is deprecrated, please use `every_n_val_epochs` instead. "
|
||||
"Overwriting `every_n_val_epochs` with `period`."
|
||||
)
|
||||
params.every_n_val_epochs = params.period
|
||||
|
||||
# NOTE: checkpoint_callback should be called last
|
||||
checkpoint_callback = NeMoModelCheckpoint(n_resume=resume, **params)
|
||||
checkpoint_callback.last_model_path = trainer.checkpoint_connector.resume_checkpoint_path or ""
|
||||
trainer.callbacks.append(checkpoint_callback)
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show more
Loading…
Reference in a new issue