2019-03-18 19:45:05 +01:00
|
|
|
# *****************************************************************************
|
|
|
|
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
|
|
|
#
|
|
|
|
# Redistribution and use in source and binary forms, with or without
|
|
|
|
# modification, are permitted provided that the following conditions are met:
|
|
|
|
# * Redistributions of source code must retain the above copyright
|
|
|
|
# notice, this list of conditions and the following disclaimer.
|
|
|
|
# * Redistributions in binary form must reproduce the above copyright
|
|
|
|
# notice, this list of conditions and the following disclaimer in the
|
|
|
|
# documentation and/or other materials provided with the distribution.
|
|
|
|
# * Neither the name of the NVIDIA CORPORATION nor the
|
|
|
|
# names of its contributors may be used to endorse or promote products
|
|
|
|
# derived from this software without specific prior written permission.
|
|
|
|
#
|
|
|
|
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
|
|
|
# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
|
|
|
# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
|
|
|
# DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
|
|
|
|
# DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
|
|
|
# (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
|
|
|
# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
|
|
|
|
# ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
|
|
|
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
|
|
|
# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
|
|
#
|
|
|
|
# *****************************************************************************
|
|
|
|
|
|
|
|
import os
|
|
|
|
import time
|
|
|
|
import argparse
|
|
|
|
import numpy as np
|
|
|
|
from contextlib import contextmanager
|
|
|
|
|
|
|
|
import torch
|
|
|
|
from torch.utils.data import DataLoader
|
|
|
|
from torch.autograd import Variable
|
|
|
|
from torch.nn.parameter import Parameter
|
|
|
|
|
|
|
|
import torch.distributed as dist
|
|
|
|
from torch.utils.data.distributed import DistributedSampler
|
|
|
|
|
|
|
|
from apex.fp16_utils import FP16_Optimizer
|
|
|
|
from apex.parallel import DistributedDataParallel as DDP
|
|
|
|
|
|
|
|
import models
|
|
|
|
import loss_functions
|
|
|
|
import data_functions
|
|
|
|
|
|
|
|
from dllogger.logger import LOGGER
|
|
|
|
import dllogger.logger as dllg
|
|
|
|
from dllogger import tags
|
|
|
|
from dllogger.autologging import log_hardware, log_args
|
|
|
|
from scipy.io.wavfile import write as write_wav
|
|
|
|
|
|
|
|
|
|
|
|
def parse_args(parser):
|
|
|
|
"""
|
|
|
|
Parse commandline arguments.
|
|
|
|
"""
|
|
|
|
|
|
|
|
parser.add_argument('-o', '--output_directory', type=str, required=True,
|
|
|
|
help='Directory to save checkpoints')
|
|
|
|
parser.add_argument('-d', '--dataset-path', type=str,
|
|
|
|
default='./', help='Path to dataset')
|
|
|
|
parser.add_argument('-m', '--model-name', type=str, default='', required=True,
|
|
|
|
help='Model to train')
|
|
|
|
parser.add_argument('--log-file', type=str, default='nvlog.json',
|
|
|
|
help='Filename for logging')
|
|
|
|
parser.add_argument('--phrase-path', type=str, default=None,
|
|
|
|
help='Path to phrase sequence file used for sample generation')
|
|
|
|
parser.add_argument('--waveglow-checkpoint', type=str, default=None,
|
|
|
|
help='Path to pre-trained WaveGlow checkpoint for sample generation')
|
|
|
|
parser.add_argument('--tacotron2-checkpoint', type=str, default=None,
|
|
|
|
help='Path to pre-trained Tacotron2 checkpoint for sample generation')
|
|
|
|
parser.add_argument('--anneal-steps', nargs='*',
|
|
|
|
help='Epochs after which decrease learning rate')
|
|
|
|
parser.add_argument('--anneal-factor', type=float, choices=[0.1, 0.3], default=0.1,
|
|
|
|
help='Factor for annealing learning rate')
|
|
|
|
|
|
|
|
# training
|
|
|
|
training = parser.add_argument_group('training setup')
|
|
|
|
training.add_argument('--epochs', type=int, required=True, # default=1500,
|
|
|
|
help='Number of total epochs to run')
|
|
|
|
training.add_argument('--epochs-per-checkpoint', type=int, default=10,
|
|
|
|
help='Number of epochs per checkpoint')
|
|
|
|
training.add_argument('--seed', type=int, default=1234,
|
|
|
|
help='Seed for PyTorch random number generators')
|
|
|
|
training.add_argument('--dynamic-loss-scaling', type=bool, default=True,
|
|
|
|
help='Enable dynamic loss scaling')
|
|
|
|
training.add_argument('--fp16-run', action='store_true',
|
|
|
|
help='Enable fp16')
|
|
|
|
training.add_argument('--cudnn-enabled', type=bool, default=True,
|
|
|
|
help='Enable cudnn')
|
|
|
|
training.add_argument('--cudnn-benchmark', type=bool, default=False,
|
|
|
|
help='Run cudnn benchmark')
|
|
|
|
|
|
|
|
optimization = parser.add_argument_group('optimization setup')
|
|
|
|
optimization.add_argument(
|
|
|
|
'--use-saved-learning-rate', default=False, type=bool)
|
|
|
|
optimization.add_argument('-lr', '--learning-rate', type=float, required=True,
|
|
|
|
help='Learing rate')
|
|
|
|
optimization.add_argument('--weight-decay', default=1e-6, type=float,
|
|
|
|
help='Weight decay')
|
|
|
|
optimization.add_argument('--grad-clip-thresh', default=1.0, type=float,
|
|
|
|
help='Clip threshold for gradients')
|
|
|
|
optimization.add_argument('-bs', '--batch-size', type=int, required=True,
|
|
|
|
help='Batch size per GPU')
|
|
|
|
optimization.add_argument('--grad-clip', default=5.0, type=float,
|
|
|
|
help='Enables gradient clipping and sets maximum gradient norm value')
|
|
|
|
|
|
|
|
# dataset parameters
|
|
|
|
dataset = parser.add_argument_group('dataset parameters')
|
|
|
|
dataset.add_argument('--load-mel-from-disk', default=False, type=bool,
|
|
|
|
help='Loads mel spectrograms from disk instead of computing them on the fly')
|
|
|
|
dataset.add_argument('--training-files',
|
|
|
|
default='filelists/ljs_audio_text_train_filelist.txt',
|
|
|
|
type=str, help='Path to training filelist')
|
|
|
|
dataset.add_argument('--validation-files',
|
|
|
|
default='filelists/ljs_audio_text_val_filelist.txt',
|
|
|
|
type=str, help='Path to validation filelist')
|
|
|
|
dataset.add_argument('--text-cleaners', nargs='*',
|
|
|
|
default=['english_cleaners'], type=str,
|
|
|
|
help='Type of text cleaners for input text')
|
|
|
|
|
|
|
|
# audio parameters
|
|
|
|
audio = parser.add_argument_group('audio parameters')
|
|
|
|
audio.add_argument('--max-wav-value', default=32768.0, type=float,
|
|
|
|
help='Maximum audiowave value')
|
|
|
|
audio.add_argument('--sampling-rate', default=22050, type=int,
|
|
|
|
help='Sampling rate')
|
|
|
|
audio.add_argument('--filter-length', default=1024, type=int,
|
|
|
|
help='Filter length')
|
|
|
|
audio.add_argument('--hop-length', default=256, type=int,
|
|
|
|
help='Hop (stride) length')
|
|
|
|
audio.add_argument('--win-length', default=1024, type=int,
|
|
|
|
help='Window length')
|
|
|
|
audio.add_argument('--mel-fmin', default=0.0, type=float,
|
|
|
|
help='Minimum mel frequency')
|
|
|
|
audio.add_argument('--mel-fmax', default=8000.0, type=float,
|
|
|
|
help='Maximum mel frequency')
|
|
|
|
|
|
|
|
distributed = parser.add_argument_group('distributed setup')
|
|
|
|
# distributed.add_argument('--distributed-run', default=True, type=bool,
|
|
|
|
# help='enable distributed run')
|
|
|
|
distributed.add_argument('--rank', default=0, type=int,
|
|
|
|
help='Rank of the process, do not set! Done by multiproc module')
|
|
|
|
distributed.add_argument('--world-size', default=1, type=int,
|
|
|
|
help='Number of processes, do not set! Done by multiproc module')
|
|
|
|
distributed.add_argument('--dist-url', type=str, default='tcp://localhost:23456',
|
|
|
|
help='Url used to set up distributed training')
|
|
|
|
distributed.add_argument('--group-name', type=str, default='group_name',
|
|
|
|
required=False, help='Distributed group name')
|
|
|
|
distributed.add_argument('--dist-backend', default='nccl', type=str, choices={'nccl'},
|
|
|
|
help='Distributed run backend')
|
|
|
|
|
|
|
|
return parser
|
|
|
|
|
|
|
|
|
|
|
|
def reduce_tensor(tensor, num_gpus):
|
|
|
|
rt = tensor.clone()
|
|
|
|
dist.all_reduce(rt, op=dist.reduce_op.SUM)
|
|
|
|
rt /= num_gpus
|
|
|
|
return rt
|
|
|
|
|
|
|
|
|
|
|
|
# shamelessly copied from apex/fp16_utils/fp16_optimizer.py
|
|
|
|
# commit 34582381e289287e6d09490b69c0840718729f2e
|
|
|
|
FLOAT_TYPES = (torch.FloatTensor, torch.cuda.FloatTensor)
|
|
|
|
HALF_TYPES = (torch.HalfTensor, torch.cuda.HalfTensor)
|
|
|
|
|
|
|
|
|
|
|
|
def conversion_helper(val, conversion):
|
|
|
|
"""Apply conversion to val. Recursively apply conversion if `val` is a nested tuple/list structure."""
|
|
|
|
if not isinstance(val, (tuple, list)):
|
|
|
|
return conversion(val)
|
|
|
|
rtn = [conversion_helper(v, conversion) for v in val]
|
|
|
|
if isinstance(val, tuple):
|
|
|
|
rtn = tuple(rtn)
|
|
|
|
return rtn
|
|
|
|
|
|
|
|
|
|
|
|
def fp32_to_fp16(val):
|
|
|
|
"""Convert fp32 `val` to fp16"""
|
|
|
|
def half_conversion(val):
|
|
|
|
val_typecheck = val
|
|
|
|
if isinstance(val_typecheck, (Parameter, Variable)):
|
|
|
|
val_typecheck = val.data
|
|
|
|
if isinstance(val_typecheck, FLOAT_TYPES):
|
|
|
|
val = val.half()
|
|
|
|
return val
|
|
|
|
return conversion_helper(val, half_conversion)
|
|
|
|
|
|
|
|
|
|
|
|
def fp16_to_fp32(val):
|
|
|
|
"""Convert fp16 `val` to fp32"""
|
|
|
|
def float_conversion(val):
|
|
|
|
val_typecheck = val
|
|
|
|
if isinstance(val_typecheck, (Parameter, Variable)):
|
|
|
|
val_typecheck = val.data
|
|
|
|
if isinstance(val_typecheck, HALF_TYPES):
|
|
|
|
val = val.float()
|
|
|
|
return val
|
|
|
|
return conversion_helper(val, float_conversion)
|
|
|
|
|
|
|
|
|
|
|
|
def init_distributed(args, world_size, rank, group_name):
|
|
|
|
assert torch.cuda.is_available(), "Distributed mode requires CUDA."
|
|
|
|
print("Initializing Distributed")
|
|
|
|
|
|
|
|
# Set cuda device so everything is done on the right GPU.
|
|
|
|
torch.cuda.set_device(rank % torch.cuda.device_count())
|
|
|
|
|
|
|
|
# Initialize distributed communication
|
|
|
|
dist.init_process_group(
|
|
|
|
backend=args.dist_backend, init_method=args.dist_url,
|
|
|
|
world_size=world_size, rank=rank, group_name=group_name)
|
|
|
|
|
|
|
|
print("Done initializing distributed")
|
|
|
|
|
|
|
|
|
|
|
|
def save_checkpoint(model, epoch, config, filepath):
|
|
|
|
print("Saving model and optimizer state at epoch {} to {}".format(
|
|
|
|
epoch, filepath))
|
|
|
|
torch.save({'epoch': epoch,
|
|
|
|
'config': config,
|
|
|
|
'state_dict': model.state_dict()}, filepath)
|
|
|
|
|
|
|
|
|
|
|
|
def save_sample(model_name, model, waveglow_path, tacotron2_path, phrase_path, filepath, sampling_rate, fp16):
|
|
|
|
if phrase_path is None:
|
|
|
|
return
|
|
|
|
phrase = torch.load(phrase_path, map_location='cpu')
|
|
|
|
if model_name == 'Tacotron2':
|
|
|
|
if waveglow_path is None:
|
|
|
|
raise Exception(
|
|
|
|
"WaveGlow checkpoint path is missing, could not generate sample")
|
|
|
|
with torch.no_grad():
|
|
|
|
checkpoint = torch.load(waveglow_path, map_location='cpu')
|
|
|
|
waveglow = models.get_model(
|
|
|
|
'WaveGlow', checkpoint['config'], to_fp16=False, to_cuda=False)
|
|
|
|
waveglow.eval()
|
|
|
|
model.eval()
|
2019-05-15 14:16:46 +02:00
|
|
|
mel = model.infer(phrase.cuda())[0].cpu()
|
2019-03-18 19:45:05 +01:00
|
|
|
model.train()
|
|
|
|
if fp16:
|
|
|
|
mel = mel.float()
|
|
|
|
audio = waveglow.infer(mel, sigma=0.6)
|
|
|
|
elif model_name == 'WaveGlow':
|
|
|
|
if tacotron2_path is None:
|
|
|
|
raise Exception(
|
|
|
|
"Tacotron2 checkpoint path is missing, could not generate sample")
|
|
|
|
with torch.no_grad():
|
|
|
|
checkpoint = torch.load(tacotron2_path, map_location='cpu')
|
|
|
|
tacotron2 = models.get_model(
|
|
|
|
'Tacotron2', checkpoint['config'], to_fp16=False, to_cuda=False)
|
|
|
|
tacotron2.eval()
|
2019-05-15 14:16:46 +02:00
|
|
|
mel = tacotron2.infer(phrase)[0].cuda()
|
2019-03-18 19:45:05 +01:00
|
|
|
model.eval()
|
|
|
|
if fp16:
|
|
|
|
mel = mel.half()
|
|
|
|
audio = model.infer(mel, sigma=0.6).cpu()
|
|
|
|
model.train()
|
|
|
|
if fp16:
|
|
|
|
audio = audio.float()
|
|
|
|
else:
|
|
|
|
raise NotImplementedError(
|
|
|
|
"unknown model requested: {}".format(model_name))
|
|
|
|
audio = audio[0].numpy()
|
|
|
|
audio = audio.astype('int16')
|
|
|
|
write_wav(filepath, sampling_rate, audio)
|
|
|
|
|
|
|
|
# adapted from: https://discuss.pytorch.org/t/opinion-eval-should-be-a-context-manager/18998/3
|
|
|
|
# Following snippet is licensed under MIT license
|
|
|
|
|
|
|
|
|
|
|
|
@contextmanager
|
|
|
|
def evaluating(model):
|
|
|
|
'''Temporarily switch to evaluation mode.'''
|
|
|
|
istrain = model.training
|
|
|
|
try:
|
|
|
|
model.eval()
|
|
|
|
yield model
|
|
|
|
finally:
|
|
|
|
if istrain:
|
|
|
|
model.train()
|
|
|
|
|
|
|
|
|
|
|
|
def validate(model, criterion, valset, iteration, batch_size, world_size,
|
|
|
|
collate_fn, distributed_run, rank, batch_to_gpu, fp16_run):
|
|
|
|
"""Handles all the validation scoring and printing"""
|
|
|
|
with evaluating(model), torch.no_grad():
|
|
|
|
val_sampler = DistributedSampler(valset) if distributed_run else None
|
|
|
|
val_loader = DataLoader(valset, num_workers=1, shuffle=False,
|
|
|
|
sampler=val_sampler,
|
|
|
|
batch_size=batch_size, pin_memory=False,
|
|
|
|
collate_fn=collate_fn)
|
|
|
|
|
|
|
|
val_loss = 0.0
|
|
|
|
for i, batch in enumerate(val_loader):
|
|
|
|
x, y, len_x = batch_to_gpu(batch)
|
|
|
|
if fp16_run:
|
|
|
|
y_pred = model(fp32_to_fp16(x))
|
|
|
|
loss = criterion(fp16_to_fp32(y_pred), y)
|
|
|
|
else:
|
|
|
|
y_pred = model(x)
|
|
|
|
loss = criterion(y_pred, y)
|
|
|
|
if distributed_run:
|
|
|
|
reduced_val_loss = reduce_tensor(loss.data, world_size).item()
|
|
|
|
else:
|
|
|
|
reduced_val_loss = loss.item()
|
|
|
|
val_loss += reduced_val_loss
|
|
|
|
val_loss = val_loss / (i + 1)
|
|
|
|
|
|
|
|
LOGGER.log(key="val_iter_loss", value=reduced_val_loss)
|
|
|
|
|
|
|
|
|
|
|
|
def adjust_learning_rate(epoch, optimizer, learning_rate,
|
|
|
|
anneal_steps, anneal_factor):
|
|
|
|
|
|
|
|
p = 0
|
|
|
|
if anneal_steps is not None:
|
|
|
|
for i, a_step in enumerate(anneal_steps):
|
|
|
|
if epoch >= int(a_step):
|
|
|
|
p = p+1
|
|
|
|
|
|
|
|
if anneal_factor == 0.3:
|
|
|
|
lr = learning_rate*((0.1 ** (p//2))*(1.0 if p % 2 == 0 else 0.3))
|
|
|
|
else:
|
|
|
|
lr = learning_rate*(anneal_factor ** p)
|
|
|
|
|
|
|
|
if optimizer.param_groups[0]['lr'] != lr:
|
|
|
|
LOGGER.log_event("learning_rate changed",
|
|
|
|
value=str(optimizer.param_groups[0]['lr']) + " -> " + str(lr))
|
|
|
|
|
|
|
|
for param_group in optimizer.param_groups:
|
|
|
|
param_group['lr'] = lr
|
|
|
|
|
|
|
|
|
|
|
|
def main():
|
|
|
|
|
|
|
|
parser = argparse.ArgumentParser(description='PyTorch Tacotron 2 Training')
|
|
|
|
parser = parse_args(parser)
|
|
|
|
args, _ = parser.parse_known_args()
|
|
|
|
|
|
|
|
LOGGER.set_model_name("Tacotron2_PyT")
|
|
|
|
LOGGER.set_backends([
|
|
|
|
dllg.StdOutBackend(log_file=None,
|
|
|
|
logging_scope=dllg.TRAIN_ITER_SCOPE, iteration_interval=1),
|
|
|
|
dllg.JsonBackend(log_file=args.log_file if args.rank == 0 else None,
|
|
|
|
logging_scope=dllg.TRAIN_ITER_SCOPE, iteration_interval=1)
|
|
|
|
])
|
|
|
|
|
|
|
|
LOGGER.timed_block_start("run")
|
|
|
|
LOGGER.register_metric(tags.TRAIN_ITERATION_LOSS,
|
|
|
|
metric_scope=dllg.TRAIN_ITER_SCOPE)
|
|
|
|
LOGGER.register_metric("iter_time",
|
|
|
|
metric_scope=dllg.TRAIN_ITER_SCOPE)
|
|
|
|
LOGGER.register_metric("epoch_time",
|
|
|
|
metric_scope=dllg.EPOCH_SCOPE)
|
|
|
|
LOGGER.register_metric("run_time",
|
|
|
|
metric_scope=dllg.RUN_SCOPE)
|
|
|
|
LOGGER.register_metric("val_iter_loss",
|
|
|
|
metric_scope=dllg.EPOCH_SCOPE)
|
|
|
|
LOGGER.register_metric("train_epoch_items/sec",
|
|
|
|
metric_scope=dllg.EPOCH_SCOPE)
|
|
|
|
LOGGER.register_metric("train_epoch_avg_loss",
|
|
|
|
metric_scope=dllg.EPOCH_SCOPE)
|
|
|
|
|
|
|
|
log_hardware()
|
|
|
|
|
|
|
|
model_name = args.model_name
|
|
|
|
parser = models.parse_model_args(model_name, parser)
|
|
|
|
parser.parse_args()
|
|
|
|
|
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
|
|
log_args(args)
|
|
|
|
|
|
|
|
torch.backends.cudnn.enabled = args.cudnn_enabled
|
|
|
|
torch.backends.cudnn.benchmark = args.cudnn_benchmark
|
|
|
|
|
|
|
|
distributed_run = args.world_size > 1
|
|
|
|
if distributed_run:
|
|
|
|
init_distributed(args, args.world_size, args.rank, args.group_name)
|
|
|
|
|
|
|
|
LOGGER.log(key=tags.RUN_START)
|
|
|
|
run_start_time = time.time()
|
|
|
|
|
|
|
|
model_config = models.get_model_config(model_name, args)
|
|
|
|
model = models.get_model(model_name, model_config,
|
|
|
|
to_fp16=args.fp16_run, to_cuda=True)
|
|
|
|
if distributed_run:
|
|
|
|
model = DDP(model)
|
|
|
|
|
|
|
|
optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate,
|
|
|
|
weight_decay=args.weight_decay)
|
|
|
|
|
|
|
|
if args.fp16_run:
|
|
|
|
optimizer = FP16_Optimizer(
|
|
|
|
optimizer, dynamic_loss_scale=args.dynamic_loss_scaling)
|
|
|
|
|
|
|
|
try:
|
|
|
|
sigma = args.sigma
|
|
|
|
except AttributeError:
|
|
|
|
sigma = None
|
|
|
|
|
|
|
|
criterion = loss_functions.get_loss_function(model_name, sigma)
|
|
|
|
|
|
|
|
try:
|
|
|
|
n_frames_per_step = args.n_frames_per_step
|
|
|
|
except AttributeError:
|
|
|
|
n_frames_per_step = None
|
|
|
|
|
|
|
|
collate_fn = data_functions.get_collate_function(
|
|
|
|
model_name, n_frames_per_step)
|
|
|
|
trainset = data_functions.get_data_loader(
|
|
|
|
model_name, args.dataset_path, args.training_files, args)
|
|
|
|
train_sampler = DistributedSampler(trainset) if distributed_run else None
|
|
|
|
train_loader = DataLoader(trainset, num_workers=1, shuffle=False,
|
|
|
|
sampler=train_sampler,
|
|
|
|
batch_size=args.batch_size, pin_memory=False,
|
|
|
|
drop_last=True, collate_fn=collate_fn)
|
|
|
|
|
|
|
|
valset = data_functions.get_data_loader(
|
|
|
|
model_name, args.dataset_path, args.validation_files, args)
|
|
|
|
|
|
|
|
batch_to_gpu = data_functions.get_batch_to_gpu(model_name)
|
|
|
|
|
|
|
|
iteration = 0
|
|
|
|
model.train()
|
|
|
|
|
|
|
|
LOGGER.log(key=tags.TRAIN_LOOP)
|
|
|
|
|
|
|
|
for epoch in range(args.epochs):
|
|
|
|
LOGGER.epoch_start()
|
|
|
|
epoch_start_time = time.time()
|
|
|
|
LOGGER.log(key=tags.TRAIN_EPOCH_START, value=epoch)
|
|
|
|
|
|
|
|
# used to calculate avg items/sec over epoch
|
|
|
|
reduced_num_items_epoch = 0
|
|
|
|
|
|
|
|
# used to calculate avg loss over epoch
|
|
|
|
train_epoch_avg_loss = 0.0
|
|
|
|
num_iters = 0
|
|
|
|
|
|
|
|
# if overflow at the last iteration then do not save checkpoint
|
|
|
|
overflow = False
|
|
|
|
|
|
|
|
for i, batch in enumerate(train_loader):
|
|
|
|
LOGGER.iteration_start()
|
|
|
|
iter_start_time = time.time()
|
|
|
|
LOGGER.log(key=tags.TRAIN_ITER_START, value=i)
|
|
|
|
print("Batch: {}/{} epoch {}".format(i, len(train_loader), epoch))
|
|
|
|
|
|
|
|
start = time.perf_counter()
|
|
|
|
adjust_learning_rate(epoch, optimizer, args.learning_rate,
|
|
|
|
args.anneal_steps, args.anneal_factor)
|
|
|
|
|
|
|
|
model.zero_grad()
|
|
|
|
x, y, num_items = batch_to_gpu(batch)
|
|
|
|
|
|
|
|
if args.fp16_run:
|
|
|
|
y_pred = model(fp32_to_fp16(x))
|
|
|
|
loss = criterion(fp16_to_fp32(y_pred), y)
|
|
|
|
else:
|
|
|
|
y_pred = model(x)
|
|
|
|
loss = criterion(y_pred, y)
|
|
|
|
|
|
|
|
if distributed_run:
|
|
|
|
reduced_loss = reduce_tensor(loss.data, args.world_size).item()
|
|
|
|
reduced_num_items = reduce_tensor(num_items.data, 1).item()
|
|
|
|
else:
|
|
|
|
reduced_loss = loss.item()
|
|
|
|
reduced_num_items = num_items.item()
|
|
|
|
if np.isnan(reduced_loss):
|
|
|
|
raise Exception("loss is NaN")
|
|
|
|
|
|
|
|
LOGGER.log(key=tags.TRAIN_ITERATION_LOSS, value=reduced_loss)
|
|
|
|
|
|
|
|
train_epoch_avg_loss += reduced_loss
|
|
|
|
num_iters += 1
|
|
|
|
|
|
|
|
# accumulate number of items processed in this epoch
|
|
|
|
reduced_num_items_epoch += reduced_num_items
|
|
|
|
|
|
|
|
if args.fp16_run:
|
|
|
|
optimizer.backward(loss)
|
|
|
|
grad_norm = optimizer.clip_master_grads(args.grad_clip_thresh)
|
|
|
|
else:
|
|
|
|
loss.backward()
|
|
|
|
grad_norm = torch.nn.utils.clip_grad_norm_(
|
|
|
|
model.parameters(), args.grad_clip_thresh)
|
|
|
|
|
|
|
|
optimizer.step()
|
|
|
|
|
|
|
|
overflow = optimizer.overflow if args.fp16_run else False
|
|
|
|
iteration += 1
|
|
|
|
|
|
|
|
LOGGER.log(key=tags.TRAIN_ITER_STOP, value=i)
|
|
|
|
|
|
|
|
iter_stop_time = time.time()
|
|
|
|
iter_time = iter_stop_time - iter_start_time
|
|
|
|
LOGGER.log(key="train_iter_items/sec",
|
|
|
|
value=(reduced_num_items/iter_time))
|
|
|
|
LOGGER.log(key="iter_time", value=iter_time)
|
|
|
|
LOGGER.iteration_stop()
|
|
|
|
|
|
|
|
LOGGER.log(key=tags.TRAIN_EPOCH_STOP, value=epoch)
|
|
|
|
epoch_stop_time = time.time()
|
|
|
|
epoch_time = epoch_stop_time - epoch_start_time
|
|
|
|
|
|
|
|
LOGGER.log(key="train_epoch_items/sec",
|
|
|
|
value=(reduced_num_items_epoch/epoch_time))
|
|
|
|
LOGGER.log(key="train_epoch_avg_loss", value=(
|
|
|
|
train_epoch_avg_loss/num_iters if num_iters > 0 else 0.0))
|
|
|
|
LOGGER.log(key="epoch_time", value=epoch_time)
|
|
|
|
|
|
|
|
LOGGER.log(key=tags.EVAL_START, value=epoch)
|
|
|
|
|
|
|
|
validate(model, criterion, valset, iteration,
|
|
|
|
args.batch_size, args.world_size, collate_fn,
|
|
|
|
distributed_run, args.rank, batch_to_gpu, args.fp16_run)
|
|
|
|
|
|
|
|
LOGGER.log(key=tags.EVAL_STOP, value=epoch)
|
|
|
|
|
|
|
|
if not overflow and (epoch % args.epochs_per_checkpoint == 0) and args.rank == 0:
|
|
|
|
checkpoint_path = os.path.join(
|
|
|
|
args.output_directory, "checkpoint_{}_{}".format(model_name, epoch))
|
|
|
|
save_checkpoint(model, epoch, model_config, checkpoint_path)
|
|
|
|
save_sample(model_name, model, args.waveglow_checkpoint,
|
|
|
|
args.tacotron2_checkpoint, args.phrase_path,
|
|
|
|
os.path.join(args.output_directory, "sample_{}_{}.wav".format(model_name, iteration)), args.sampling_rate, args.fp16_run)
|
|
|
|
|
|
|
|
LOGGER.epoch_stop()
|
|
|
|
|
|
|
|
run_stop_time = time.time()
|
|
|
|
run_time = run_stop_time - run_start_time
|
|
|
|
LOGGER.log(key="run_time", value=run_time)
|
|
|
|
LOGGER.log(key=tags.RUN_FINAL)
|
|
|
|
|
|
|
|
print("training time", run_stop_time - run_start_time)
|
|
|
|
|
|
|
|
LOGGER.timed_block_stop("run")
|
|
|
|
|
|
|
|
if args.rank == 0:
|
|
|
|
LOGGER.finish()
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
main()
|