DeepLearningExamples/CUDA-Optimized/FastSpeech/fastspeech/train.py
2020-07-31 14:59:15 +08:00

168 lines
6.9 KiB
Python

# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# * Neither the name of the NVIDIA CORPORATION nor the
# names of its contributors may be used to endorse or promote products
# derived from this software without specific prior written permission.
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
# DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
# (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
# ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import pprint
import fire
import torch
from torch.optim.lr_scheduler import LambdaLR
from fastspeech import DEFAULT_DEVICE
from fastspeech import hparam as hp
from fastspeech.data_load import PadDataLoader
from fastspeech.dataset.ljspeech_dataset import LJSpeechDataset
from fastspeech.model.fastspeech import Fastspeech
from fastspeech.trainer.fastspeech_trainer import FastspeechTrainer
from fastspeech.utils.logging import tprint
try:
import apex
except ImportError:
ImportError('Required to install apex.')
# import multiprocessing
# multiprocessing.set_start_method('spawn', True)
pp = pprint.PrettyPrinter(indent=4, width=1000)
def train(hparam="train.yaml",
device=DEFAULT_DEVICE,
**kwargs):
""" The FastSpeech model training script.
By default, this script assumes to load parameters in the default config file, fastspeech/hparams/train.yaml.
Besides the flags, you can also set parameters in the config file via the command-line. For examples,
--dataset_path=DATASET_PATH
Path to dataset directory.
--tacotron2_path=TACOTRON2_PATH
Path to tacotron2 checkpoint file.
--mels_path=MELS_PATH
Path to preprocessed mels directory.
--aligns_path=ALIGNS_PATH
Path to preprocessed alignments directory.
--log_path=LOG_PATH
Path to log directory.
--checkpoint_path=CHECKPOINT_PATH
Path to checkpoint directory. The latest checkpoint will be loaded.
--batch_size=BATCH_SIZE
Batch size to use. Defaults to 16.
Refer to fastspeech/hparams/train.yaml to see more parameters.
Args:
hparam (str, optional): Path to default config file. Defaults to "train.yaml".
device (str, optional): Device to use. Defaults to "cuda" if avaiable, or "cpu".
"""
hp.set_hparam(hparam, kwargs)
tprint("Hparams:\n{}".format(pp.pformat(hp)))
tprint("Device count: {}".format(torch.cuda.device_count()))
# model
model = Fastspeech(
max_seq_len=hp.max_seq_len,
d_model=hp.d_model,
phoneme_side_n_layer=hp.phoneme_side_n_layer,
phoneme_side_head=hp.phoneme_side_head,
phoneme_side_conv1d_filter_size=hp.phoneme_side_conv1d_filter_size,
phoneme_side_output_size=hp.phoneme_side_output_size,
mel_side_n_layer=hp.mel_side_n_layer,
mel_side_head=hp.mel_side_head,
mel_side_conv1d_filter_size=hp.mel_side_conv1d_filter_size,
mel_side_output_size=hp.mel_side_output_size,
duration_predictor_filter_size=hp.duration_predictor_filter_size,
duration_predictor_kernel_size=hp.duration_predictor_kernel_size,
fft_conv1d_kernel=hp.fft_conv1d_kernel,
fft_conv1d_padding=hp.fft_conv1d_padding,
dropout=hp.dropout,
n_mels=hp.num_mels,
fused_layernorm=hp.fused_layernorm
)
# dataset
dataset = LJSpeechDataset(root_path=hp.dataset_path,
meta_file=hp.meta_file,
mels_path=hp.mels_path,
aligns_path=hp.aligns_path,
sr=hp.sr,
n_fft=hp.n_fft,
win_len=hp.win_len,
hop_len=hp.hop_len,
n_mels=hp.num_mels,
mel_fmin=hp.mel_fmin,
mel_fmax=hp.mel_fmax,
)
tprint("Dataset size: {}".format(len(dataset)))
# data loader
data_loader = PadDataLoader(dataset,
batch_size=hp.batch_size,
num_workers=hp.n_workers,
drop_last=True,
)
# optimizer
def get_optimizer(model):
optimizer = torch.optim.Adam(
model.parameters(),
lr=hp.learning_rate,
betas=(0.9, 0.98),
eps=1e-9)
return optimizer
def get_warmup_lr_scheduler(optimizer):
d_model = hp.d_model
warmup_steps = hp.warmup_steps
lr = lambda step: d_model ** -0.5 * min((step + 1) ** -0.5,
(step + 1) * warmup_steps ** -1.5) / hp.learning_rate
scheduler = LambdaLR(optimizer, lr_lambda=[lr])
return scheduler
# trainer
trainer = FastspeechTrainer(data_loader,
'fastspeech',
model,
optimizer_fn=get_optimizer,
final_steps=hp.final_steps,
log_steps=hp.log_step,
ckpt_path=hp.checkpoint_path,
save_steps=hp.save_step,
log_path=hp.log_path,
lr_scheduler_fn=get_warmup_lr_scheduler,
pre_aligns=True if hp.aligns_path else False,
device=device,
use_amp=hp.use_amp,
nvprof_iter_start=hp.nvprof_iter_start,
nvprof_iter_end=hp.nvprof_iter_end,
pyprof_enabled=hp.pyprof_enabled,
)
trainer.train()
if __name__ == '__main__':
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = False
fire.Fire(train)