DeepLearningExamples/PyTorch/SpeechRecognition/QuartzNet/train.py
2021-09-14 06:03:36 -07:00

559 lines
24 KiB
Python

# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import copy
import os
import random
import time
try:
import nvidia_dlprof_pytorch_nvtx as pyprof
except:
import pyprof
import torch
import amp_C
import numpy as np
import torch.cuda.profiler as profiler
import torch.distributed as dist
from apex.optimizers import FusedLAMB, FusedNovoGrad
from contextlib import suppress as empty_context
from common import helpers
from common.dali.data_loader import DaliDataLoader
from common.dataset import AudioDataset, get_data_loader
from common.features import BaseFeatures, FilterbankFeatures
from common.helpers import (Checkpointer, greedy_wer, num_weights, print_once,
process_evaluation_epoch)
from common.optimizers import AdamW, lr_policy, Novograd
from common.tb_dllogger import flush_log, init_log, log
from common.utils import BenchmarkStats
from quartznet import config
from quartznet.model import CTCLossNM, GreedyCTCDecoder, QuartzNet
def parse_args():
parser = argparse.ArgumentParser(description='QuartzNet')
training = parser.add_argument_group('training setup')
training.add_argument('--epochs', default=400, type=int,
help='Number of epochs for the entire training; influences the lr schedule')
training.add_argument("--warmup_epochs", default=0, type=int,
help='Initial epochs of increasing learning rate')
training.add_argument("--hold_epochs", default=0, type=int,
help='Constant max learning rate epochs after warmup')
training.add_argument('--epochs_this_job', default=0, type=int,
help=('Run for a number of epochs with no effect on the lr schedule.'
'Useful for re-starting the training.'))
training.add_argument('--cudnn_benchmark', action='store_true', default=True,
help='Enable cudnn benchmark')
training.add_argument('--amp', '--fp16', action='store_true', default=False,
help='Use pytorch native mixed precision training')
training.add_argument('--seed', default=1, type=int, help='Random seed')
training.add_argument('--local_rank', default=os.getenv('LOCAL_RANK', 0), type=int,
help='GPU id used for distributed training')
training.add_argument('--pre_allocate_range', default=None, type=int, nargs=2,
help='Warmup with batches of length [min, max] before training')
training.add_argument('--pyprof', action='store_true', help='Enable pyprof profiling')
optim = parser.add_argument_group('optimization setup')
optim.add_argument('--gpu_batch_size', default=32, type=int,
help='Batch size for a single forward/backward pass. '
'The Effective batch size is gpu_batch_size * grad_accumulation.')
optim.add_argument('--lr', default=1e-3, type=float,
help='Peak learning rate')
optim.add_argument("--min_lr", default=1e-5, type=float,
help='minimum learning rate')
optim.add_argument("--lr_policy", default='exponential', type=str,
choices=['exponential', 'legacy'], help='lr scheduler')
optim.add_argument("--lr_exp_gamma", default=0.99, type=float,
help='gamma factor for exponential lr scheduler')
optim.add_argument('--weight_decay', default=1e-3, type=float,
help='Weight decay for the optimizer')
optim.add_argument('--grad_accumulation', '--update-freq', default=1, type=int,
help='Number of accumulation steps')
optim.add_argument('--optimizer', default='novograd', type=str,
choices=['novograd', 'adamw', 'lamb98', 'fused_novograd'],
help='Optimization algorithm')
optim.add_argument('--ema', type=float, default=0.0,
help='Discount factor for exp averaging of model weights')
optim.add_argument('--multi_tensor_ema', action='store_true',
help='Use multi_tensor_apply for EMA')
io = parser.add_argument_group('feature and checkpointing setup')
io.add_argument('--dali_device', type=str, choices=['none', 'cpu', 'gpu'],
default='gpu', help='Use DALI pipeline for fast data processing')
io.add_argument('--resume', action='store_true',
help='Try to resume from last saved checkpoint.')
io.add_argument('--ckpt', default=None, type=str,
help='Path to a checkpoint for resuming training')
io.add_argument('--save_frequency', default=10, type=int,
help='Checkpoint saving frequency in epochs')
io.add_argument('--keep_milestones', default=[100, 200, 300], type=int, nargs='+',
help='Milestone checkpoints to keep from removing')
io.add_argument('--save_best_from', default=380, type=int,
help='Epoch on which to begin tracking best checkpoint (dev WER)')
io.add_argument('--eval_frequency', default=200, type=int,
help='Number of steps between evaluations on dev set')
io.add_argument('--log_frequency', default=25, type=int,
help='Number of steps between printing training stats')
io.add_argument('--prediction_frequency', default=100, type=int,
help='Number of steps between printing sample decodings')
io.add_argument('--model_config', type=str, required=True,
help='Path of the model configuration file')
io.add_argument('--train_manifests', type=str, required=True, nargs='+',
help='Paths of the training dataset manifest file')
io.add_argument('--val_manifests', type=str, required=True, nargs='+',
help='Paths of the evaluation datasets manifest files')
io.add_argument('--dataset_dir', required=True, type=str,
help='Root dir of dataset')
io.add_argument('--output_dir', type=str, required=True,
help='Directory for logs and checkpoints')
io.add_argument('--log_file', type=str, default=None,
help='Path to save the training logfile.')
io.add_argument('--benchmark_epochs_num', type=int, default=1,
help='Number of epochs accounted in final average throughput.')
io.add_argument('--override_config', type=str, action='append',
help='Overrides arbitrary config value.'
' Syntax: `--override_config nested.config.key=val`.')
return parser.parse_args()
def reduce_tensor(tensor, num_gpus):
rt = tensor.clone()
dist.all_reduce(rt, op=dist.ReduceOp.SUM)
return rt.true_divide(num_gpus)
def init_multi_tensor_ema(model, ema_model):
model_weights = list(model.state_dict().values())
ema_model_weights = list(ema_model.state_dict().values())
ema_overflow_buf = torch.cuda.IntTensor([0])
return model_weights, ema_model_weights, ema_overflow_buf
def apply_multi_tensor_ema(decay, model_weights, ema_model_weights, overflow_buf):
amp_C.multi_tensor_axpby(
65536, overflow_buf,
[ema_model_weights, model_weights, ema_model_weights],
decay, 1-decay, -1)
def apply_ema(model, ema_model, decay):
if not decay:
return
sd = getattr(model, 'module', model).state_dict()
for k, v in ema_model.state_dict().items():
v.copy_(decay * v + (1 - decay) * sd[k])
@torch.no_grad()
def evaluate(epoch, step, val_loader, val_feat_proc, labels, model,
ema_model, ctc_loss, greedy_decoder, use_amp, use_dali=False):
for model, subset in [(model, 'dev'), (ema_model, 'dev_ema')]:
if model is None:
continue
model.eval()
start_time = time.time()
agg = {'losses': [], 'preds': [], 'txts': []}
for batch in val_loader:
if use_dali:
# with DALI, the data is already on GPU
feat, feat_lens, txt, txt_lens = batch
if val_feat_proc is not None:
feat, feat_lens = val_feat_proc(feat, feat_lens)
else:
batch = [t.cuda(non_blocking=True) for t in batch]
audio, audio_lens, txt, txt_lens = batch
feat, feat_lens = val_feat_proc(audio, audio_lens)
with torch.cuda.amp.autocast(enabled=use_amp):
log_probs, enc_lens = model(feat, feat_lens)
loss = ctc_loss(log_probs, txt, enc_lens, txt_lens)
pred = greedy_decoder(log_probs)
agg['losses'] += helpers.gather_losses([loss])
agg['preds'] += helpers.gather_predictions([pred], labels)
agg['txts'] += helpers.gather_transcripts([txt], [txt_lens], labels)
wer, loss = process_evaluation_epoch(agg)
log((epoch,), step, subset, {'loss': loss, 'wer': 100.0 * wer,
'took': time.time() - start_time})
model.train()
return wer
def main():
args = parse_args()
assert(torch.cuda.is_available())
assert args.prediction_frequency % args.log_frequency == 0
torch.backends.cudnn.benchmark = args.cudnn_benchmark
# set up distributed training
multi_gpu = int(os.environ.get('WORLD_SIZE', 1)) > 1
if multi_gpu:
torch.cuda.set_device(args.local_rank)
dist.init_process_group(backend='nccl', init_method='env://')
world_size = dist.get_world_size()
print_once(f'Distributed training with {world_size} GPUs\n')
else:
world_size = 1
torch.manual_seed(args.seed + args.local_rank)
np.random.seed(args.seed + args.local_rank)
random.seed(args.seed + args.local_rank)
init_log(args)
cfg = config.load(args.model_config)
config.apply_config_overrides(cfg, args)
symbols = helpers.add_ctc_blank(cfg['labels'])
assert args.grad_accumulation >= 1
batch_size = args.gpu_batch_size
print_once('Setting up datasets...')
train_dataset_kw, train_features_kw = config.input(cfg, 'train')
val_dataset_kw, val_features_kw = config.input(cfg, 'val')
use_dali = args.dali_device in ('cpu', 'gpu')
if use_dali:
assert train_dataset_kw['ignore_offline_speed_perturbation'], \
"DALI doesn't support offline speed perturbation"
# pad_to_max_duration is not supported by DALI - have simple padders
if train_features_kw['pad_to_max_duration']:
train_feat_proc = BaseFeatures(
pad_align=train_features_kw['pad_align'],
pad_to_max_duration=True,
max_duration=train_features_kw['max_duration'],
sample_rate=train_features_kw['sample_rate'],
window_size=train_features_kw['window_size'],
window_stride=train_features_kw['window_stride'])
train_features_kw['pad_to_max_duration'] = False
else:
train_feat_proc = None
if val_features_kw['pad_to_max_duration']:
val_feat_proc = BaseFeatures(
pad_align=val_features_kw['pad_align'],
pad_to_max_duration=True,
max_duration=val_features_kw['max_duration'],
sample_rate=val_features_kw['sample_rate'],
window_size=val_features_kw['window_size'],
window_stride=val_features_kw['window_stride'])
val_features_kw['pad_to_max_duration'] = False
else:
val_feat_proc = None
train_loader = DaliDataLoader(gpu_id=args.local_rank,
dataset_path=args.dataset_dir,
config_data=train_dataset_kw,
config_features=train_features_kw,
json_names=args.train_manifests,
batch_size=batch_size,
grad_accumulation_steps=args.grad_accumulation,
pipeline_type="train",
device_type=args.dali_device,
symbols=symbols)
val_loader = DaliDataLoader(gpu_id=args.local_rank,
dataset_path=args.dataset_dir,
config_data=val_dataset_kw,
config_features=val_features_kw,
json_names=args.val_manifests,
batch_size=batch_size,
pipeline_type="val",
device_type=args.dali_device,
symbols=symbols)
else:
train_dataset_kw, train_features_kw = config.input(cfg, 'train')
train_dataset = AudioDataset(args.dataset_dir,
args.train_manifests,
symbols,
**train_dataset_kw)
train_loader = get_data_loader(train_dataset,
batch_size,
multi_gpu=multi_gpu,
shuffle=True,
num_workers=4)
train_feat_proc = FilterbankFeatures(**train_features_kw)
val_dataset_kw, val_features_kw = config.input(cfg, 'val')
val_dataset = AudioDataset(args.dataset_dir,
args.val_manifests,
symbols,
**val_dataset_kw)
val_loader = get_data_loader(val_dataset,
batch_size,
multi_gpu=multi_gpu,
shuffle=False,
num_workers=4,
drop_last=False)
val_feat_proc = FilterbankFeatures(**val_features_kw)
dur = train_dataset.duration / 3600
dur_f = train_dataset.duration_filtered / 3600
nsampl = len(train_dataset)
print_once(f'Training samples: {nsampl} ({dur:.1f}h, '
f'filtered {dur_f:.1f}h)')
if train_feat_proc is not None:
train_feat_proc.cuda()
if val_feat_proc is not None:
val_feat_proc.cuda()
steps_per_epoch = len(train_loader) // args.grad_accumulation
# set up the model
model = QuartzNet(encoder_kw=config.encoder(cfg),
decoder_kw=config.decoder(cfg, n_classes=len(symbols)))
model.cuda()
ctc_loss = CTCLossNM(n_classes=len(symbols))
greedy_decoder = GreedyCTCDecoder()
print_once(f'Model size: {num_weights(model) / 10**6:.1f}M params\n')
# optimization
kw = {'lr': args.lr, 'weight_decay': args.weight_decay}
if args.optimizer == "novograd":
optimizer = Novograd(model.parameters(), **kw)
elif args.optimizer == "adamw":
optimizer = AdamW(model.parameters(), **kw)
elif args.optimizer == 'lamb98':
optimizer = FusedLAMB(model.parameters(), betas=(0.9, 0.98), eps=1e-9,
**kw)
elif args.optimizer == 'fused_novograd':
optimizer = FusedNovoGrad(model.parameters(), betas=(0.95, 0),
bias_correction=False, reg_inside_moment=True,
grad_averaging=False, **kw)
else:
raise ValueError(f'Invalid optimizer "{args.optimizer}"')
scaler = torch.cuda.amp.GradScaler(enabled=args.amp)
adjust_lr = lambda step, epoch, optimizer: lr_policy(
step, epoch, args.lr, optimizer, steps_per_epoch=steps_per_epoch,
warmup_epochs=args.warmup_epochs, hold_epochs=args.hold_epochs,
num_epochs=args.epochs, policy=args.lr_policy, min_lr=args.min_lr,
exp_gamma=args.lr_exp_gamma)
if args.ema > 0:
ema_model = copy.deepcopy(model)
else:
ema_model = None
if multi_gpu:
model = torch.nn.parallel.DistributedDataParallel(
model, device_ids=[args.local_rank], output_device=args.local_rank)
if args.pyprof:
pyprof.init(enable_function_stack=True)
# load checkpoint
meta = {'best_wer': 10**6, 'start_epoch': 0}
checkpointer = Checkpointer(args.output_dir, 'QuartzNet',
args.keep_milestones)
if args.resume:
args.ckpt = checkpointer.last_checkpoint() or args.ckpt
if args.ckpt is not None:
checkpointer.load(args.ckpt, model, ema_model, optimizer, scaler, meta)
start_epoch = meta['start_epoch']
best_wer = meta['best_wer']
epoch = 1
step = start_epoch * steps_per_epoch + 1
if args.pyprof:
torch.autograd.profiler.emit_nvtx().__enter__()
profiler.start()
# training loop
model.train()
if args.ema > 0.0:
mt_ema_params = init_multi_tensor_ema(model, ema_model)
# ema_model_weight_list, model_weight_list, overflow_buf_for_ema = ema_
# pre-allocate
if args.pre_allocate_range is not None:
n_feats = train_features_kw['n_filt']
pad_align = train_features_kw['pad_align']
a, b = args.pre_allocate_range
for n_frames in range(a, b + pad_align, pad_align):
print_once(f'Pre-allocation ({batch_size}x{n_feats}x{n_frames})...')
feat = torch.randn(batch_size, n_feats, n_frames, device='cuda')
feat_lens = torch.ones(batch_size, device='cuda').fill_(n_frames)
txt = torch.randint(high=len(symbols)-1, size=(batch_size, 100),
device='cuda')
txt_lens = torch.ones(batch_size, device='cuda').fill_(100)
with torch.cuda.amp.autocast(enabled=args.amp):
log_probs, enc_lens = model(feat, feat_lens)
del feat
loss = ctc_loss(log_probs, txt, enc_lens, txt_lens)
loss.backward()
model.zero_grad()
torch.cuda.empty_cache()
bmark_stats = BenchmarkStats()
for epoch in range(start_epoch + 1, args.epochs + 1):
if multi_gpu and not use_dali:
train_loader.sampler.set_epoch(epoch)
epoch_utts = 0
epoch_loss = 0
accumulated_batches = 0
epoch_start_time = time.time()
epoch_eval_time = 0
for batch in train_loader:
if accumulated_batches == 0:
step_loss = 0
step_utts = 0
step_start_time = time.time()
if use_dali:
# with DALI, the data is already on GPU
feat, feat_lens, txt, txt_lens = batch
if train_feat_proc is not None:
feat, feat_lens = train_feat_proc(feat, feat_lens)
else:
batch = [t.cuda(non_blocking=True) for t in batch]
audio, audio_lens, txt, txt_lens = batch
feat, feat_lens = train_feat_proc(audio, audio_lens)
# Use context manager to prevent redundant accumulation of gradients
if (multi_gpu and accumulated_batches + 1 < args.grad_accumulation):
ctx = model.no_sync()
else:
ctx = empty_context()
with ctx:
with torch.cuda.amp.autocast(enabled=args.amp):
log_probs, enc_lens = model(feat, feat_lens)
loss = ctc_loss(log_probs, txt, enc_lens, txt_lens)
loss /= args.grad_accumulation
if multi_gpu:
reduced_loss = reduce_tensor(loss.data, world_size)
else:
reduced_loss = loss
if torch.isnan(reduced_loss).any():
print_once(f'WARNING: loss is NaN; skipping update')
continue
else:
step_loss += reduced_loss.item()
step_utts += batch[0].size(0) * world_size
epoch_utts += batch[0].size(0) * world_size
accumulated_batches += 1
scaler.scale(loss).backward()
if accumulated_batches % args.grad_accumulation == 0:
epoch_loss += step_loss
scaler.step(optimizer)
scaler.update()
adjust_lr(step, epoch, optimizer)
optimizer.zero_grad()
if args.ema > 0.0:
apply_multi_tensor_ema(args.ema, *mt_ema_params)
if step % args.log_frequency == 0:
preds = greedy_decoder(log_probs)
wer, pred_utt, ref = greedy_wer(preds, txt, txt_lens, symbols)
if step % args.prediction_frequency == 0:
print_once(f' Decoded: {pred_utt[:90]}')
print_once(f' Reference: {ref[:90]}')
step_time = time.time() - step_start_time
log((epoch, step % steps_per_epoch or steps_per_epoch, steps_per_epoch),
step, 'train',
{'loss': step_loss,
'wer': 100.0 * wer,
'throughput': step_utts / step_time,
'took': step_time,
'lrate': optimizer.param_groups[0]['lr']})
step_start_time = time.time()
if step % args.eval_frequency == 0:
tik = time.time()
wer = evaluate(epoch, step, val_loader, val_feat_proc,
symbols, model, ema_model, ctc_loss,
greedy_decoder, args.amp, use_dali)
if wer < best_wer and epoch >= args.save_best_from:
checkpointer.save(model, ema_model, optimizer, scaler,
epoch, step, best_wer, is_best=True)
best_wer = wer
epoch_eval_time += time.time() - tik
step += 1
accumulated_batches = 0
# end of step
# DALI iterator need to be exhausted;
# if not using DALI, simulate drop_last=True with grad accumulation
if not use_dali and step > steps_per_epoch * epoch:
break
epoch_time = time.time() - epoch_start_time
epoch_loss /= steps_per_epoch
log((epoch,), None, 'train_avg', {'throughput': epoch_utts / epoch_time,
'took': epoch_time,
'loss': epoch_loss})
bmark_stats.update(epoch_utts, epoch_time, epoch_loss)
if epoch % args.save_frequency == 0 or epoch in args.keep_milestones:
checkpointer.save(model, ema_model, optimizer, scaler, epoch, step,
best_wer)
if 0 < args.epochs_this_job <= epoch - start_epoch:
print_once(f'Finished after {args.epochs_this_job} epochs.')
break
# end of epoch
if args.pyprof:
profiler.stop()
torch.autograd.profiler.emit_nvtx().__exit__(None, None, None)
log((), None, 'train_avg', bmark_stats.get(args.benchmark_epochs_num))
if epoch == args.epochs:
evaluate(epoch, step, val_loader, val_feat_proc, symbols, model,
ema_model, ctc_loss, greedy_decoder, args.amp, use_dali)
checkpointer.save(model, ema_model, optimizer, scaler, epoch, step,
best_wer)
flush_log()
if __name__ == "__main__":
main()