Add Paarth's HiFi-GAN and Tacotron fine-tuning code (#3000)

Signed-off-by: Jocelyn Huang <jocelynh@nvidia.com>
This commit is contained in:
Jocelyn 2021-10-13 13:38:50 -07:00 committed by GitHub
parent 38e74de2b9
commit 9bbbe2e403
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 363 additions and 23 deletions

View file

@ -18,6 +18,8 @@ model:
pitch_fmax: 640
pitch_avg: 211.27540199742586 # for a female speaker (8051 HiFiGAN)
pitch_std: 52.1851002822779 # for a female speaker (8051 HiFiGAN)
dur_loss_scale: 0.1
pitch_loss_scale: 0.1
train_ds:
dataset:

View file

@ -0,0 +1,67 @@
name: "HifiGan"
train_dataset: ???
validation_datasets: ???
defaults:
- model/generator: v4
- model/train_ds: train_ds
- model/validation_ds: val_ds
model:
preprocessor:
_target_: nemo.collections.asr.parts.preprocessing.features.FilterbankFeatures
dither: 0.0
frame_splicing: 1
nfilt: 80
highfreq: null
log: true
log_zero_guard_type: clamp
log_zero_guard_value: 1e-05
lowfreq: 0
mag_power: 1.0
n_fft: 2048
n_window_size: 2048
n_window_stride: 512
normalize: null
pad_to: 0
pad_value: -11.52
preemph: null
sample_rate: 44100
window: hann
use_grads: false
exact_pad: true
optim:
_target_: torch.optim.AdamW
lr: 0.0002
betas: [0.8, 0.99]
sched:
name: CosineAnnealing
min_lr: 1e-5
warmup_ratio: 0.02
max_steps: 25000000
l1_loss_factor: 45
denoise_strength: 0.0025
trainer:
gpus: -1 # number of gpus
max_steps: ${model.max_steps}
num_nodes: 1
accelerator: ddp
accumulate_grad_batches: 1
checkpoint_callback: False # Provided by exp_manager
logger: False # Provided by exp_manager
flush_logs_every_n_steps: 200
log_every_n_steps: 100
check_val_every_n_epoch: 10
exp_manager:
exp_dir: null
name: ${name}
create_tensorboard_logger: True
create_checkpoint_callback: True
checkpoint_callback_params:
monitor: "val_loss"
mode: "min"

View file

@ -0,0 +1,8 @@
# @package _group_
_target_: nemo.collections.tts.modules.hifigan_modules.Generator
resblock: 1
upsample_rates: [8,8,4,2]
upsample_kernel_sizes: [16,16,4,4]
upsample_initial_channel: 512
resblock_kernel_sizes: [3,7,11]
resblock_dilation_sizes: [[1,3,5], [1,3,5], [1,3,5]]

View file

