668 lines
27 KiB
Python
668 lines
27 KiB
Python
# *****************************************************************************
|
|
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
|
#
|
|
# Redistribution and use in source and binary forms, with or without
|
|
# modification, are permitted provided that the following conditions are met:
|
|
# * Redistributions of source code must retain the above copyright
|
|
# notice, this list of conditions and the following disclaimer.
|
|
# * Redistributions in binary form must reproduce the above copyright
|
|
# notice, this list of conditions and the following disclaimer in the
|
|
# documentation and/or other materials provided with the distribution.
|
|
# * Neither the name of the NVIDIA CORPORATION nor the
|
|
# names of its contributors may be used to endorse or promote products
|
|
# derived from this software without specific prior written permission.
|
|
#
|
|
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
|
# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
|
# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
|
# DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
|
|
# DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
|
# (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
|
# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
|
|
# ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
|
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
|
# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
#
|
|
# *****************************************************************************
|
|
|
|
import argparse
|
|
import copy
|
|
import glob
|
|
import os
|
|
import re
|
|
import time
|
|
import warnings
|
|
from collections import defaultdict, OrderedDict
|
|
|
|
try:
|
|
import nvidia_dlprof_pytorch_nvtx as pyprof
|
|
except ModuleNotFoundError:
|
|
try:
|
|
import pyprof
|
|
except ModuleNotFoundError:
|
|
warnings.warn('PyProf is unavailable')
|
|
|
|
import numpy as np
|
|
import torch
|
|
import torch.cuda.profiler as profiler
|
|
import torch.distributed as dist
|
|
import amp_C
|
|
from apex.optimizers import FusedAdam, FusedLAMB
|
|
from torch.nn.parallel import DistributedDataParallel
|
|
from torch.utils.data import DataLoader
|
|
from torch.utils.data.distributed import DistributedSampler
|
|
|
|
import common.tb_dllogger as logger
|
|
import models
|
|
from common.text import cmudict
|
|
from common.utils import prepare_tmp
|
|
from fastpitch.attn_loss_function import AttentionBinarizationLoss
|
|
from fastpitch.data_function import batch_to_gpu, TTSCollate, TTSDataset
|
|
from fastpitch.loss_function import FastPitchLoss
|
|
|
|
|
|
def parse_args(parser):
|
|
parser.add_argument('-o', '--output', type=str, required=True,
|
|
help='Directory to save checkpoints')
|
|
parser.add_argument('-d', '--dataset-path', type=str, default='./',
|
|
help='Path to dataset')
|
|
parser.add_argument('--log-file', type=str, default=None,
|
|
help='Path to a DLLogger log file')
|
|
parser.add_argument('--pyprof', action='store_true',
|
|
help='Enable pyprof profiling')
|
|
|
|
train = parser.add_argument_group('training setup')
|
|
train.add_argument('--epochs', type=int, required=True,
|
|
help='Number of total epochs to run')
|
|
train.add_argument('--epochs-per-checkpoint', type=int, default=50,
|
|
help='Number of epochs per checkpoint')
|
|
train.add_argument('--checkpoint-path', type=str, default=None,
|
|
help='Checkpoint path to resume training')
|
|
train.add_argument('--resume', action='store_true',
|
|
help='Resume training from the last checkpoint')
|
|
train.add_argument('--seed', type=int, default=1234,
|
|
help='Seed for PyTorch random number generators')
|
|
train.add_argument('--amp', action='store_true',
|
|
help='Enable AMP')
|
|
train.add_argument('--cuda', action='store_true',
|
|
help='Run on GPU using CUDA')
|
|
train.add_argument('--cudnn-benchmark', action='store_true',
|
|
help='Enable cudnn benchmark mode')
|
|
train.add_argument('--ema-decay', type=float, default=0,
|
|
help='Discounting factor for training weights EMA')
|
|
train.add_argument('--grad-accumulation', type=int, default=1,
|
|
help='Training steps to accumulate gradients for')
|
|
train.add_argument('--kl-loss-start-epoch', type=int, default=250,
|
|
help='Start adding the hard attention loss term')
|
|
train.add_argument('--kl-loss-warmup-epochs', type=int, default=100,
|
|
help='Gradually increase the hard attention loss term')
|
|
train.add_argument('--kl-loss-weight', type=float, default=1.0,
|
|
help='Gradually increase the hard attention loss term')
|
|
|
|
opt = parser.add_argument_group('optimization setup')
|
|
opt.add_argument('--optimizer', type=str, default='lamb',
|
|
help='Optimization algorithm')
|
|
opt.add_argument('-lr', '--learning-rate', type=float, required=True,
|
|
help='Learing rate')
|
|
opt.add_argument('--weight-decay', default=1e-6, type=float,
|
|
help='Weight decay')
|
|
opt.add_argument('--grad-clip-thresh', default=1000.0, type=float,
|
|
help='Clip threshold for gradients')
|
|
opt.add_argument('-bs', '--batch-size', type=int, required=True,
|
|
help='Batch size per GPU')
|
|
opt.add_argument('--warmup-steps', type=int, default=1000,
|
|
help='Number of steps for lr warmup')
|
|
opt.add_argument('--dur-predictor-loss-scale', type=float,
|
|
default=1.0, help='Rescale duration predictor loss')
|
|
opt.add_argument('--pitch-predictor-loss-scale', type=float,
|
|
default=1.0, help='Rescale pitch predictor loss')
|
|
opt.add_argument('--attn-loss-scale', type=float,
|
|
default=1.0, help='Rescale alignment loss')
|
|
|
|
data = parser.add_argument_group('dataset parameters')
|
|
data.add_argument('--training-files', type=str, nargs='*', required=True,
|
|
help='Paths to training filelists.')
|
|
data.add_argument('--validation-files', type=str, nargs='*',
|
|
required=True, help='Paths to validation filelists')
|
|
data.add_argument('--text-cleaners', nargs='*',
|
|
default=['english_cleaners'], type=str,
|
|
help='Type of text cleaners for input text')
|
|
data.add_argument('--symbol-set', type=str, default='english_basic',
|
|
help='Define symbol set for input text')
|
|
data.add_argument('--p-arpabet', type=float, default=0.0,
|
|
help='Probability of using arpabets instead of graphemes '
|
|
'for each word; set 0 for pure grapheme training')
|
|
data.add_argument('--heteronyms-path', type=str, default='cmudict/heteronyms',
|
|
help='Path to the list of heteronyms')
|
|
data.add_argument('--cmudict-path', type=str, default='cmudict/cmudict-0.7b',
|
|
help='Path to the pronouncing dictionary')
|
|
data.add_argument('--prepend-space-to-text', action='store_true',
|
|
help='Capture leading silence with a space token')
|
|
data.add_argument('--append-space-to-text', action='store_true',
|
|
help='Capture trailing silence with a space token')
|
|
|
|
cond = parser.add_argument_group('data for conditioning')
|
|
cond.add_argument('--n-speakers', type=int, default=1,
|
|
help='Number of speakers in the dataset. '
|
|
'n_speakers > 1 enables speaker embeddings')
|
|
cond.add_argument('--load-pitch-from-disk', action='store_true',
|
|
help='Use pitch cached on disk with prepare_dataset.py')
|
|
cond.add_argument('--pitch-online-method', default='pyin',
|
|
choices=['pyin'],
|
|
help='Calculate pitch on the fly during trainig')
|
|
cond.add_argument('--pitch-online-dir', type=str, default=None,
|
|
help='A directory for storing pitch calculated on-line')
|
|
cond.add_argument('--pitch-mean', type=float, default=214.72203,
|
|
help='Normalization value for pitch')
|
|
cond.add_argument('--pitch-std', type=float, default=65.72038,
|
|
help='Normalization value for pitch')
|
|
cond.add_argument('--load-mel-from-disk', action='store_true',
|
|
help='Use mel-spectrograms cache on the disk') # XXX
|
|
|
|
audio = parser.add_argument_group('audio parameters')
|
|
audio.add_argument('--max-wav-value', default=32768.0, type=float,
|
|
help='Maximum audiowave value')
|
|
audio.add_argument('--sampling-rate', default=22050, type=int,
|
|
help='Sampling rate')
|
|
audio.add_argument('--filter-length', default=1024, type=int,
|
|
help='Filter length')
|
|
audio.add_argument('--hop-length', default=256, type=int,
|
|
help='Hop (stride) length')
|
|
audio.add_argument('--win-length', default=1024, type=int,
|
|
help='Window length')
|
|
audio.add_argument('--mel-fmin', default=0.0, type=float,
|
|
help='Minimum mel frequency')
|
|
audio.add_argument('--mel-fmax', default=8000.0, type=float,
|
|
help='Maximum mel frequency')
|
|
|
|
dist = parser.add_argument_group('distributed setup')
|
|
dist.add_argument('--local_rank', type=int, default=os.getenv('LOCAL_RANK', 0),
|
|
help='Rank of the process for multiproc; do not set manually')
|
|
dist.add_argument('--world_size', type=int, default=os.getenv('WORLD_SIZE', 1),
|
|
help='Number of processes for multiproc; do not set manually')
|
|
return parser
|
|
|
|
|
|
def reduce_tensor(tensor, num_gpus):
|
|
rt = tensor.clone()
|
|
dist.all_reduce(rt, op=dist.ReduceOp.SUM)
|
|
return rt.true_divide(num_gpus)
|
|
|
|
|
|
def init_distributed(args, world_size, rank):
|
|
assert torch.cuda.is_available(), "Distributed mode requires CUDA."
|
|
print("Initializing distributed training")
|
|
|
|
# Set cuda device so everything is done on the right GPU.
|
|
torch.cuda.set_device(rank % torch.cuda.device_count())
|
|
|
|
# Initialize distributed communication
|
|
dist.init_process_group(backend=('nccl' if args.cuda else 'gloo'),
|
|
init_method='env://')
|
|
print("Done initializing distributed training")
|
|
|
|
|
|
def last_checkpoint(output):
|
|
|
|
def corrupted(fpath):
|
|
try:
|
|
torch.load(fpath, map_location='cpu')
|
|
return False
|
|
except:
|
|
warnings.warn(f'Cannot load {fpath}')
|
|
return True
|
|
|
|
saved = sorted(
|
|
glob.glob(f'{output}/FastPitch_checkpoint_*.pt'),
|
|
key=lambda f: int(re.search('_(\d+).pt', f).group(1)))
|
|
|
|
if len(saved) >= 1 and not corrupted(saved[-1]):
|
|
return saved[-1]
|
|
elif len(saved) >= 2:
|
|
return saved[-2]
|
|
else:
|
|
return None
|
|
|
|
|
|
def maybe_save_checkpoint(args, model, ema_model, optimizer, scaler, epoch,
|
|
total_iter, config, final_checkpoint=False):
|
|
if args.local_rank != 0:
|
|
return
|
|
|
|
intermediate = (args.epochs_per_checkpoint > 0
|
|
and epoch % args.epochs_per_checkpoint == 0)
|
|
|
|
if not intermediate and epoch < args.epochs:
|
|
return
|
|
|
|
fpath = os.path.join(args.output, f"FastPitch_checkpoint_{epoch}.pt")
|
|
print(f"Saving model and optimizer state at epoch {epoch} to {fpath}")
|
|
ema_dict = None if ema_model is None else ema_model.state_dict()
|
|
checkpoint = {'epoch': epoch,
|
|
'iteration': total_iter,
|
|
'config': config,
|
|
'state_dict': model.state_dict(),
|
|
'ema_state_dict': ema_dict,
|
|
'optimizer': optimizer.state_dict()}
|
|
if args.amp:
|
|
checkpoint['scaler'] = scaler.state_dict()
|
|
torch.save(checkpoint, fpath)
|
|
|
|
|
|
def load_checkpoint(args, model, ema_model, optimizer, scaler, epoch,
|
|
total_iter, config, filepath):
|
|
if args.local_rank == 0:
|
|
print(f'Loading model and optimizer state from {filepath}')
|
|
checkpoint = torch.load(filepath, map_location='cpu')
|
|
epoch[0] = checkpoint['epoch'] + 1
|
|
total_iter[0] = checkpoint['iteration']
|
|
|
|
sd = {k.replace('module.', ''): v
|
|
for k, v in checkpoint['state_dict'].items()}
|
|
getattr(model, 'module', model).load_state_dict(sd)
|
|
optimizer.load_state_dict(checkpoint['optimizer'])
|
|
|
|
if args.amp:
|
|
scaler.load_state_dict(checkpoint['scaler'])
|
|
|
|
if ema_model is not None:
|
|
ema_model.load_state_dict(checkpoint['ema_state_dict'])
|
|
|
|
|
|
def validate(model, epoch, total_iter, criterion, valset, batch_size,
|
|
collate_fn, distributed_run, batch_to_gpu, ema=False):
|
|
"""Handles all the validation scoring and printing"""
|
|
was_training = model.training
|
|
model.eval()
|
|
|
|
tik = time.perf_counter()
|
|
with torch.no_grad():
|
|
val_sampler = DistributedSampler(valset) if distributed_run else None
|
|
val_loader = DataLoader(valset, num_workers=4, shuffle=False,
|
|
sampler=val_sampler,
|
|
batch_size=batch_size, pin_memory=False,
|
|
collate_fn=collate_fn)
|
|
val_meta = defaultdict(float)
|
|
val_num_frames = 0
|
|
for i, batch in enumerate(val_loader):
|
|
x, y, num_frames = batch_to_gpu(batch)
|
|
y_pred = model(x)
|
|
loss, meta = criterion(y_pred, y, is_training=False, meta_agg='sum')
|
|
|
|
if distributed_run:
|
|
for k, v in meta.items():
|
|
val_meta[k] += reduce_tensor(v, 1)
|
|
val_num_frames += reduce_tensor(num_frames.data, 1).item()
|
|
else:
|
|
for k, v in meta.items():
|
|
val_meta[k] += v
|
|
val_num_frames = num_frames.item()
|
|
|
|
val_meta = {k: v / len(valset) for k, v in val_meta.items()}
|
|
|
|
val_meta['took'] = time.perf_counter() - tik
|
|
|
|
logger.log((epoch,) if epoch is not None else (),
|
|
tb_total_steps=total_iter,
|
|
subset='val_ema' if ema else 'val',
|
|
data=OrderedDict([
|
|
('loss', val_meta['loss'].item()),
|
|
('mel_loss', val_meta['mel_loss'].item()),
|
|
('frames/s', num_frames.item() / val_meta['took']),
|
|
('took', val_meta['took'])]),
|
|
)
|
|
|
|
if was_training:
|
|
model.train()
|
|
return val_meta
|
|
|
|
|
|
def adjust_learning_rate(total_iter, opt, learning_rate, warmup_iters=None):
|
|
if warmup_iters == 0:
|
|
scale = 1.0
|
|
elif total_iter > warmup_iters:
|
|
scale = 1. / (total_iter ** 0.5)
|
|
else:
|
|
scale = total_iter / (warmup_iters ** 1.5)
|
|
|
|
for param_group in opt.param_groups:
|
|
param_group['lr'] = learning_rate * scale
|
|
|
|
|
|
def apply_ema_decay(model, ema_model, decay):
|
|
if not decay:
|
|
return
|
|
st = model.state_dict()
|
|
add_module = hasattr(model, 'module') and not hasattr(ema_model, 'module')
|
|
for k, v in ema_model.state_dict().items():
|
|
if add_module and not k.startswith('module.'):
|
|
k = 'module.' + k
|
|
v.copy_(decay * v + (1 - decay) * st[k])
|
|
|
|
|
|
def init_multi_tensor_ema(model, ema_model):
|
|
model_weights = list(model.state_dict().values())
|
|
ema_model_weights = list(ema_model.state_dict().values())
|
|
ema_overflow_buf = torch.cuda.IntTensor([0])
|
|
return model_weights, ema_model_weights, ema_overflow_buf
|
|
|
|
|
|
def apply_multi_tensor_ema(decay, model_weights, ema_weights, overflow_buf):
|
|
amp_C.multi_tensor_axpby(
|
|
65536, overflow_buf, [ema_weights, model_weights, ema_weights],
|
|
decay, 1-decay, -1)
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(description='PyTorch FastPitch Training',
|
|
allow_abbrev=False)
|
|
parser = parse_args(parser)
|
|
args, _ = parser.parse_known_args()
|
|
|
|
if args.p_arpabet > 0.0:
|
|
cmudict.initialize(args.cmudict_path, keep_ambiguous=True)
|
|
|
|
distributed_run = args.world_size > 1
|
|
|
|
torch.manual_seed(args.seed + args.local_rank)
|
|
np.random.seed(args.seed + args.local_rank)
|
|
|
|
if args.local_rank == 0:
|
|
if not os.path.exists(args.output):
|
|
os.makedirs(args.output)
|
|
|
|
log_fpath = args.log_file or os.path.join(args.output, 'nvlog.json')
|
|
tb_subsets = ['train', 'val']
|
|
if args.ema_decay > 0.0:
|
|
tb_subsets.append('val_ema')
|
|
|
|
logger.init(log_fpath, args.output, enabled=(args.local_rank == 0),
|
|
tb_subsets=tb_subsets)
|
|
logger.parameters(vars(args), tb_subset='train')
|
|
|
|
parser = models.parse_model_args('FastPitch', parser)
|
|
args, unk_args = parser.parse_known_args()
|
|
if len(unk_args) > 0:
|
|
raise ValueError(f'Invalid options {unk_args}')
|
|
|
|
torch.backends.cudnn.benchmark = args.cudnn_benchmark
|
|
|
|
if distributed_run:
|
|
init_distributed(args, args.world_size, args.local_rank)
|
|
|
|
device = torch.device('cuda' if args.cuda else 'cpu')
|
|
model_config = models.get_model_config('FastPitch', args)
|
|
model = models.get_model('FastPitch', model_config, device)
|
|
|
|
attention_kl_loss = AttentionBinarizationLoss()
|
|
|
|
# Store pitch mean/std as params to translate from Hz during inference
|
|
model.pitch_mean[0] = args.pitch_mean
|
|
model.pitch_std[0] = args.pitch_std
|
|
|
|
kw = dict(lr=args.learning_rate, betas=(0.9, 0.98), eps=1e-9,
|
|
weight_decay=args.weight_decay)
|
|
if args.optimizer == 'adam':
|
|
optimizer = FusedAdam(model.parameters(), **kw)
|
|
elif args.optimizer == 'lamb':
|
|
optimizer = FusedLAMB(model.parameters(), **kw)
|
|
else:
|
|
raise ValueError
|
|
|
|
scaler = torch.cuda.amp.GradScaler(enabled=args.amp)
|
|
|
|
if args.ema_decay > 0:
|
|
ema_model = copy.deepcopy(model)
|
|
else:
|
|
ema_model = None
|
|
|
|
if distributed_run:
|
|
model = DistributedDataParallel(
|
|
model, device_ids=[args.local_rank], output_device=args.local_rank,
|
|
find_unused_parameters=True)
|
|
|
|
if args.pyprof:
|
|
pyprof.init(enable_function_stack=True)
|
|
|
|
start_epoch = [1]
|
|
start_iter = [0]
|
|
|
|
assert args.checkpoint_path is None or args.resume is False, (
|
|
"Specify a single checkpoint source")
|
|
if args.checkpoint_path is not None:
|
|
ch_fpath = args.checkpoint_path
|
|
elif args.resume:
|
|
ch_fpath = last_checkpoint(args.output)
|
|
else:
|
|
ch_fpath = None
|
|
|
|
if ch_fpath is not None:
|
|
load_checkpoint(args, model, ema_model, optimizer, scaler,
|
|
start_epoch, start_iter, model_config, ch_fpath)
|
|
|
|
start_epoch = start_epoch[0]
|
|
total_iter = start_iter[0]
|
|
|
|
criterion = FastPitchLoss(
|
|
dur_predictor_loss_scale=args.dur_predictor_loss_scale,
|
|
pitch_predictor_loss_scale=args.pitch_predictor_loss_scale,
|
|
attn_loss_scale=args.attn_loss_scale)
|
|
|
|
collate_fn = TTSCollate()
|
|
|
|
if args.local_rank == 0:
|
|
prepare_tmp(args.pitch_online_dir)
|
|
|
|
trainset = TTSDataset(audiopaths_and_text=args.training_files, **vars(args))
|
|
valset = TTSDataset(audiopaths_and_text=args.validation_files, **vars(args))
|
|
|
|
if distributed_run:
|
|
train_sampler, shuffle = DistributedSampler(trainset), False
|
|
else:
|
|
train_sampler, shuffle = None, True
|
|
|
|
# 4 workers are optimal on DGX-1 (from epoch 2 onwards)
|
|
train_loader = DataLoader(trainset, num_workers=4, shuffle=shuffle,
|
|
sampler=train_sampler, batch_size=args.batch_size,
|
|
pin_memory=True, persistent_workers=True,
|
|
drop_last=True, collate_fn=collate_fn)
|
|
|
|
if args.ema_decay:
|
|
mt_ema_params = init_multi_tensor_ema(model, ema_model)
|
|
|
|
model.train()
|
|
|
|
if args.pyprof:
|
|
torch.autograd.profiler.emit_nvtx().__enter__()
|
|
profiler.start()
|
|
|
|
epoch_loss = []
|
|
epoch_mel_loss = []
|
|
epoch_num_frames = []
|
|
epoch_frames_per_sec = []
|
|
epoch_time = []
|
|
|
|
torch.cuda.synchronize()
|
|
for epoch in range(start_epoch, args.epochs + 1):
|
|
epoch_start_time = time.perf_counter()
|
|
|
|
epoch_loss += [0.0]
|
|
epoch_mel_loss += [0.0]
|
|
epoch_num_frames += [0]
|
|
epoch_frames_per_sec += [0.0]
|
|
|
|
if distributed_run:
|
|
train_loader.sampler.set_epoch(epoch)
|
|
|
|
accumulated_steps = 0
|
|
iter_loss = 0
|
|
iter_num_frames = 0
|
|
iter_meta = {}
|
|
iter_start_time = None
|
|
|
|
epoch_iter = 0
|
|
num_iters = len(train_loader) // args.grad_accumulation
|
|
for batch in train_loader:
|
|
|
|
if accumulated_steps == 0:
|
|
if epoch_iter == num_iters:
|
|
break
|
|
total_iter += 1
|
|
epoch_iter += 1
|
|
if iter_start_time is None:
|
|
iter_start_time = time.perf_counter()
|
|
|
|
adjust_learning_rate(total_iter, optimizer, args.learning_rate,
|
|
args.warmup_steps)
|
|
|
|
model.zero_grad(set_to_none=True)
|
|
|
|
x, y, num_frames = batch_to_gpu(batch)
|
|
|
|
with torch.cuda.amp.autocast(enabled=args.amp):
|
|
y_pred = model(x)
|
|
loss, meta = criterion(y_pred, y)
|
|
|
|
if (args.kl_loss_start_epoch is not None
|
|
and epoch >= args.kl_loss_start_epoch):
|
|
|
|
if args.kl_loss_start_epoch == epoch and epoch_iter == 1:
|
|
print('Begin hard_attn loss')
|
|
|
|
_, _, _, _, _, _, _, _, attn_soft, attn_hard, _, _ = y_pred
|
|
binarization_loss = attention_kl_loss(attn_hard, attn_soft)
|
|
kl_weight = min((epoch - args.kl_loss_start_epoch) / args.kl_loss_warmup_epochs, 1.0) * args.kl_loss_weight
|
|
meta['kl_loss'] = binarization_loss.clone().detach() * kl_weight
|
|
loss += kl_weight * binarization_loss
|
|
|
|
else:
|
|
meta['kl_loss'] = torch.zeros_like(loss)
|
|
kl_weight = 0
|
|
binarization_loss = 0
|
|
|
|
loss /= args.grad_accumulation
|
|
|
|
meta = {k: v / args.grad_accumulation
|
|
for k, v in meta.items()}
|
|
|
|
if args.amp:
|
|
scaler.scale(loss).backward()
|
|
else:
|
|
loss.backward()
|
|
|
|
if distributed_run:
|
|
reduced_loss = reduce_tensor(loss.data, args.world_size).item()
|
|
reduced_num_frames = reduce_tensor(num_frames.data, 1).item()
|
|
meta = {k: reduce_tensor(v, args.world_size) for k, v in meta.items()}
|
|
else:
|
|
reduced_loss = loss.item()
|
|
reduced_num_frames = num_frames.item()
|
|
if np.isnan(reduced_loss):
|
|
raise Exception("loss is NaN")
|
|
|
|
accumulated_steps += 1
|
|
iter_loss += reduced_loss
|
|
iter_num_frames += reduced_num_frames
|
|
iter_meta = {k: iter_meta.get(k, 0) + meta.get(k, 0) for k in meta}
|
|
|
|
if accumulated_steps % args.grad_accumulation == 0:
|
|
|
|
logger.log_grads_tb(total_iter, model)
|
|
if args.amp:
|
|
scaler.unscale_(optimizer)
|
|
torch.nn.utils.clip_grad_norm_(
|
|
model.parameters(), args.grad_clip_thresh)
|
|
scaler.step(optimizer)
|
|
scaler.update()
|
|
else:
|
|
torch.nn.utils.clip_grad_norm_(
|
|
model.parameters(), args.grad_clip_thresh)
|
|
optimizer.step()
|
|
|
|
if args.ema_decay > 0.0:
|
|
apply_multi_tensor_ema(args.ema_decay, *mt_ema_params)
|
|
|
|
iter_time = time.perf_counter() - iter_start_time
|
|
iter_mel_loss = iter_meta['mel_loss'].item()
|
|
iter_kl_loss = iter_meta['kl_loss'].item()
|
|
epoch_frames_per_sec[-1] += iter_num_frames / iter_time
|
|
epoch_loss[-1] += iter_loss
|
|
epoch_num_frames[-1] += iter_num_frames
|
|
epoch_mel_loss[-1] += iter_mel_loss
|
|
|
|
logger.log((epoch, epoch_iter, num_iters),
|
|
tb_total_steps=total_iter,
|
|
subset='train',
|
|
data=OrderedDict([
|
|
('loss', iter_loss),
|
|
('mel_loss', iter_mel_loss),
|
|
('kl_loss', iter_kl_loss),
|
|
('kl_weight', kl_weight),
|
|
('frames/s', iter_num_frames / iter_time),
|
|
('took', iter_time),
|
|
('lrate', optimizer.param_groups[0]['lr'])]),
|
|
)
|
|
|
|
accumulated_steps = 0
|
|
iter_loss = 0
|
|
iter_num_frames = 0
|
|
iter_meta = {}
|
|
iter_start_time = time.perf_counter()
|
|
|
|
# Finished epoch
|
|
epoch_loss[-1] /= epoch_iter
|
|
epoch_mel_loss[-1] /= epoch_iter
|
|
epoch_time += [time.perf_counter() - epoch_start_time]
|
|
iter_start_time = None
|
|
|
|
logger.log((epoch,),
|
|
tb_total_steps=None,
|
|
subset='train_avg',
|
|
data=OrderedDict([
|
|
('loss', epoch_loss[-1]),
|
|
('mel_loss', epoch_mel_loss[-1]),
|
|
('frames/s', epoch_num_frames[-1] / epoch_time[-1]),
|
|
('took', epoch_time[-1])]),
|
|
)
|
|
|
|
validate(model, epoch, total_iter, criterion, valset, args.batch_size,
|
|
collate_fn, distributed_run, batch_to_gpu)
|
|
|
|
if args.ema_decay > 0:
|
|
validate(ema_model, epoch, total_iter, criterion, valset,
|
|
args.batch_size, collate_fn, distributed_run, batch_to_gpu,
|
|
ema=True)
|
|
|
|
maybe_save_checkpoint(args, model, ema_model, optimizer, scaler, epoch,
|
|
total_iter, model_config)
|
|
logger.flush()
|
|
|
|
# Finished training
|
|
if args.pyprof:
|
|
profiler.stop()
|
|
torch.autograd.profiler.emit_nvtx().__exit__(None, None, None)
|
|
|
|
if len(epoch_loss) > 0:
|
|
# Was trained - average the last 20 measurements
|
|
last_ = lambda l: np.asarray(l[-20:])
|
|
epoch_loss = last_(epoch_loss)
|
|
epoch_mel_loss = last_(epoch_mel_loss)
|
|
epoch_num_frames = last_(epoch_num_frames)
|
|
epoch_time = last_(epoch_time)
|
|
logger.log((),
|
|
tb_total_steps=None,
|
|
subset='train_avg',
|
|
data=OrderedDict([
|
|
('loss', epoch_loss.mean()),
|
|
('mel_loss', epoch_mel_loss.mean()),
|
|
('frames/s', epoch_num_frames.sum() / epoch_time.sum()),
|
|
('took', epoch_time.mean())]),
|
|
)
|
|
|
|
validate(model, None, total_iter, criterion, valset, args.batch_size,
|
|
collate_fn, distributed_run, batch_to_gpu)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|