# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import argparse import copy import os import random import time try: import nvidia_dlprof_pytorch_nvtx as pyprof except: import pyprof import torch import amp_C import numpy as np import torch.cuda.profiler as profiler import torch.distributed as dist from apex.optimizers import FusedLAMB, FusedNovoGrad from contextlib import suppress as empty_context from common import helpers from common.dali.data_loader import DaliDataLoader from common.dataset import AudioDataset, get_data_loader from common.features import BaseFeatures, FilterbankFeatures from common.helpers import (Checkpointer, greedy_wer, num_weights, print_once, process_evaluation_epoch) from common.optimizers import AdamW, lr_policy, Novograd from common.tb_dllogger import flush_log, init_log, log from common.utils import BenchmarkStats from quartznet import config from quartznet.model import CTCLossNM, GreedyCTCDecoder, QuartzNet def parse_args(): parser = argparse.ArgumentParser(description='QuartzNet') training = parser.add_argument_group('training setup') training.add_argument('--epochs', default=400, type=int, help='Number of epochs for the entire training; influences the lr schedule') training.add_argument("--warmup_epochs", default=0, type=int, help='Initial epochs of increasing learning rate') training.add_argument("--hold_epochs", default=0, type=int, help='Constant max learning rate epochs after warmup') training.add_argument('--epochs_this_job', default=0, type=int, help=('Run for a number of epochs with no effect on the lr schedule.' 'Useful for re-starting the training.')) training.add_argument('--cudnn_benchmark', action='store_true', default=True, help='Enable cudnn benchmark') training.add_argument('--amp', '--fp16', action='store_true', default=False, help='Use pytorch native mixed precision training') training.add_argument('--seed', default=1, type=int, help='Random seed') training.add_argument('--local_rank', default=os.getenv('LOCAL_RANK', 0), type=int, help='GPU id used for distributed training') training.add_argument('--pre_allocate_range', default=None, type=int, nargs=2, help='Warmup with batches of length [min, max] before training') training.add_argument('--pyprof', action='store_true', help='Enable pyprof profiling') optim = parser.add_argument_group('optimization setup') optim.add_argument('--gpu_batch_size', default=32, type=int, help='Batch size for a single forward/backward pass. ' 'The Effective batch size is gpu_batch_size * grad_accumulation.') optim.add_argument('--lr', default=1e-3, type=float, help='Peak learning rate') optim.add_argument("--min_lr", default=1e-5, type=float, help='minimum learning rate') optim.add_argument("--lr_policy", default='exponential', type=str, choices=['exponential', 'legacy'], help='lr scheduler') optim.add_argument("--lr_exp_gamma", default=0.99, type=float, help='gamma factor for exponential lr scheduler') optim.add_argument('--weight_decay', default=1e-3, type=float, help='Weight decay for the optimizer') optim.add_argument('--grad_accumulation', '--update-freq', default=1, type=int, help='Number of accumulation steps') optim.add_argument('--optimizer', default='novograd', type=str, choices=['novograd', 'adamw', 'lamb98', 'fused_novograd'], help='Optimization algorithm') optim.add_argument('--ema', type=float, default=0.0, help='Discount factor for exp averaging of model weights') optim.add_argument('--multi_tensor_ema', action='store_true', help='Use multi_tensor_apply for EMA') io = parser.add_argument_group('feature and checkpointing setup') io.add_argument('--dali_device', type=str, choices=['none', 'cpu', 'gpu'], default='gpu', help='Use DALI pipeline for fast data processing') io.add_argument('--resume', action='store_true', help='Try to resume from last saved checkpoint.') io.add_argument('--ckpt', default=None, type=str, help='Path to a checkpoint for resuming training') io.add_argument('--save_frequency', default=10, type=int, help='Checkpoint saving frequency in epochs') io.add_argument('--keep_milestones', default=[100, 200, 300], type=int, nargs='+', help='Milestone checkpoints to keep from removing') io.add_argument('--save_best_from', default=380, type=int, help='Epoch on which to begin tracking best checkpoint (dev WER)') io.add_argument('--eval_frequency', default=200, type=int, help='Number of steps between evaluations on dev set') io.add_argument('--log_frequency', default=25, type=int, help='Number of steps between printing training stats') io.add_argument('--prediction_frequency', default=100, type=int, help='Number of steps between printing sample decodings') io.add_argument('--model_config', type=str, required=True, help='Path of the model configuration file') io.add_argument('--train_manifests', type=str, required=True, nargs='+', help='Paths of the training dataset manifest file') io.add_argument('--val_manifests', type=str, required=True, nargs='+', help='Paths of the evaluation datasets manifest files') io.add_argument('--dataset_dir', required=True, type=str, help='Root dir of dataset') io.add_argument('--output_dir', type=str, required=True, help='Directory for logs and checkpoints') io.add_argument('--log_file', type=str, default=None, help='Path to save the training logfile.') io.add_argument('--benchmark_epochs_num', type=int, default=1, help='Number of epochs accounted in final average throughput.') io.add_argument('--override_config', type=str, action='append', help='Overrides arbitrary config value.' ' Syntax: `--override_config nested.config.key=val`.') return parser.parse_args() def reduce_tensor(tensor, num_gpus): rt = tensor.clone() dist.all_reduce(rt, op=dist.ReduceOp.SUM) return rt.true_divide(num_gpus) def init_multi_tensor_ema(model, ema_model): model_weights = list(model.state_dict().values()) ema_model_weights = list(ema_model.state_dict().values()) ema_overflow_buf = torch.cuda.IntTensor([0]) return model_weights, ema_model_weights, ema_overflow_buf def apply_multi_tensor_ema(decay, model_weights, ema_model_weights, overflow_buf): amp_C.multi_tensor_axpby( 65536, overflow_buf, [ema_model_weights, model_weights, ema_model_weights], decay, 1-decay, -1) def apply_ema(model, ema_model, decay): if not decay: return sd = getattr(model, 'module', model).state_dict() for k, v in ema_model.state_dict().items(): v.copy_(decay * v + (1 - decay) * sd[k]) @torch.no_grad() def evaluate(epoch, step, val_loader, val_feat_proc, labels, model, ema_model, ctc_loss, greedy_decoder, use_amp, use_dali=False): for model, subset in [(model, 'dev'), (ema_model, 'dev_ema')]: if model is None: continue model.eval() start_time = time.time() agg = {'losses': [], 'preds': [], 'txts': []} for batch in val_loader: if use_dali: # with DALI, the data is already on GPU feat, feat_lens, txt, txt_lens = batch if val_feat_proc is not None: feat, feat_lens = val_feat_proc(feat, feat_lens) else: batch = [t.cuda(non_blocking=True) for t in batch] audio, audio_lens, txt, txt_lens = batch feat, feat_lens = val_feat_proc(audio, audio_lens) with torch.cuda.amp.autocast(enabled=use_amp): log_probs, enc_lens = model(feat, feat_lens) loss = ctc_loss(log_probs, txt, enc_lens, txt_lens) pred = greedy_decoder(log_probs) agg['losses'] += helpers.gather_losses([loss]) agg['preds'] += helpers.gather_predictions([pred], labels) agg['txts'] += helpers.gather_transcripts([txt], [txt_lens], labels) wer, loss = process_evaluation_epoch(agg) log((epoch,), step, subset, {'loss': loss, 'wer': 100.0 * wer, 'took': time.time() - start_time}) model.train() return wer def main(): args = parse_args() assert(torch.cuda.is_available()) assert args.prediction_frequency % args.log_frequency == 0 torch.backends.cudnn.benchmark = args.cudnn_benchmark # set up distributed training multi_gpu = int(os.environ.get('WORLD_SIZE', 1)) > 1 if multi_gpu: torch.cuda.set_device(args.local_rank) dist.init_process_group(backend='nccl', init_method='env://') world_size = dist.get_world_size() print_once(f'Distributed training with {world_size} GPUs\n') else: world_size = 1 torch.manual_seed(args.seed + args.local_rank) np.random.seed(args.seed + args.local_rank) random.seed(args.seed + args.local_rank) init_log(args) cfg = config.load(args.model_config) config.apply_config_overrides(cfg, args) symbols = helpers.add_ctc_blank(cfg['labels']) assert args.grad_accumulation >= 1 batch_size = args.gpu_batch_size print_once('Setting up datasets...') train_dataset_kw, train_features_kw = config.input(cfg, 'train') val_dataset_kw, val_features_kw = config.input(cfg, 'val') use_dali = args.dali_device in ('cpu', 'gpu') if use_dali: assert train_dataset_kw['ignore_offline_speed_perturbation'], \ "DALI doesn't support offline speed perturbation" # pad_to_max_duration is not supported by DALI - have simple padders if train_features_kw['pad_to_max_duration']: train_feat_proc = BaseFeatures( pad_align=train_features_kw['pad_align'], pad_to_max_duration=True, max_duration=train_features_kw['max_duration'], sample_rate=train_features_kw['sample_rate'], window_size=train_features_kw['window_size'], window_stride=train_features_kw['window_stride']) train_features_kw['pad_to_max_duration'] = False else: train_feat_proc = None if val_features_kw['pad_to_max_duration']: val_feat_proc = BaseFeatures( pad_align=val_features_kw['pad_align'], pad_to_max_duration=True, max_duration=val_features_kw['max_duration'], sample_rate=val_features_kw['sample_rate'], window_size=val_features_kw['window_size'], window_stride=val_features_kw['window_stride']) val_features_kw['pad_to_max_duration'] = False else: val_feat_proc = None train_loader = DaliDataLoader(gpu_id=args.local_rank, dataset_path=args.dataset_dir, config_data=train_dataset_kw, config_features=train_features_kw, json_names=args.train_manifests, batch_size=batch_size, grad_accumulation_steps=args.grad_accumulation, pipeline_type="train", device_type=args.dali_device, symbols=symbols) val_loader = DaliDataLoader(gpu_id=args.local_rank, dataset_path=args.dataset_dir, config_data=val_dataset_kw, config_features=val_features_kw, json_names=args.val_manifests, batch_size=batch_size, pipeline_type="val", device_type=args.dali_device, symbols=symbols) else: train_dataset_kw, train_features_kw = config.input(cfg, 'train') train_dataset = AudioDataset(args.dataset_dir, args.train_manifests, symbols, **train_dataset_kw) train_loader = get_data_loader(train_dataset, batch_size, multi_gpu=multi_gpu, shuffle=True, num_workers=4) train_feat_proc = FilterbankFeatures(**train_features_kw) val_dataset_kw, val_features_kw = config.input(cfg, 'val') val_dataset = AudioDataset(args.dataset_dir, args.val_manifests, symbols, **val_dataset_kw) val_loader = get_data_loader(val_dataset, batch_size, multi_gpu=multi_gpu, shuffle=False, num_workers=4, drop_last=False) val_feat_proc = FilterbankFeatures(**val_features_kw) dur = train_dataset.duration / 3600 dur_f = train_dataset.duration_filtered / 3600 nsampl = len(train_dataset) print_once(f'Training samples: {nsampl} ({dur:.1f}h, ' f'filtered {dur_f:.1f}h)') if train_feat_proc is not None: train_feat_proc.cuda() if val_feat_proc is not None: val_feat_proc.cuda() steps_per_epoch = len(train_loader) // args.grad_accumulation # set up the model model = QuartzNet(encoder_kw=config.encoder(cfg), decoder_kw=config.decoder(cfg, n_classes=len(symbols))) model.cuda() ctc_loss = CTCLossNM(n_classes=len(symbols)) greedy_decoder = GreedyCTCDecoder() print_once(f'Model size: {num_weights(model) / 10**6:.1f}M params\n') # optimization kw = {'lr': args.lr, 'weight_decay': args.weight_decay} if args.optimizer == "novograd": optimizer = Novograd(model.parameters(), **kw) elif args.optimizer == "adamw": optimizer = AdamW(model.parameters(), **kw) elif args.optimizer == 'lamb98': optimizer = FusedLAMB(model.parameters(), betas=(0.9, 0.98), eps=1e-9, **kw) elif args.optimizer == 'fused_novograd': optimizer = FusedNovoGrad(model.parameters(), betas=(0.95, 0), bias_correction=False, reg_inside_moment=True, grad_averaging=False, **kw) else: raise ValueError(f'Invalid optimizer "{args.optimizer}"') scaler = torch.cuda.amp.GradScaler(enabled=args.amp) adjust_lr = lambda step, epoch, optimizer: lr_policy( step, epoch, args.lr, optimizer, steps_per_epoch=steps_per_epoch, warmup_epochs=args.warmup_epochs, hold_epochs=args.hold_epochs, num_epochs=args.epochs, policy=args.lr_policy, min_lr=args.min_lr, exp_gamma=args.lr_exp_gamma) if args.ema > 0: ema_model = copy.deepcopy(model) else: ema_model = None if multi_gpu: model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[args.local_rank], output_device=args.local_rank) if args.pyprof: pyprof.init(enable_function_stack=True) # load checkpoint meta = {'best_wer': 10**6, 'start_epoch': 0} checkpointer = Checkpointer(args.output_dir, 'QuartzNet', args.keep_milestones) if args.resume: args.ckpt = checkpointer.last_checkpoint() or args.ckpt if args.ckpt is not None: checkpointer.load(args.ckpt, model, ema_model, optimizer, scaler, meta) start_epoch = meta['start_epoch'] best_wer = meta['best_wer'] epoch = 1 step = start_epoch * steps_per_epoch + 1 if args.pyprof: torch.autograd.profiler.emit_nvtx().__enter__() profiler.start() # training loop model.train() if args.ema > 0.0: mt_ema_params = init_multi_tensor_ema(model, ema_model) # ema_model_weight_list, model_weight_list, overflow_buf_for_ema = ema_ # pre-allocate if args.pre_allocate_range is not None: n_feats = train_features_kw['n_filt'] pad_align = train_features_kw['pad_align'] a, b = args.pre_allocate_range for n_frames in range(a, b + pad_align, pad_align): print_once(f'Pre-allocation ({batch_size}x{n_feats}x{n_frames})...') feat = torch.randn(batch_size, n_feats, n_frames, device='cuda') feat_lens = torch.ones(batch_size, device='cuda').fill_(n_frames) txt = torch.randint(high=len(symbols)-1, size=(batch_size, 100), device='cuda') txt_lens = torch.ones(batch_size, device='cuda').fill_(100) with torch.cuda.amp.autocast(enabled=args.amp): log_probs, enc_lens = model(feat, feat_lens) del feat loss = ctc_loss(log_probs, txt, enc_lens, txt_lens) loss.backward() model.zero_grad() torch.cuda.empty_cache() bmark_stats = BenchmarkStats() for epoch in range(start_epoch + 1, args.epochs + 1): if multi_gpu and not use_dali: train_loader.sampler.set_epoch(epoch) epoch_utts = 0 epoch_loss = 0 accumulated_batches = 0 epoch_start_time = time.time() epoch_eval_time = 0 for batch in train_loader: if accumulated_batches == 0: step_loss = 0 step_utts = 0 step_start_time = time.time() if use_dali: # with DALI, the data is already on GPU feat, feat_lens, txt, txt_lens = batch if train_feat_proc is not None: feat, feat_lens = train_feat_proc(feat, feat_lens) else: batch = [t.cuda(non_blocking=True) for t in batch] audio, audio_lens, txt, txt_lens = batch feat, feat_lens = train_feat_proc(audio, audio_lens) # Use context manager to prevent redundant accumulation of gradients if (multi_gpu and accumulated_batches + 1 < args.grad_accumulation): ctx = model.no_sync() else: ctx = empty_context() with ctx: with torch.cuda.amp.autocast(enabled=args.amp): log_probs, enc_lens = model(feat, feat_lens) loss = ctc_loss(log_probs, txt, enc_lens, txt_lens) loss /= args.grad_accumulation if multi_gpu: reduced_loss = reduce_tensor(loss.data, world_size) else: reduced_loss = loss if torch.isnan(reduced_loss).any(): print_once(f'WARNING: loss is NaN; skipping update') continue else: step_loss += reduced_loss.item() step_utts += batch[0].size(0) * world_size epoch_utts += batch[0].size(0) * world_size accumulated_batches += 1 scaler.scale(loss).backward() if accumulated_batches % args.grad_accumulation == 0: epoch_loss += step_loss scaler.step(optimizer) scaler.update() adjust_lr(step, epoch, optimizer) optimizer.zero_grad() if args.ema > 0.0: apply_multi_tensor_ema(args.ema, *mt_ema_params) if step % args.log_frequency == 0: preds = greedy_decoder(log_probs) wer, pred_utt, ref = greedy_wer(preds, txt, txt_lens, symbols) if step % args.prediction_frequency == 0: print_once(f' Decoded: {pred_utt[:90]}') print_once(f' Reference: {ref[:90]}') step_time = time.time() - step_start_time log((epoch, step % steps_per_epoch or steps_per_epoch, steps_per_epoch), step, 'train', {'loss': step_loss, 'wer': 100.0 * wer, 'throughput': step_utts / step_time, 'took': step_time, 'lrate': optimizer.param_groups[0]['lr']}) step_start_time = time.time() if step % args.eval_frequency == 0: tik = time.time() wer = evaluate(epoch, step, val_loader, val_feat_proc, symbols, model, ema_model, ctc_loss, greedy_decoder, args.amp, use_dali) if wer < best_wer and epoch >= args.save_best_from: checkpointer.save(model, ema_model, optimizer, scaler, epoch, step, best_wer, is_best=True) best_wer = wer epoch_eval_time += time.time() - tik step += 1 accumulated_batches = 0 # end of step # DALI iterator need to be exhausted; # if not using DALI, simulate drop_last=True with grad accumulation if not use_dali and step > steps_per_epoch * epoch: break epoch_time = time.time() - epoch_start_time epoch_loss /= steps_per_epoch log((epoch,), None, 'train_avg', {'throughput': epoch_utts / epoch_time, 'took': epoch_time, 'loss': epoch_loss}) bmark_stats.update(epoch_utts, epoch_time, epoch_loss) if epoch % args.save_frequency == 0 or epoch in args.keep_milestones: checkpointer.save(model, ema_model, optimizer, scaler, epoch, step, best_wer) if 0 < args.epochs_this_job <= epoch - start_epoch: print_once(f'Finished after {args.epochs_this_job} epochs.') break # end of epoch if args.pyprof: profiler.stop() torch.autograd.profiler.emit_nvtx().__exit__(None, None, None) log((), None, 'train_avg', bmark_stats.get(args.benchmark_epochs_num)) if epoch == args.epochs: evaluate(epoch, step, val_loader, val_feat_proc, symbols, model, ema_model, ctc_loss, greedy_decoder, args.amp, use_dali) checkpointer.save(model, ema_model, optimizer, scaler, epoch, step, best_wer) flush_log() if __name__ == "__main__": main()