From 7d8b7d679a2a4c9d3a6b2d8def8888e53c804ded Mon Sep 17 00:00:00 2001 From: Felix Kreuk Date: Tue, 16 Feb 2021 23:56:29 +0200 Subject: [PATCH] HifiGAN MelSpectrogram Vocoder Model (#1706) * HifiGAN: initial commit Signed-off-by: Felix Kreuk * switched to manual optimization, added grads to FilterBanks object, added v1,v2,v3 configurations for the generator added multiple audio/spec examples in wandb Signed-off-by: Felix Kreuk * added trg_mel_fn for different mel params in L1 loss Signed-off-by: Felix Kreuk * fixed ci checks Signed-off-by: Felix Kreuk * moved losses to separate classes, fixed exp_manager in yaml, style fixes Signed-off-by: Felix Kreuk * fixed `min_duration` in yamls of MelGAN and HifiGAN Signed-off-by: Felix Kreuk * fixed f_max in loss calculation, added docs Signed-off-by: Felix Kreuk Co-authored-by: Oleksii Kuchaiev Co-authored-by: Jason --- .gitignore | 1 + docs/source/tts/models.rst | 6 + docs/source/tts/tts_all.bib | 9 +- examples/tts/conf/hifigan/hifigan.yaml | 84 ++++ .../tts/conf/hifigan/model/generator/v1.yaml | 8 + .../tts/conf/hifigan/model/generator/v2.yaml | 8 + .../tts/conf/hifigan/model/generator/v3.yaml | 8 + examples/tts/conf/melgan.yaml | 2 +- examples/tts/hifigan.py | 34 ++ nemo/collections/asr/parts/features.py | 12 +- nemo/collections/tts/losses/hifigan_losses.py | 132 ++++++ nemo/collections/tts/models/__init__.py | 2 + nemo/collections/tts/models/hifigan.py | 219 +++++++++ .../tts/modules/hifigan_modules.py | 434 ++++++++++++++++++ 14 files changed, 955 insertions(+), 4 deletions(-) create mode 100644 examples/tts/conf/hifigan/hifigan.yaml create mode 100644 examples/tts/conf/hifigan/model/generator/v1.yaml create mode 100644 examples/tts/conf/hifigan/model/generator/v2.yaml create mode 100644 examples/tts/conf/hifigan/model/generator/v3.yaml create mode 100644 examples/tts/hifigan.py create mode 100644 nemo/collections/tts/losses/hifigan_losses.py create mode 100644 nemo/collections/tts/models/hifigan.py create mode 100644 nemo/collections/tts/modules/hifigan_modules.py diff --git a/.gitignore b/.gitignore index 135b76c6b..ff04a2626 100644 --- a/.gitignore +++ b/.gitignore @@ -11,6 +11,7 @@ bert.pt.json work runs fastspeech_output +.hydra # Byte-compiled / optimized / DLL files __pycache__/ diff --git a/docs/source/tts/models.rst b/docs/source/tts/models.rst index b48257a64..4bf229b61 100644 --- a/docs/source/tts/models.rst +++ b/docs/source/tts/models.rst @@ -51,6 +51,12 @@ Full-band MelGAN The :class:`MelGanModel` class implements Full-band MelGAN as described in :cite:`yang2020multiband`. MelGAN is a GAN-based vocoder that converts mel-spectrograms to audio. +HifiGAN +------- + +The :class:`HifiGanModel` class implements HifiGAN as described in +:cite:`kong2020hifi`. HifiGAN is a GAN-based vocoder that converts mel-spectrograms to audio. + References ---------- diff --git a/docs/source/tts/tts_all.bib b/docs/source/tts/tts_all.bib index 0cd4fa8de..528b96976 100644 --- a/docs/source/tts/tts_all.bib +++ b/docs/source/tts/tts_all.bib @@ -37,4 +37,11 @@ eprint={2005.05106}, archivePrefix={arXiv}, primaryClass={cs.SD} -} \ No newline at end of file +} + +@article{kong2020hifi, + title={HiFi-GAN: Generative Adversarial Networks for Efficient and High Fidelity Speech Synthesis}, + author={Kong, Jungil and Kim, Jaehyeon and Bae, Jaekyoung}, + journal={arXiv preprint arXiv:2010.05646}, + year={2020} +} diff --git a/examples/tts/conf/hifigan/hifigan.yaml b/examples/tts/conf/hifigan/hifigan.yaml new file mode 100644 index 000000000..a5656e634 --- /dev/null +++ b/examples/tts/conf/hifigan/hifigan.yaml @@ -0,0 +1,84 @@ +name: "HifiGan" +train_dataset: ??? +validation_datasets: ??? + +defaults: + - model/generator: v1 + +model: + train_ds: + dataset: + _target_: "nemo.collections.tts.data.datalayers.AudioDataset" + manifest_filepath: ${train_dataset} + max_duration: null + min_duration: 0.75 + n_segments: 8192 + trim: false + dataloader_params: + drop_last: false + shuffle: true + batch_size: 16 + num_workers: 4 + + validation_ds: + dataset: + _target_: "nemo.collections.tts.data.datalayers.AudioDataset" + manifest_filepath: ${validation_datasets} + max_duration: null + min_duration: null + n_segments: -1 + trim: false + dataloader_params: + drop_last: false + shuffle: false + batch_size: 3 + num_workers: 4 + + preprocessor: + _target_: nemo.collections.asr.parts.features.FilterbankFeatures + dither: 0.0 + frame_splicing: 1 + nfilt: 80 + highfreq: 8000 + log: true + log_zero_guard_type: clamp + log_zero_guard_value: 1e-05 + lowfreq: 0 + mag_power: 1.0 + n_fft: 1024 + n_window_size: 1024 + n_window_stride: 256 + normalize: null + pad_to: 0 + pad_value: -11.52 + preemph: null + sample_rate: 22050 + stft_exact_pad: true + stft_conv: false + window: hann + use_grads: false + + optim: + lr: 0.0002 + adam_b1: 0.8 + adam_b2: 0.99 + lr_decay: 0.999 + lr_step: 15 + +trainer: + gpus: -1 # number of gpus + max_epochs: 1000000 + num_nodes: 1 + accelerator: ddp + accumulate_grad_batches: 1 + checkpoint_callback: False # Provided by exp_manager + logger: False # Provided by exp_manager + log_every_n_steps: 100 + check_val_every_n_epoch: 10 + automatic_optimization: false + +exp_manager: + exp_dir: null + name: ${name} + create_tensorboard_logger: True + create_checkpoint_callback: True diff --git a/examples/tts/conf/hifigan/model/generator/v1.yaml b/examples/tts/conf/hifigan/model/generator/v1.yaml new file mode 100644 index 000000000..30b819e97 --- /dev/null +++ b/examples/tts/conf/hifigan/model/generator/v1.yaml @@ -0,0 +1,8 @@ +# @package _group_ +_target_: nemo.collections.tts.modules.hifigan_modules.Generator +resblock: 1 +upsample_rates: [8,8,2,2] +upsample_kernel_sizes: [16,16,4,4] +upsample_initial_channel: 512 +resblock_kernel_sizes: [3,7,11] +resblock_dilation_sizes: [[1,3,5], [1,3,5], [1,3,5]] diff --git a/examples/tts/conf/hifigan/model/generator/v2.yaml b/examples/tts/conf/hifigan/model/generator/v2.yaml new file mode 100644 index 000000000..ef33801d5 --- /dev/null +++ b/examples/tts/conf/hifigan/model/generator/v2.yaml @@ -0,0 +1,8 @@ +# @package _group_ +_target_: nemo.collections.tts.modules.hifigan_modules.Generator +resblock: 1 +upsample_rates: [8,8,2,2] +upsample_kernel_sizes: [16,16,4,4] +upsample_initial_channel: 128 +resblock_kernel_sizes: [3,7,11] +resblock_dilation_sizes: [[1,3,5], [1,3,5], [1,3,5]] diff --git a/examples/tts/conf/hifigan/model/generator/v3.yaml b/examples/tts/conf/hifigan/model/generator/v3.yaml new file mode 100644 index 000000000..734f0a330 --- /dev/null +++ b/examples/tts/conf/hifigan/model/generator/v3.yaml @@ -0,0 +1,8 @@ +# @package _group_ +_target_: nemo.collections.tts.modules.hifigan_modules.Generator +resblock: 2 +upsample_rates: [8,8,4] +upsample_kernel_sizes: [16,16,8] +upsample_initial_channel: 256 +resblock_kernel_sizes: [3,5,7] +resblock_dilation_sizes: [[1,2], [2,6], [3,12]] diff --git a/examples/tts/conf/melgan.yaml b/examples/tts/conf/melgan.yaml index 25c3126a5..05b4293d9 100644 --- a/examples/tts/conf/melgan.yaml +++ b/examples/tts/conf/melgan.yaml @@ -28,7 +28,7 @@ model: _target_: "nemo.collections.tts.data.datalayers.AudioDataset" manifest_filepath: ${validation_datasets} max_duration: null - min_duration: 0.75 + min_duration: null n_segments: -1 trim: false dataloader_params: diff --git a/examples/tts/hifigan.py b/examples/tts/hifigan.py new file mode 100644 index 000000000..9afe8c1d1 --- /dev/null +++ b/examples/tts/hifigan.py @@ -0,0 +1,34 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytorch_lightning as pl + +from nemo.collections.common.callbacks import LogEpochTimeCallback +from nemo.collections.tts.models import HifiGanModel +from nemo.core.config import hydra_runner +from nemo.utils.exp_manager import exp_manager + + +@hydra_runner(config_path="conf/hifigan", config_name="hifigan") +def main(cfg): + trainer = pl.Trainer(**cfg.trainer) + exp_manager(trainer, cfg.get("exp_manager", None)) + model = HifiGanModel(cfg=cfg.model, trainer=trainer) + epoch_time_logger = LogEpochTimeCallback() + trainer.callbacks.extend([epoch_time_logger]) + trainer.fit(model) + + +if __name__ == '__main__': + main() # noqa pylint: disable=no-value-for-parameter diff --git a/nemo/collections/asr/parts/features.py b/nemo/collections/asr/parts/features.py index 3083f5855..3231502c1 100644 --- a/nemo/collections/asr/parts/features.py +++ b/nemo/collections/asr/parts/features.py @@ -220,6 +220,7 @@ class FilterbankFeatures(nn.Module): stft_conv=False, pad_value=0, mag_power=2.0, + use_grads=False, ): super().__init__() self.log_zero_guard_value = log_zero_guard_value @@ -301,6 +302,11 @@ class FilterbankFeatures(nn.Module): f"log_zero_guard_type parameter. It must be either 'add' or " f"'clamp'." ) + + self.use_grads = use_grads + if not use_grads: + self.forward = torch.no_grad()(self.forward) + # log_zero_guard_value is the the small we want to use, we support # an actual number, or "tiny", or "eps" self.log_zero_guard_type = log_zero_guard_type @@ -311,6 +317,7 @@ class FilterbankFeatures(nn.Module): logging.debug(f"n_mels: {nfilt}") logging.debug(f"fmin: {lowfreq}") logging.debug(f"fmax: {highfreq}") + logging.debug(f"using grads: {use_grads}") def log_zero_guard_value_fn(self, x): if isinstance(self.log_zero_guard_value, str): @@ -334,7 +341,6 @@ class FilterbankFeatures(nn.Module): def filter_banks(self): return self.fb - @torch.no_grad() def forward(self, x, seq_len): seq_len = self.get_seq_len(seq_len.float()) @@ -356,7 +362,9 @@ class FilterbankFeatures(nn.Module): # torch returns real, imag; so convert to magnitude if not self.stft_conv: - x = torch.sqrt(x.pow(2).sum(-1)) + # guard is needed for sqrt if grads are passed through + guard = 0 if not self.use_grads else CONSTANT + x = torch.sqrt(x.pow(2).sum(-1) + guard) # get power spectrum if self.mag_power != 1.0: diff --git a/nemo/collections/tts/losses/hifigan_losses.py b/nemo/collections/tts/losses/hifigan_losses.py new file mode 100644 index 000000000..dd6d61309 --- /dev/null +++ b/nemo/collections/tts/losses/hifigan_losses.py @@ -0,0 +1,132 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# MIT License +# +# Copyright (c) 2020 Jungil Kong +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# The forward functions onf the following classes are based on code from https://github.com/jik876/hifi-gan: +# FeatureMatchingLoss, DiscriminatorLoss, GeneratorLoss + +import torch + +from nemo.core.classes import Loss, typecheck +from nemo.core.neural_types.elements import LossType, VoidType +from nemo.core.neural_types.neural_type import NeuralType + + +class FeatureMatchingLoss(Loss): + """Feature Matching Loss module""" + + @property + def input_types(self): + return { + "fmap_r": NeuralType(elements_type=VoidType()), + "fmap_g": NeuralType(elements_type=VoidType()), + } + + @property + def output_types(self): + return { + "loss": NeuralType(elements_type=LossType()), + } + + @typecheck() + def forward(self, fmap_r, fmap_g): + loss = 0 + for dr, dg in zip(fmap_r, fmap_g): + for rl, gl in zip(dr, dg): + loss += torch.mean(torch.abs(rl - gl)) + + return loss * 2 + + +class DiscriminatorLoss(Loss): + """Discriminator Loss module""" + + @property + def input_types(self): + return { + "disc_real_outputs": NeuralType(('B', 'T'), VoidType()), + "disc_generated_outputs": NeuralType(('B', 'T'), VoidType()), + } + + @property + def output_types(self): + return { + "loss": NeuralType(elements_type=LossType()), + "real_losses": NeuralType(elements_type=LossType()), + "fake_losses": NeuralType(elements_type=LossType()), + } + + @typecheck() + def forward(self, disc_real_outputs, disc_generated_outputs): + loss = 0 + r_losses = [] + g_losses = [] + for dr, dg in zip(disc_real_outputs, disc_generated_outputs): + r_loss = torch.mean((1 - dr) ** 2) + g_loss = torch.mean(dg ** 2) + loss += r_loss + g_loss + r_losses.append(r_loss.item()) + g_losses.append(g_loss.item()) + + return loss, r_losses, g_losses + + +class GeneratorLoss(Loss): + """Generator Loss module""" + + @property + def input_types(self): + return { + "disc_outputs": NeuralType(('B', 'T'), VoidType()), + } + + @property + def output_types(self): + return { + "loss": NeuralType(elements_type=LossType()), + "fake_losses": NeuralType(elements_type=LossType()), + } + + @typecheck() + def forward(self, disc_outputs): + loss = 0 + gen_losses = [] + for dg in disc_outputs: + l = torch.mean((1 - dg) ** 2) + gen_losses.append(l) + loss += l + + return loss, gen_losses diff --git a/nemo/collections/tts/models/__init__.py b/nemo/collections/tts/models/__init__.py index aeb50cc65..22256e17f 100644 --- a/nemo/collections/tts/models/__init__.py +++ b/nemo/collections/tts/models/__init__.py @@ -15,6 +15,7 @@ from nemo.collections.tts.models.degli import DegliModel from nemo.collections.tts.models.ed_mel2spec import EDMel2SpecModel from nemo.collections.tts.models.glow_tts import GlowTTSModel +from nemo.collections.tts.models.hifigan import HifiGanModel from nemo.collections.tts.models.melgan import MelGanModel from nemo.collections.tts.models.squeezewave import SqueezeWaveModel from nemo.collections.tts.models.tacotron2 import Tacotron2Model @@ -37,4 +38,5 @@ __all__ = [ "TalkNetSpectModel", "UniGlowModel", "MelGanModel", + "HifiGanModel", ] diff --git a/nemo/collections/tts/models/hifigan.py b/nemo/collections/tts/models/hifigan.py new file mode 100644 index 000000000..e746e7937 --- /dev/null +++ b/nemo/collections/tts/models/hifigan.py @@ -0,0 +1,219 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import itertools + +import torch +import torch.nn.functional as F +import wandb +from hydra.utils import instantiate +from omegaconf import DictConfig, OmegaConf, open_dict +from pytorch_lightning.loggers.wandb import WandbLogger + +from nemo.collections.tts.helpers.helpers import plot_spectrogram_to_numpy +from nemo.collections.tts.losses.hifigan_losses import DiscriminatorLoss, FeatureMatchingLoss, GeneratorLoss +from nemo.collections.tts.models.base import Vocoder +from nemo.collections.tts.modules.hifigan_modules import MultiPeriodDiscriminator, MultiScaleDiscriminator +from nemo.core.classes.common import typecheck +from nemo.core.neural_types.elements import AudioSignal, MelSpectrogramType +from nemo.core.neural_types.neural_type import NeuralType +from nemo.utils import logging + + +class HifiGanModel(Vocoder): + def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None): + if isinstance(cfg, dict): + cfg = OmegaConf.create(cfg) + super().__init__(cfg=cfg, trainer=trainer) + + self.audio_to_melspec_precessor = instantiate(cfg.preprocessor) + # use a different melspec extractor because: + # 1. we need to pass grads + # 2. we need remove fmax limitation + self.trg_melspec_fn = instantiate(cfg.preprocessor, highfreq=None, use_grads=True) + self.generator = instantiate(cfg.generator) + self.mpd = MultiPeriodDiscriminator() + self.msd = MultiScaleDiscriminator() + self.feature_loss = FeatureMatchingLoss() + self.discriminator_loss = DiscriminatorLoss() + self.generator_loss = GeneratorLoss() + + self.sample_rate = self._cfg.preprocessor.sample_rate + + def configure_optimizers(self): + self.optim_g = torch.optim.AdamW( + self.generator.parameters(), self._cfg.optim.lr, betas=[self._cfg.optim.adam_b1, self._cfg.optim.adam_b2] + ) + self.optim_d = torch.optim.AdamW( + itertools.chain(self.msd.parameters(), self.mpd.parameters()), + self._cfg.optim.lr, + betas=[self._cfg.optim.adam_b1, self._cfg.optim.adam_b2], + ) + + self.scheduler_g = torch.optim.lr_scheduler.StepLR( + self.optim_g, step_size=self._cfg.optim.lr_step, gamma=self._cfg.optim.lr_decay, + ) + self.scheduler_d = torch.optim.lr_scheduler.StepLR( + self.optim_d, step_size=self._cfg.optim.lr_step, gamma=self._cfg.optim.lr_decay, + ) + + return [self.optim_g, self.optim_d], [self.scheduler_g, self.scheduler_d] + + @property + def input_types(self): + return { + "spec": NeuralType(('B', 'D', 'T'), MelSpectrogramType()), + } + + @property + def output_types(self): + return { + "audio": NeuralType(('B', 'S', 'T'), AudioSignal(self.sample_rate)), + } + + @typecheck() + def forward(self, *, spec): + """ + Runs the generator, for inputs and outputs see input_types, and output_types + """ + return self.generator(x=spec) + + @typecheck(output_types={"audio": NeuralType(('B', 'T'), AudioSignal())}) + def convert_spectrogram_to_audio(self, spec: 'torch.tensor') -> 'torch.tensor': + return self(spec=spec).squeeze(1) + + def training_step(self, batch, batch_idx, optimizer_idx): + audio, audio_len = batch + # mel as input for generator + audio_mel, _ = self.audio_to_melspec_precessor(audio, audio_len) + # mel as input for L1 mel loss + audio_trg_mel, _ = self.trg_melspec_fn(audio, audio_len) + audio = audio.unsqueeze(1) + + audio_pred = self.generator(x=audio_mel) + audio_pred_mel, _ = self.trg_melspec_fn(audio_pred.squeeze(1), audio_len) + + # train discriminator + self.optim_d.zero_grad() + mpd_score_real, mpd_score_gen, _, _ = self.mpd(y=audio, y_hat=audio_pred.detach()) + loss_disc_mpd, _, _ = self.discriminator_loss( + disc_real_outputs=mpd_score_real, disc_generated_outputs=mpd_score_gen + ) + msd_score_real, msd_score_gen, _, _ = self.msd(y=audio, y_hat=audio_pred.detach()) + loss_disc_msd, _, _ = self.discriminator_loss( + disc_real_outputs=msd_score_real, disc_generated_outputs=msd_score_gen + ) + loss_d = loss_disc_msd + loss_disc_mpd + self.manual_backward(loss_d, self.optim_d) + self.optim_d.step() + + # train generator + self.optim_g.zero_grad() + loss_mel = F.l1_loss(audio_pred_mel, audio_trg_mel) * 45 + _, mpd_score_gen, fmap_mpd_real, fmap_mpd_gen = self.mpd(y=audio, y_hat=audio_pred) + _, msd_score_gen, fmap_msd_real, fmap_msd_gen = self.msd(y=audio, y_hat=audio_pred) + loss_fm_mpd = self.feature_loss(fmap_r=fmap_mpd_real, fmap_g=fmap_mpd_gen) + loss_fm_msd = self.feature_loss(fmap_r=fmap_msd_real, fmap_g=fmap_msd_gen) + loss_gen_mpd, _ = self.generator_loss(disc_outputs=mpd_score_gen) + loss_gen_msd, _ = self.generator_loss(disc_outputs=msd_score_gen) + loss_g = loss_gen_msd + loss_gen_mpd + loss_fm_msd + loss_fm_mpd + loss_mel + self.manual_backward(loss_g, self.optim_g) + self.optim_g.step() + + metrics = { + "g_l1_loss": loss_mel, + "g_loss_fm_mpd": loss_fm_mpd, + "g_loss_fm_msd": loss_fm_msd, + "g_loss_gen_mpd": loss_gen_mpd, + "g_loss_gen_msd": loss_gen_msd, + "g_loss": loss_g, + "d_loss_mpd": loss_disc_mpd, + "d_loss_msd": loss_disc_msd, + "d_loss": loss_d, + "global_step": self.global_step, + "lr": self.optim_g.param_groups[0]['lr'], + } + self.log_dict(metrics, on_step=False, on_epoch=True, sync_dist=True) + + def validation_step(self, batch, batch_idx): + audio, audio_len = batch + audio_mel, audio_mel_len = self.audio_to_melspec_precessor(audio, audio_len) + audio_pred = self(spec=audio_mel) + + audio_pred_mel, _ = self.audio_to_melspec_precessor(audio_pred.squeeze(1), audio_len) + loss_mel = F.l1_loss(audio_mel, audio_pred_mel) + + self.log("val_loss", loss_mel, prog_bar=True, sync_dist=True) + + # plot audio once per epoch + if batch_idx == 0 and isinstance(self.logger, WandbLogger): + clips = [] + specs = [] + for i in range(min(5, audio.shape[0])): + clips += [ + wandb.Audio( + audio[i, : audio_len[i]].data.cpu().numpy(), + caption=f"real audio {i}", + sample_rate=self.sample_rate, + ), + wandb.Audio( + audio_pred[i, 0, : audio_len[i]].data.cpu().numpy().astype('float32'), + caption=f"generated audio {i}", + sample_rate=self.sample_rate, + ), + ] + specs += [ + wandb.Image( + plot_spectrogram_to_numpy(audio_mel[i, :, : audio_mel_len[i]].data.cpu().numpy()), + caption=f"real audio {i}", + ), + wandb.Image( + plot_spectrogram_to_numpy(audio_pred_mel[i, :, : audio_mel_len[i]].data.cpu().numpy()), + caption=f"generated audio {i}", + ), + ] + + self.logger.experiment.log({"audio": clips, "specs": specs}, commit=False) + + def __setup_dataloader_from_config(self, cfg, shuffle_should_be: bool = True, name: str = "train"): + if "dataset" not in cfg or not isinstance(cfg.dataset, DictConfig): + raise ValueError(f"No dataset for {name}") + if "dataloader_params" not in cfg or not isinstance(cfg.dataloader_params, DictConfig): + raise ValueError(f"No dataloder_params for {name}") + if shuffle_should_be: + if 'shuffle' not in cfg.dataloader_params: + logging.warning( + f"Shuffle should be set to True for {self}'s {name} dataloader but was not found in its " + "config. Manually setting to True" + ) + with open_dict(cfg["dataloader_params"]): + cfg.dataloader_params.shuffle = True + elif not cfg.dataloader_params.shuffle: + logging.error(f"The {name} dataloader for {self} has shuffle set to False!!!") + elif not shuffle_should_be and cfg.dataloader_params.shuffle: + logging.error(f"The {name} dataloader for {self} has shuffle set to True!!!") + + dataset = instantiate(cfg.dataset) + return torch.utils.data.DataLoader(dataset, collate_fn=dataset.collate_fn, **cfg.dataloader_params) + + def setup_training_data(self, cfg): + self._train_dl = self.__setup_dataloader_from_config(cfg) + + def setup_validation_data(self, cfg): + self._validation_dl = self.__setup_dataloader_from_config(cfg, shuffle_should_be=False, name="validation") + + @classmethod + def list_available_models(cls) -> 'Optional[Dict[str, str]]': + # TODO + pass diff --git a/nemo/collections/tts/modules/hifigan_modules.py b/nemo/collections/tts/modules/hifigan_modules.py new file mode 100644 index 000000000..37fe242ae --- /dev/null +++ b/nemo/collections/tts/modules/hifigan_modules.py @@ -0,0 +1,434 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# MIT License +# +# Copyright (c) 2020 Jungil Kong +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# The following functions/classes were based on code from https://github.com/jik876/hifi-gan: +# ResBlock1, ResBlock2, Generator, DiscriminatorP, DiscriminatorS, MultiScaleDiscriminator, +# MultiPeriodDiscriminator, init_weights, get_padding + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn import AvgPool1d, Conv1d, Conv2d, ConvTranspose1d +from torch.nn.utils import remove_weight_norm, spectral_norm, weight_norm + +from nemo.core.classes import NeuralModule +from nemo.core.classes.common import typecheck +from nemo.core.neural_types.elements import AudioSignal, MelSpectrogramType, VoidType +from nemo.core.neural_types.neural_type import NeuralType + +LRELU_SLOPE = 0.1 + + +def init_weights(m, mean=0.0, std=0.01): + classname = m.__class__.__name__ + if classname.find("Conv") != -1: + m.weight.data.normal_(mean, std) + + +def get_padding(kernel_size, dilation=1): + return int((kernel_size * dilation - dilation) / 2) + + +class ResBlock1(torch.nn.Module): + def __init__(self, channels, kernel_size, dilation): + super(ResBlock1, self).__init__() + self.convs1 = nn.ModuleList( + [ + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[2], + padding=get_padding(kernel_size, dilation[2]), + ) + ), + ] + ) + self.convs1.apply(init_weights) + + self.convs2 = nn.ModuleList( + [ + weight_norm( + Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1)) + ), + weight_norm( + Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1)) + ), + weight_norm( + Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1)) + ), + ] + ) + self.convs2.apply(init_weights) + + def forward(self, x): + for c1, c2 in zip(self.convs1, self.convs2): + xt = F.leaky_relu(x, LRELU_SLOPE) + xt = c1(xt) + xt = F.leaky_relu(xt, LRELU_SLOPE) + xt = c2(xt) + x = xt + x + return x + + def remove_weight_norm(self): + for l in self.convs1: + remove_weight_norm(l) + for l in self.convs2: + remove_weight_norm(l) + + +class ResBlock2(torch.nn.Module): + def __init__(self, channels, kernel_size, dilation): + super(ResBlock2, self).__init__() + self.convs = nn.ModuleList( + [ + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]), + ) + ), + ] + ) + self.convs.apply(init_weights) + + def forward(self, x): + for c in self.convs: + xt = F.leaky_relu(x, LRELU_SLOPE) + xt = c(xt) + x = xt + x + return x + + def remove_weight_norm(self): + for l in self.convs: + remove_weight_norm(l) + + +class Generator(NeuralModule): + def __init__( + self, + resblock, + upsample_rates, + upsample_kernel_sizes, + upsample_initial_channel, + resblock_kernel_sizes, + resblock_dilation_sizes, + ): + super(Generator, self).__init__() + self.num_kernels = len(resblock_kernel_sizes) + self.num_upsamples = len(upsample_rates) + self.conv_pre = weight_norm(Conv1d(80, upsample_initial_channel, 7, 1, padding=3)) + resblock = ResBlock1 if resblock == 1 else ResBlock2 + + self.ups = nn.ModuleList() + for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)): + self.ups.append( + weight_norm( + ConvTranspose1d( + upsample_initial_channel // (2 ** i), + upsample_initial_channel // (2 ** (i + 1)), + k, + u, + padding=(k - u) // 2, + ) + ) + ) + + self.resblocks = nn.ModuleList() + for i in range(len(self.ups)): + ch = upsample_initial_channel // (2 ** (i + 1)) + for j, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)): + self.resblocks.append(resblock(ch, k, d)) + + self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3)) + self.ups.apply(init_weights) + self.conv_post.apply(init_weights) + + @property + def input_types(self): + return { + "x": NeuralType(('B', 'D', 'T'), MelSpectrogramType()), + } + + @property + def output_types(self): + return { + "audio": NeuralType(('B', 'S', 'T'), AudioSignal()), + } + + @typecheck() + def forward(self, x): + x = self.conv_pre(x) + for i in range(self.num_upsamples): + x = F.leaky_relu(x, LRELU_SLOPE) + x = self.ups[i](x) + xs = None + for j in range(self.num_kernels): + if xs is None: + xs = self.resblocks[i * self.num_kernels + j](x) + else: + xs += self.resblocks[i * self.num_kernels + j](x) + x = xs / self.num_kernels + x = F.leaky_relu(x) + x = self.conv_post(x) + x = torch.tanh(x) + + return x + + def remove_weight_norm(self): + print('Removing weight norm...') + for l in self.ups: + remove_weight_norm(l) + for l in self.resblocks: + l.remove_weight_norm() + remove_weight_norm(self.conv_pre) + remove_weight_norm(self.conv_post) + + +class DiscriminatorP(NeuralModule): + def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False): + super(DiscriminatorP, self).__init__() + self.period = period + norm_f = weight_norm if use_spectral_norm == False else spectral_norm + self.convs = nn.ModuleList( + [ + norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), + norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), + norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), + norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), + norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))), + ] + ) + self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0))) + + @property + def input_types(self): + return { + "x": NeuralType(('B', 'S', 'T'), AudioSignal()), + } + + @property + def output_types(self): + return { + "decision": NeuralType(('B', 'T'), VoidType()), + "feature_maps": NeuralType(("B", "C", "H", "W"), VoidType()), + } + + @typecheck() + def forward(self, x): + fmap = [] + + # 1d to 2d + b, c, t = x.shape + if t % self.period != 0: # pad first + n_pad = self.period - (t % self.period) + x = F.pad(x, (0, n_pad), "reflect") + t = t + n_pad + x = x.view(b, c, t // self.period, self.period) + + for l in self.convs: + x = l(x) + x = F.leaky_relu(x, LRELU_SLOPE) + fmap.append(x) + x = self.conv_post(x) + fmap.append(x) + x = torch.flatten(x, 1, -1) + + return x, fmap + + +class MultiPeriodDiscriminator(NeuralModule): + def __init__(self): + super(MultiPeriodDiscriminator, self).__init__() + self.discriminators = nn.ModuleList( + [DiscriminatorP(2), DiscriminatorP(3), DiscriminatorP(5), DiscriminatorP(7), DiscriminatorP(11),] + ) + + @property + def output_types(self): + return { + "y": NeuralType(('B', 'S', 'T'), AudioSignal()), + "y_hat": NeuralType(('B', 'S', 'T'), AudioSignal()), + } + + @property + def output_types(self): + return { + "real_scores": NeuralType(('B', 'T'), VoidType()), + "fake_scores": NeuralType(('B', 'T'), VoidType()), + "real_feature_maps": NeuralType(("B", "C", "H", "W"), VoidType()), + "fake_feature_maps": NeuralType(("B", "C", "H", "W"), VoidType()), + } + + @typecheck() + def forward(self, y, y_hat): + y_d_rs = [] + y_d_gs = [] + fmap_rs = [] + fmap_gs = [] + for i, d in enumerate(self.discriminators): + y_d_r, fmap_r = d(x=y) + y_d_g, fmap_g = d(x=y_hat) + y_d_rs.append(y_d_r) + fmap_rs.append(fmap_r) + y_d_gs.append(y_d_g) + fmap_gs.append(fmap_g) + + return y_d_rs, y_d_gs, fmap_rs, fmap_gs + + +class DiscriminatorS(NeuralModule): + def __init__(self, use_spectral_norm=False): + super(DiscriminatorS, self).__init__() + norm_f = weight_norm if use_spectral_norm == False else spectral_norm + self.convs = nn.ModuleList( + [ + norm_f(Conv1d(1, 128, 15, 1, padding=7)), + norm_f(Conv1d(128, 128, 41, 2, groups=4, padding=20)), + norm_f(Conv1d(128, 256, 41, 2, groups=16, padding=20)), + norm_f(Conv1d(256, 512, 41, 4, groups=16, padding=20)), + norm_f(Conv1d(512, 1024, 41, 4, groups=16, padding=20)), + norm_f(Conv1d(1024, 1024, 41, 1, groups=16, padding=20)), + norm_f(Conv1d(1024, 1024, 5, 1, padding=2)), + ] + ) + self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1)) + + @property + def input_types(self): + return { + "x": NeuralType(('B', 'S', 'T'), AudioSignal()), + } + + @property + def output_types(self): + return { + "decision": NeuralType(('B', 'T'), VoidType()), + "feature_maps": NeuralType(("B", "C", "T"), VoidType()), + } + + @typecheck() + def forward(self, x): + fmap = [] + for l in self.convs: + x = l(x) + x = F.leaky_relu(x, LRELU_SLOPE) + fmap.append(x) + x = self.conv_post(x) + fmap.append(x) + x = torch.flatten(x, 1, -1) + + return x, fmap + + +class MultiScaleDiscriminator(NeuralModule): + def __init__(self): + super(MultiScaleDiscriminator, self).__init__() + self.discriminators = nn.ModuleList( + [DiscriminatorS(use_spectral_norm=True), DiscriminatorS(), DiscriminatorS(),] + ) + self.meanpools = nn.ModuleList([AvgPool1d(4, 2, padding=2), AvgPool1d(4, 2, padding=2)]) + + @property + def output_types(self): + return { + "y": NeuralType(('B', 'S', 'T'), AudioSignal()), + "y_hat": NeuralType(('B', 'S', 'T'), AudioSignal()), + } + + @property + def output_types(self): + return { + "real_scores": NeuralType(('B', 'T'), VoidType()), + "fake_scores": NeuralType(('B', 'T'), VoidType()), + "real_feature_maps": NeuralType(("B", "C", "T"), VoidType()), + "fake_feature_maps": NeuralType(("B", "C", "T"), VoidType()), + } + + @typecheck() + def forward(self, y, y_hat): + y_d_rs = [] + y_d_gs = [] + fmap_rs = [] + fmap_gs = [] + for i, d in enumerate(self.discriminators): + if i != 0: + y = self.meanpools[i - 1](y) + y_hat = self.meanpools[i - 1](y_hat) + y_d_r, fmap_r = d(x=y) + y_d_g, fmap_g = d(x=y_hat) + y_d_rs.append(y_d_r) + fmap_rs.append(fmap_r) + y_d_gs.append(y_d_g) + fmap_gs.append(fmap_g) + + return y_d_rs, y_d_gs, fmap_rs, fmap_gs