@ -0,0 +1,175 @@
name: Tacotron2
sample_rate: 44100
# <PAD>, <BOS>, <EOS> will be added by the tacotron2.py script
labels:
- ' '
- '!'
- '"'
- ''''
- (
- )
- ','
- '-'
- .
- ':'
- ;
- '?'
- a
- b
- c
- d
- e
- f
- g
- h
- i
- j
- k
- l
- m
- 'n'
- o
- p
- q
- r
- s
- t
- u
- v
- w
- x
- 'y'
- z
n_fft: 2048
n_mels: 80
fmax: null
n_stride: 512
pad_value: -11.52
train_dataset: ???
validation_datasets: ???
model:
labels: ${labels}
train_ds:
dataset:
_target_: "nemo.collections.asr.data.audio_to_text.AudioToCharDataset"
manifest_filepath: ${train_dataset}
max_duration: null
min_duration: 0.1
trim: false
int_values: false
normalize: true
sample_rate: ${sample_rate}
# bos_id: 66
# eos_id: 67
# pad_id: 68 These parameters are added automatically in Tacotron2
dataloader_params:
drop_last: false
shuffle: true
batch_size: 48
num_workers: 4
validation_ds:
dataset:
_target_: "nemo.collections.asr.data.audio_to_text.AudioToCharDataset"
manifest_filepath: ${validation_datasets}
max_duration: null
min_duration: 0.1
int_values: false
normalize: true
sample_rate: ${sample_rate}
trim: false
# bos_id: 66
# eos_id: 67
# pad_id: 68 These parameters are added automatically in Tacotron2
dataloader_params:
drop_last: false
shuffle: false
batch_size: 48
num_workers: 8
preprocessor:
_target_: nemo.collections.asr.parts.preprocessing.features.FilterbankFeatures
dither: 0.0
nfilt: ${n_mels}
frame_splicing: 1
highfreq: ${fmax}
log: true
log_zero_guard_type: clamp
log_zero_guard_value: 1e-05
lowfreq: 0
mag_power: 1.0
n_fft: ${n_fft}
n_window_size: 2048
n_window_stride: ${n_stride}
normalize: null
pad_to: 16
pad_value: ${pad_value}
preemph: null
sample_rate: ${sample_rate}
window: hann
encoder:
_target_: nemo.collections.tts.modules.tacotron2.Encoder
encoder_kernel_size: 5
encoder_n_convolutions: 3
encoder_embedding_dim: 512
decoder:
_target_: nemo.collections.tts.modules.tacotron2.Decoder
decoder_rnn_dim: 1024
encoder_embedding_dim: ${model.encoder.encoder_embedding_dim}
gate_threshold: 0.5
max_decoder_steps: 1000
n_frames_per_step: 1 # currently only 1 is supported
n_mel_channels: ${n_mels}
p_attention_dropout: 0.1
p_decoder_dropout: 0.1
prenet_dim: 256
prenet_p_dropout: 0.5
# Attention parameters
attention_dim: 128
attention_rnn_dim: 1024
# AttentionLocation Layer parameters
attention_location_kernel_size: 31
attention_location_n_filters: 32
early_stopping: true
postnet:
_target_: nemo.collections.tts.modules.tacotron2.Postnet
n_mel_channels: ${n_mels}
p_dropout: 0.5
postnet_embedding_dim: 512
postnet_kernel_size: 5
postnet_n_convolutions: 5
optim:
name: adam
lr: 1e-3
weight_decay: 1e-6
# scheduler setup
sched:
name: CosineAnnealing
min_lr: 1e-5
trainer:
gpus: 1 # number of gpus
max_epochs: ???
num_nodes: 1
accelerator: ddp
accumulate_grad_batches: 1
checkpoint_callback: False # Provided by exp_manager
logger: False # Provided by exp_manager
gradient_clip_val: 1.0
flush_logs_every_n_steps: 1000
log_every_n_steps: 200
check_val_every_n_epoch: 25
exp_manager:
exp_dir: null
name: ${name}
create_tensorboard_logger: True
create_checkpoint_callback: True

View file

@ -0,0 +1,32 @@
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytorch_lightning as pl
from nemo.collections.tts.models import HifiGanModel
from nemo.core.config import hydra_runner
from nemo.utils.exp_manager import exp_manager
@hydra_runner(config_path="conf/hifigan", config_name="hifigan44100")
def main(cfg):
trainer = pl.Trainer(**cfg.trainer)
exp_manager(trainer, cfg.get("exp_manager", None))
model = HifiGanModel(cfg=cfg.model, trainer=trainer)
model.maybe_init_from_pretrained_checkpoint(cfg=cfg)
trainer.fit(model)
if __name__ == '__main__':
main() # noqa pylint: disable=no-value-for-parameter

View file

@ -0,0 +1,45 @@
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytorch_lightning as pl
from nemo.collections.common.callbacks import LogEpochTimeCallback
from nemo.collections.tts.models import Tacotron2Model
from nemo.core.config import hydra_runner
from nemo.utils.exp_manager import exp_manager
# hydra_runner is a thin NeMo wrapper around Hydra
# It looks for a config named tacotron2.yaml inside the conf folder
# Hydra parses the yaml and returns it as a Omegaconf DictConfig
@hydra_runner(config_path="conf", config_name="tacotron2_44100")
def main(cfg):
# Define the Lightning trainer
trainer = pl.Trainer(**cfg.trainer)
# exp_manager is a NeMo construct that helps with logging and checkpointing
exp_manager(trainer, cfg.get("exp_manager", None))
# Define the Tacotron 2 model, this will construct the model as well as
# define the training and validation dataloaders
model = Tacotron2Model(cfg=cfg.model, trainer=trainer)
model.maybe_init_from_pretrained_checkpoint(cfg=cfg)
# Let's add a few more callbacks
lr_logger = pl.callbacks.LearningRateMonitor()
epoch_time_logger = LogEpochTimeCallback()
trainer.callbacks.extend([lr_logger, epoch_time_logger])
# Call lightning trainer's fit() to train the model
trainer.fit(model)
if __name__ == '__main__':
main() # noqa pylint: disable=no-value-for-parameter

View file

@ -81,8 +81,14 @@ class FastPitchModel(SpectrogramGenerator, Exportable):
self.log_train_images = False
self.mel_loss = MelLoss()
loss_scale = 0.1 if self.learn_alignment else 1.0
self.pitch_loss = PitchLoss(loss_scale=loss_scale)
self.duration_loss = DurationLoss(loss_scale=loss_scale)
dur_loss_scale = loss_scale
pitch_loss_scale = loss_scale
if "dur_loss_scale" in cfg:
dur_loss_scale = cfg.dur_loss_scale
if "pitch_loss_scale" in cfg:
pitch_loss_scale = cfg.pitch_loss_scale
self.pitch_loss = PitchLoss(loss_scale=pitch_loss_scale)
self.duration_loss = DurationLoss(loss_scale=dur_loss_scale)
input_fft_kwargs = {}
if self.learn_alignment:
self.aligner = instantiate(self._cfg.alignment_module)

View file

@ -76,26 +76,29 @@ class HifiGanModel(Vocoder, Exportable):
self._cfg.optim, params=itertools.chain(self.msd.parameters(), self.mpd.parameters()),
)
self.scheduler_g = CosineAnnealing(
optimizer=self.optim_g,
max_steps=self._cfg.max_steps,
min_lr=self._cfg.sched.min_lr,
warmup_steps=self._cfg.sched.warmup_ratio * self._cfg.max_steps,
) # Use warmup to delay start
sch1_dict = {
'scheduler': self.scheduler_g,
'interval': 'step',
}
if hasattr(self._cfg, 'sched'):
self.scheduler_g = CosineAnnealing(
optimizer=self.optim_g,
max_steps=self._cfg.max_steps,
min_lr=self._cfg.sched.min_lr,
warmup_steps=self._cfg.sched.warmup_ratio * self._cfg.max_steps,
) # Use warmup to delay start
sch1_dict = {
'scheduler': self.scheduler_g,
'interval': 'step',
}
self.scheduler_d = CosineAnnealing(
optimizer=self.optim_d, max_steps=self._cfg.max_steps, min_lr=self._cfg.sched.min_lr,
)
sch2_dict = {
'scheduler': self.scheduler_d,
'interval': 'step',
}
self.scheduler_d = CosineAnnealing(
optimizer=self.optim_d, max_steps=self._cfg.max_steps, min_lr=self._cfg.sched.min_lr,
)
sch2_dict = {
'scheduler': self.scheduler_d,
'interval': 'step',
}
return [self.optim_g, self.optim_d], [sch1_dict, sch2_dict]
return [self.optim_g, self.optim_d], [sch1_dict, sch2_dict]
else:
return [self.optim_g, self.optim_d]
@property
def input_types(self):
@ -169,9 +172,11 @@ class HifiGanModel(Vocoder, Exportable):
self.optim_g.step()
# run schedulers
sch1, sch2 = self.lr_schedulers()
sch1.step()
sch2.step()
schedulers = self.lr_schedulers()
if schedulers is not None:
sch1, sch2 = schedulers
sch1.step()
sch2.step()
metrics = {
"g_loss_fm_mpd": loss_fm_mpd,