HifiGAN MelSpectrogram Vocoder Model (#1706)

* HifiGAN: initial commit

Signed-off-by: Felix Kreuk <felixkreuk@gmail.com>

* switched to manual optimization, added grads to FilterBanks object,
added v1,v2,v3 configurations for the generator
added multiple audio/spec examples in wandb

Signed-off-by: Felix Kreuk <felixkreuk@gmail.com>

* added trg_mel_fn for different mel params in L1 loss

Signed-off-by: Felix Kreuk <felixkreuk@gmail.com>

* fixed ci checks

Signed-off-by: Felix Kreuk <felixkreuk@gmail.com>

* moved losses to separate classes, fixed exp_manager in yaml, style fixes

Signed-off-by: Felix Kreuk <felixkreuk@gmail.com>

* fixed `min_duration` in yamls of MelGAN and HifiGAN

Signed-off-by: Felix Kreuk <felixkreuk@gmail.com>

* fixed f_max in loss calculation, added docs

Signed-off-by: Felix Kreuk <felixkreuk@gmail.com>

Co-authored-by: Oleksii Kuchaiev <okuchaiev@users.noreply.github.com>
Co-authored-by: Jason <jasoli@nvidia.com>
This commit is contained in:
Felix Kreuk 2021-02-16 23:56:29 +02:00 committed by GitHub
parent 562087498d
commit 7d8b7d679a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
14 changed files with 955 additions and 4 deletions

1
.gitignore vendored
View file

@ -11,6 +11,7 @@ bert.pt.json
work work
runs runs
fastspeech_output fastspeech_output
.hydra
# Byte-compiled / optimized / DLL files # Byte-compiled / optimized / DLL files
__pycache__/ __pycache__/

View file

@ -51,6 +51,12 @@ Full-band MelGAN
The :class:`MelGanModel<nemo.collections.tts.models.MelGanModel>` class implements Full-band MelGAN as described in The :class:`MelGanModel<nemo.collections.tts.models.MelGanModel>` class implements Full-band MelGAN as described in
:cite:`yang2020multiband`. MelGAN is a GAN-based vocoder that converts mel-spectrograms to audio. :cite:`yang2020multiband`. MelGAN is a GAN-based vocoder that converts mel-spectrograms to audio.
HifiGAN
-------
The :class:`HifiGanModel<nemo.collections.tts.models.HifiGanModel>` class implements HifiGAN as described in
:cite:`kong2020hifi`. HifiGAN is a GAN-based vocoder that converts mel-spectrograms to audio.
References References
---------- ----------

View file

@ -37,4 +37,11 @@
eprint={2005.05106}, eprint={2005.05106},
archivePrefix={arXiv}, archivePrefix={arXiv},
primaryClass={cs.SD} primaryClass={cs.SD}
} }
@article{kong2020hifi,
title={HiFi-GAN: Generative Adversarial Networks for Efficient and High Fidelity Speech Synthesis},
author={Kong, Jungil and Kim, Jaehyeon and Bae, Jaekyoung},
journal={arXiv preprint arXiv:2010.05646},
year={2020}
}

View file

@ -0,0 +1,84 @@
name: "HifiGan"
train_dataset: ???
validation_datasets: ???
defaults:
- model/generator: v1
model:
train_ds:
dataset:
_target_: "nemo.collections.tts.data.datalayers.AudioDataset"
manifest_filepath: ${train_dataset}
max_duration: null
min_duration: 0.75
n_segments: 8192
trim: false
dataloader_params:
drop_last: false
shuffle: true
batch_size: 16
num_workers: 4
validation_ds:
dataset:
_target_: "nemo.collections.tts.data.datalayers.AudioDataset"
manifest_filepath: ${validation_datasets}
max_duration: null
min_duration: null
n_segments: -1
trim: false
dataloader_params:
drop_last: false
shuffle: false
batch_size: 3
num_workers: 4
preprocessor:
_target_: nemo.collections.asr.parts.features.FilterbankFeatures
dither: 0.0
frame_splicing: 1
nfilt: 80
highfreq: 8000
log: true
log_zero_guard_type: clamp
log_zero_guard_value: 1e-05
lowfreq: 0
mag_power: 1.0
n_fft: 1024
n_window_size: 1024
n_window_stride: 256
normalize: null
pad_to: 0
pad_value: -11.52
preemph: null
sample_rate: 22050
stft_exact_pad: true
stft_conv: false
window: hann
use_grads: false
optim:
lr: 0.0002
adam_b1: 0.8
adam_b2: 0.99
lr_decay: 0.999
lr_step: 15
trainer:
gpus: -1 # number of gpus
max_epochs: 1000000
num_nodes: 1
accelerator: ddp
accumulate_grad_batches: 1
checkpoint_callback: False # Provided by exp_manager
logger: False # Provided by exp_manager
log_every_n_steps: 100
check_val_every_n_epoch: 10
automatic_optimization: false
exp_manager:
exp_dir: null
name: ${name}
create_tensorboard_logger: True
create_checkpoint_callback: True

View file

@ -0,0 +1,8 @@
# @package _group_
_target_: nemo.collections.tts.modules.hifigan_modules.Generator
resblock: 1
upsample_rates: [8,8,2,2]
upsample_kernel_sizes: [16,16,4,4]
upsample_initial_channel: 512
resblock_kernel_sizes: [3,7,11]
resblock_dilation_sizes: [[1,3,5], [1,3,5], [1,3,5]]

View file

@ -0,0 +1,8 @@
# @package _group_
_target_: nemo.collections.tts.modules.hifigan_modules.Generator
resblock: 1
upsample_rates: [8,8,2,2]
upsample_kernel_sizes: [16,16,4,4]
upsample_initial_channel: 128
resblock_kernel_sizes: [3,7,11]
resblock_dilation_sizes: [[1,3,5], [1,3,5], [1,3,5]]

View file

@ -0,0 +1,8 @@
# @package _group_
_target_: nemo.collections.tts.modules.hifigan_modules.Generator
resblock: 2
upsample_rates: [8,8,4]
upsample_kernel_sizes: [16,16,8]
upsample_initial_channel: 256
resblock_kernel_sizes: [3,5,7]
resblock_dilation_sizes: [[1,2], [2,6], [3,12]]

View file

@ -28,7 +28,7 @@ model:
_target_: "nemo.collections.tts.data.datalayers.AudioDataset" _target_: "nemo.collections.tts.data.datalayers.AudioDataset"
manifest_filepath: ${validation_datasets} manifest_filepath: ${validation_datasets}
max_duration: null max_duration: null
min_duration: 0.75 min_duration: null
n_segments: -1 n_segments: -1
trim: false trim: false
dataloader_params: dataloader_params:

34
examples/tts/hifigan.py Normal file
View file

@ -0,0 +1,34 @@
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytorch_lightning as pl
from nemo.collections.common.callbacks import LogEpochTimeCallback
from nemo.collections.tts.models import HifiGanModel
from nemo.core.config import hydra_runner
from nemo.utils.exp_manager import exp_manager
@hydra_runner(config_path="conf/hifigan", config_name="hifigan")
def main(cfg):
trainer = pl.Trainer(**cfg.trainer)
exp_manager(trainer, cfg.get("exp_manager", None))
model = HifiGanModel(cfg=cfg.model, trainer=trainer)
epoch_time_logger = LogEpochTimeCallback()
trainer.callbacks.extend([epoch_time_logger])
trainer.fit(model)
if __name__ == '__main__':
main() # noqa pylint: disable=no-value-for-parameter

View file

@ -220,6 +220,7 @@ class FilterbankFeatures(nn.Module):
stft_conv=False, stft_conv=False,
pad_value=0, pad_value=0,
mag_power=2.0, mag_power=2.0,
use_grads=False,
): ):
super().__init__() super().__init__()
self.log_zero_guard_value = log_zero_guard_value self.log_zero_guard_value = log_zero_guard_value
@ -301,6 +302,11 @@ class FilterbankFeatures(nn.Module):
f"log_zero_guard_type parameter. It must be either 'add' or " f"log_zero_guard_type parameter. It must be either 'add' or "
f"'clamp'." f"'clamp'."
) )
self.use_grads = use_grads
if not use_grads:
self.forward = torch.no_grad()(self.forward)
# log_zero_guard_value is the the small we want to use, we support # log_zero_guard_value is the the small we want to use, we support
# an actual number, or "tiny", or "eps" # an actual number, or "tiny", or "eps"
self.log_zero_guard_type = log_zero_guard_type self.log_zero_guard_type = log_zero_guard_type
@ -311,6 +317,7 @@ class FilterbankFeatures(nn.Module):
logging.debug(f"n_mels: {nfilt}") logging.debug(f"n_mels: {nfilt}")
logging.debug(f"fmin: {lowfreq}") logging.debug(f"fmin: {lowfreq}")
logging.debug(f"fmax: {highfreq}") logging.debug(f"fmax: {highfreq}")
logging.debug(f"using grads: {use_grads}")
def log_zero_guard_value_fn(self, x): def log_zero_guard_value_fn(self, x):
if isinstance(self.log_zero_guard_value, str): if isinstance(self.log_zero_guard_value, str):
@ -334,7 +341,6 @@ class FilterbankFeatures(nn.Module):
def filter_banks(self): def filter_banks(self):
return self.fb return self.fb
@torch.no_grad()
def forward(self, x, seq_len): def forward(self, x, seq_len):
seq_len = self.get_seq_len(seq_len.float()) seq_len = self.get_seq_len(seq_len.float())
@ -356,7 +362,9 @@ class FilterbankFeatures(nn.Module):
# torch returns real, imag; so convert to magnitude # torch returns real, imag; so convert to magnitude
if not self.stft_conv: if not self.stft_conv:
x = torch.sqrt(x.pow(2).sum(-1)) # guard is needed for sqrt if grads are passed through
guard = 0 if not self.use_grads else CONSTANT
x = torch.sqrt(x.pow(2).sum(-1) + guard)
# get power spectrum # get power spectrum
if self.mag_power != 1.0: if self.mag_power != 1.0:

View file

@ -0,0 +1,132 @@
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# MIT License
#
# Copyright (c) 2020 Jungil Kong
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
# The forward functions onf the following classes are based on code from https://github.com/jik876/hifi-gan:
# FeatureMatchingLoss, DiscriminatorLoss, GeneratorLoss
import torch
from nemo.core.classes import Loss, typecheck
from nemo.core.neural_types.elements import LossType, VoidType
from nemo.core.neural_types.neural_type import NeuralType
class FeatureMatchingLoss(Loss):
"""Feature Matching Loss module"""
@property
def input_types(self):
return {
"fmap_r": NeuralType(elements_type=VoidType()),
"fmap_g": NeuralType(elements_type=VoidType()),
}
@property
def output_types(self):
return {
"loss": NeuralType(elements_type=LossType()),
}
@typecheck()
def forward(self, fmap_r, fmap_g):
loss = 0
for dr, dg in zip(fmap_r, fmap_g):
for rl, gl in zip(dr, dg):
loss += torch.mean(torch.abs(rl - gl))
return loss * 2
class DiscriminatorLoss(Loss):
"""Discriminator Loss module"""
@property
def input_types(self):
return {
"disc_real_outputs": NeuralType(('B', 'T'), VoidType()),
"disc_generated_outputs": NeuralType(('B', 'T'), VoidType()),
}
@property
def output_types(self):
return {
"loss": NeuralType(elements_type=LossType()),
"real_losses": NeuralType(elements_type=LossType()),
"fake_losses": NeuralType(elements_type=LossType()),
}
@typecheck()
def forward(self, disc_real_outputs, disc_generated_outputs):
loss = 0
r_losses = []
g_losses = []
for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
r_loss = torch.mean((1 - dr) ** 2)
g_loss = torch.mean(dg ** 2)
loss += r_loss + g_loss
r_losses.append(r_loss.item())
g_losses.append(g_loss.item())
return loss, r_losses, g_losses
class GeneratorLoss(Loss):
"""Generator Loss module"""
@property
def input_types(self):
return {
"disc_outputs": NeuralType(('B', 'T'), VoidType()),
}
@property
def output_types(self):
return {
"loss": NeuralType(elements_type=LossType()),
"fake_losses": NeuralType(elements_type=LossType()),
}
@typecheck()
def forward(self, disc_outputs):
loss = 0
gen_losses = []
for dg in disc_outputs:
l = torch.mean((1 - dg) ** 2)
gen_losses.append(l)
loss += l
return loss, gen_losses

View file

@ -15,6 +15,7 @@
from nemo.collections.tts.models.degli import DegliModel from nemo.collections.tts.models.degli import DegliModel
from nemo.collections.tts.models.ed_mel2spec import EDMel2SpecModel from nemo.collections.tts.models.ed_mel2spec import EDMel2SpecModel
from nemo.collections.tts.models.glow_tts import GlowTTSModel from nemo.collections.tts.models.glow_tts import GlowTTSModel
from nemo.collections.tts.models.hifigan import HifiGanModel
from nemo.collections.tts.models.melgan import MelGanModel from nemo.collections.tts.models.melgan import MelGanModel
from nemo.collections.tts.models.squeezewave import SqueezeWaveModel from nemo.collections.tts.models.squeezewave import SqueezeWaveModel
from nemo.collections.tts.models.tacotron2 import Tacotron2Model from nemo.collections.tts.models.tacotron2 import Tacotron2Model
@ -37,4 +38,5 @@ __all__ = [
"TalkNetSpectModel", "TalkNetSpectModel",
"UniGlowModel", "UniGlowModel",
"MelGanModel", "MelGanModel",
"HifiGanModel",
] ]

View file

@ -0,0 +1,219 @@
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import itertools
import torch
import torch.nn.functional as F
import wandb
from hydra.utils import instantiate
from omegaconf import DictConfig, OmegaConf, open_dict
from pytorch_lightning.loggers.wandb import WandbLogger
from nemo.collections.tts.helpers.helpers import plot_spectrogram_to_numpy
from nemo.collections.tts.losses.hifigan_losses import DiscriminatorLoss, FeatureMatchingLoss, GeneratorLoss
from nemo.collections.tts.models.base import Vocoder
from nemo.collections.tts.modules.hifigan_modules import MultiPeriodDiscriminator, MultiScaleDiscriminator
from nemo.core.classes.common import typecheck
from nemo.core.neural_types.elements import AudioSignal, MelSpectrogramType
from nemo.core.neural_types.neural_type import NeuralType
from nemo.utils import logging
class HifiGanModel(Vocoder):
def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None):
if isinstance(cfg, dict):
cfg = OmegaConf.create(cfg)
super().__init__(cfg=cfg, trainer=trainer)
self.audio_to_melspec_precessor = instantiate(cfg.preprocessor)
# use a different melspec extractor because:
# 1. we need to pass grads
# 2. we need remove fmax limitation
self.trg_melspec_fn = instantiate(cfg.preprocessor, highfreq=None, use_grads=True)
self.generator = instantiate(cfg.generator)
self.mpd = MultiPeriodDiscriminator()
self.msd = MultiScaleDiscriminator()
self.feature_loss = FeatureMatchingLoss()
self.discriminator_loss = DiscriminatorLoss()
self.generator_loss = GeneratorLoss()
self.sample_rate = self._cfg.preprocessor.sample_rate
def configure_optimizers(self):
self.optim_g = torch.optim.AdamW(
self.generator.parameters(), self._cfg.optim.lr, betas=[self._cfg.optim.adam_b1, self._cfg.optim.adam_b2]
)
self.optim_d = torch.optim.AdamW(
itertools.chain(self.msd.parameters(), self.mpd.parameters()),
self._cfg.optim.lr,
betas=[self._cfg.optim.adam_b1, self._cfg.optim.adam_b2],
)
self.scheduler_g = torch.optim.lr_scheduler.StepLR(
self.optim_g, step_size=self._cfg.optim.lr_step, gamma=self._cfg.optim.lr_decay,
)
self.scheduler_d = torch.optim.lr_scheduler.StepLR(
self.optim_d, step_size=self._cfg.optim.lr_step, gamma=self._cfg.optim.lr_decay,
)
return [self.optim_g, self.optim_d], [self.scheduler_g, self.scheduler_d]
@property
def input_types(self):
return {
"spec": NeuralType(('B', 'D', 'T'), MelSpectrogramType()),
}
@property
def output_types(self):
return {
"audio": NeuralType(('B', 'S', 'T'), AudioSignal(self.sample_rate)),
}
@typecheck()
def forward(self, *, spec):
"""
Runs the generator, for inputs and outputs see input_types, and output_types
"""
return self.generator(x=spec)
@typecheck(output_types={"audio": NeuralType(('B', 'T'), AudioSignal())})
def convert_spectrogram_to_audio(self, spec: 'torch.tensor') -> 'torch.tensor':
return self(spec=spec).squeeze(1)
def training_step(self, batch, batch_idx, optimizer_idx):
audio, audio_len = batch
# mel as input for generator
audio_mel, _ = self.audio_to_melspec_precessor(audio, audio_len)
# mel as input for L1 mel loss
audio_trg_mel, _ = self.trg_melspec_fn(audio, audio_len)
audio = audio.unsqueeze(1)
audio_pred = self.generator(x=audio_mel)
audio_pred_mel, _ = self.trg_melspec_fn(audio_pred.squeeze(1), audio_len)
# train discriminator
self.optim_d.zero_grad()
mpd_score_real, mpd_score_gen, _, _ = self.mpd(y=audio, y_hat=audio_pred.detach())
loss_disc_mpd, _, _ = self.discriminator_loss(
disc_real_outputs=mpd_score_real, disc_generated_outputs=mpd_score_gen
)
msd_score_real, msd_score_gen, _, _ = self.msd(y=audio, y_hat=audio_pred.detach())
loss_disc_msd, _, _ = self.discriminator_loss(
disc_real_outputs=msd_score_real, disc_generated_outputs=msd_score_gen
)
loss_d = loss_disc_msd + loss_disc_mpd
self.manual_backward(loss_d, self.optim_d)
self.optim_d.step()
# train generator
self.optim_g.zero_grad()
loss_mel = F.l1_loss(audio_pred_mel, audio_trg_mel) * 45
_, mpd_score_gen, fmap_mpd_real, fmap_mpd_gen = self.mpd(y=audio, y_hat=audio_pred)
_, msd_score_gen, fmap_msd_real, fmap_msd_gen = self.msd(y=audio, y_hat=audio_pred)
loss_fm_mpd = self.feature_loss(fmap_r=fmap_mpd_real, fmap_g=fmap_mpd_gen)
loss_fm_msd = self.feature_loss(fmap_r=fmap_msd_real, fmap_g=fmap_msd_gen)
loss_gen_mpd, _ = self.generator_loss(disc_outputs=mpd_score_gen)
loss_gen_msd, _ = self.generator_loss(disc_outputs=msd_score_gen)
loss_g = loss_gen_msd + loss_gen_mpd + loss_fm_msd + loss_fm_mpd + loss_mel
self.manual_backward(loss_g, self.optim_g)
self.optim_g.step()
metrics = {
"g_l1_loss": loss_mel,
"g_loss_fm_mpd": loss_fm_mpd,
"g_loss_fm_msd": loss_fm_msd,
"g_loss_gen_mpd": loss_gen_mpd,
"g_loss_gen_msd": loss_gen_msd,
"g_loss": loss_g,
"d_loss_mpd": loss_disc_mpd,
"d_loss_msd": loss_disc_msd,
"d_loss": loss_d,
"global_step": self.global_step,
"lr": self.optim_g.param_groups[0]['lr'],
}
self.log_dict(metrics, on_step=False, on_epoch=True, sync_dist=True)
def validation_step(self, batch, batch_idx):
audio, audio_len = batch
audio_mel, audio_mel_len = self.audio_to_melspec_precessor(audio, audio_len)
audio_pred = self(spec=audio_mel)
audio_pred_mel, _ = self.audio_to_melspec_precessor(audio_pred.squeeze(1), audio_len)
loss_mel = F.l1_loss(audio_mel, audio_pred_mel)
self.log("val_loss", loss_mel, prog_bar=True, sync_dist=True)
# plot audio once per epoch
if batch_idx == 0 and isinstance(self.logger, WandbLogger):
clips = []
specs = []
for i in range(min(5, audio.shape[0])):
clips += [
wandb.Audio(
audio[i, : audio_len[i]].data.cpu().numpy(),
caption=f"real audio {i}",
sample_rate=self.sample_rate,
),
wandb.Audio(
audio_pred[i, 0, : audio_len[i]].data.cpu().numpy().astype('float32'),
caption=f"generated audio {i}",
sample_rate=self.sample_rate,
),
]
specs += [
wandb.Image(
plot_spectrogram_to_numpy(audio_mel[i, :, : audio_mel_len[i]].data.cpu().numpy()),
caption=f"real audio {i}",
),
wandb.Image(
plot_spectrogram_to_numpy(audio_pred_mel[i, :, : audio_mel_len[i]].data.cpu().numpy()),
caption=f"generated audio {i}",
),
]
self.logger.experiment.log({"audio": clips, "specs": specs}, commit=False)
def __setup_dataloader_from_config(self, cfg, shuffle_should_be: bool = True, name: str = "train"):
if "dataset" not in cfg or not isinstance(cfg.dataset, DictConfig):
raise ValueError(f"No dataset for {name}")
if "dataloader_params" not in cfg or not isinstance(cfg.dataloader_params, DictConfig):
raise ValueError(f"No dataloder_params for {name}")
if shuffle_should_be:
if 'shuffle' not in cfg.dataloader_params:
logging.warning(
f"Shuffle should be set to True for {self}'s {name} dataloader but was not found in its "
"config. Manually setting to True"
)
with open_dict(cfg["dataloader_params"]):
cfg.dataloader_params.shuffle = True
elif not cfg.dataloader_params.shuffle:
logging.error(f"The {name} dataloader for {self} has shuffle set to False!!!")
elif not shuffle_should_be and cfg.dataloader_params.shuffle:
logging.error(f"The {name} dataloader for {self} has shuffle set to True!!!")
dataset = instantiate(cfg.dataset)
return torch.utils.data.DataLoader(dataset, collate_fn=dataset.collate_fn, **cfg.dataloader_params)
def setup_training_data(self, cfg):
self._train_dl = self.__setup_dataloader_from_config(cfg)
def setup_validation_data(self, cfg):
self._validation_dl = self.__setup_dataloader_from_config(cfg, shuffle_should_be=False, name="validation")
@classmethod
def list_available_models(cls) -> 'Optional[Dict[str, str]]':
# TODO
pass

View file

@ -0,0 +1,434 @@
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# MIT License
#
# Copyright (c) 2020 Jungil Kong
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
# The following functions/classes were based on code from https://github.com/jik876/hifi-gan:
# ResBlock1, ResBlock2, Generator, DiscriminatorP, DiscriminatorS, MultiScaleDiscriminator,
# MultiPeriodDiscriminator, init_weights, get_padding
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import AvgPool1d, Conv1d, Conv2d, ConvTranspose1d
from torch.nn.utils import remove_weight_norm, spectral_norm, weight_norm
from nemo.core.classes import NeuralModule
from nemo.core.classes.common import typecheck
from nemo.core.neural_types.elements import AudioSignal, MelSpectrogramType, VoidType
from nemo.core.neural_types.neural_type import NeuralType
LRELU_SLOPE = 0.1
def init_weights(m, mean=0.0, std=0.01):
classname = m.__class__.__name__
if classname.find("Conv") != -1:
m.weight.data.normal_(mean, std)
def get_padding(kernel_size, dilation=1):
return int((kernel_size * dilation - dilation) / 2)
class ResBlock1(torch.nn.Module):
def __init__(self, channels, kernel_size, dilation):
super(ResBlock1, self).__init__()
self.convs1 = nn.ModuleList(
[
weight_norm(
Conv1d(
channels,
channels,
kernel_size,
1,
dilation=dilation[0],
padding=get_padding(kernel_size, dilation[0]),
)
),
weight_norm(
Conv1d(
channels,
channels,
kernel_size,
1,
dilation=dilation[1],
padding=get_padding(kernel_size, dilation[1]),
)
),
weight_norm(
Conv1d(
channels,
channels,
kernel_size,
1,
dilation=dilation[2],
padding=get_padding(kernel_size, dilation[2]),
)
),
]
)
self.convs1.apply(init_weights)
self.convs2 = nn.ModuleList(
[
weight_norm(
Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1))
),
weight_norm(
Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1))
),
weight_norm(
Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1))
),
]
)
self.convs2.apply(init_weights)
def forward(self, x):
for c1, c2 in zip(self.convs1, self.convs2):
xt = F.leaky_relu(x, LRELU_SLOPE)
xt = c1(xt)
xt = F.leaky_relu(xt, LRELU_SLOPE)
xt = c2(xt)
x = xt + x
return x
def remove_weight_norm(self):
for l in self.convs1:
remove_weight_norm(l)
for l in self.convs2:
remove_weight_norm(l)
class ResBlock2(torch.nn.Module):
def __init__(self, channels, kernel_size, dilation):
super(ResBlock2, self).__init__()
self.convs = nn.ModuleList(
[
weight_norm(
Conv1d(
channels,
channels,
kernel_size,
1,
dilation=dilation[0],
padding=get_padding(kernel_size, dilation[0]),
)
),
weight_norm(
Conv1d(
channels,
channels,
kernel_size,
1,
dilation=dilation[1],
padding=get_padding(kernel_size, dilation[1]),
)
),
]
)
self.convs.apply(init_weights)
def forward(self, x):
for c in self.convs:
xt = F.leaky_relu(x, LRELU_SLOPE)
xt = c(xt)
x = xt + x
return x
def remove_weight_norm(self):
for l in self.convs:
remove_weight_norm(l)
class Generator(NeuralModule):
def __init__(
self,
resblock,
upsample_rates,
upsample_kernel_sizes,
upsample_initial_channel,
resblock_kernel_sizes,
resblock_dilation_sizes,
):
super(Generator, self).__init__()
self.num_kernels = len(resblock_kernel_sizes)
self.num_upsamples = len(upsample_rates)
self.conv_pre = weight_norm(Conv1d(80, upsample_initial_channel, 7, 1, padding=3))
resblock = ResBlock1 if resblock == 1 else ResBlock2
self.ups = nn.ModuleList()
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
self.ups.append(
weight_norm(
ConvTranspose1d(
upsample_initial_channel // (2 ** i),
upsample_initial_channel // (2 ** (i + 1)),
k,
u,
padding=(k - u) // 2,
)
)
)
self.resblocks = nn.ModuleList()
for i in range(len(self.ups)):
ch = upsample_initial_channel // (2 ** (i + 1))
for j, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
self.resblocks.append(resblock(ch, k, d))
self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3))
self.ups.apply(init_weights)
self.conv_post.apply(init_weights)
@property
def input_types(self):
return {
"x": NeuralType(('B', 'D', 'T'), MelSpectrogramType()),
}
@property
def output_types(self):
return {
"audio": NeuralType(('B', 'S', 'T'), AudioSignal()),
}
@typecheck()
def forward(self, x):
x = self.conv_pre(x)
for i in range(self.num_upsamples):
x = F.leaky_relu(x, LRELU_SLOPE)
x = self.ups[i](x)
xs = None
for j in range(self.num_kernels):
if xs is None:
xs = self.resblocks[i * self.num_kernels + j](x)
else:
xs += self.resblocks[i * self.num_kernels + j](x)
x = xs / self.num_kernels
x = F.leaky_relu(x)
x = self.conv_post(x)
x = torch.tanh(x)
return x
def remove_weight_norm(self):
print('Removing weight norm...')
for l in self.ups:
remove_weight_norm(l)
for l in self.resblocks:
l.remove_weight_norm()
remove_weight_norm(self.conv_pre)
remove_weight_norm(self.conv_post)
class DiscriminatorP(NeuralModule):
def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
super(DiscriminatorP, self).__init__()
self.period = period
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
self.convs = nn.ModuleList(
[
norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))),
]
)
self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
@property
def input_types(self):
return {
"x": NeuralType(('B', 'S', 'T'), AudioSignal()),
}
@property
def output_types(self):
return {
"decision": NeuralType(('B', 'T'), VoidType()),
"feature_maps": NeuralType(("B", "C", "H", "W"), VoidType()),
}
@typecheck()
def forward(self, x):
fmap = []
# 1d to 2d
b, c, t = x.shape
if t % self.period != 0: # pad first
n_pad = self.period - (t % self.period)
x = F.pad(x, (0, n_pad), "reflect")
t = t + n_pad
x = x.view(b, c, t // self.period, self.period)
for l in self.convs:
x = l(x)
x = F.leaky_relu(x, LRELU_SLOPE)
fmap.append(x)
x = self.conv_post(x)
fmap.append(x)
x = torch.flatten(x, 1, -1)
return x, fmap
class MultiPeriodDiscriminator(NeuralModule):
def __init__(self):
super(MultiPeriodDiscriminator, self).__init__()
self.discriminators = nn.ModuleList(
[DiscriminatorP(2), DiscriminatorP(3), DiscriminatorP(5), DiscriminatorP(7), DiscriminatorP(11),]
)
@property
def output_types(self):
return {
"y": NeuralType(('B', 'S', 'T'), AudioSignal()),
"y_hat": NeuralType(('B', 'S', 'T'), AudioSignal()),
}
@property
def output_types(self):
return {
"real_scores": NeuralType(('B', 'T'), VoidType()),
"fake_scores": NeuralType(('B', 'T'), VoidType()),
"real_feature_maps": NeuralType(("B", "C", "H", "W"), VoidType()),
"fake_feature_maps": NeuralType(("B", "C", "H", "W"), VoidType()),
}
@typecheck()
def forward(self, y, y_hat):
y_d_rs = []
y_d_gs = []
fmap_rs = []
fmap_gs = []
for i, d in enumerate(self.discriminators):
y_d_r, fmap_r = d(x=y)
y_d_g, fmap_g = d(x=y_hat)
y_d_rs.append(y_d_r)
fmap_rs.append(fmap_r)
y_d_gs.append(y_d_g)
fmap_gs.append(fmap_g)
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
class DiscriminatorS(NeuralModule):
def __init__(self, use_spectral_norm=False):
super(DiscriminatorS, self).__init__()
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
self.convs = nn.ModuleList(
[
norm_f(Conv1d(1, 128, 15, 1, padding=7)),
norm_f(Conv1d(128, 128, 41, 2, groups=4, padding=20)),
norm_f(Conv1d(128, 256, 41, 2, groups=16, padding=20)),
norm_f(Conv1d(256, 512, 41, 4, groups=16, padding=20)),
norm_f(Conv1d(512, 1024, 41, 4, groups=16, padding=20)),
norm_f(Conv1d(1024, 1024, 41, 1, groups=16, padding=20)),
norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
]
)
self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
@property
def input_types(self):
return {
"x": NeuralType(('B', 'S', 'T'), AudioSignal()),
}
@property
def output_types(self):
return {
"decision": NeuralType(('B', 'T'), VoidType()),
"feature_maps": NeuralType(("B", "C", "T"), VoidType()),
}
@typecheck()
def forward(self, x):
fmap = []
for l in self.convs:
x = l(x)
x = F.leaky_relu(x, LRELU_SLOPE)
fmap.append(x)
x = self.conv_post(x)
fmap.append(x)
x = torch.flatten(x, 1, -1)
return x, fmap
class MultiScaleDiscriminator(NeuralModule):
def __init__(self):
super(MultiScaleDiscriminator, self).__init__()
self.discriminators = nn.ModuleList(
[DiscriminatorS(use_spectral_norm=True), DiscriminatorS(), DiscriminatorS(),]
)
self.meanpools = nn.ModuleList([AvgPool1d(4, 2, padding=2), AvgPool1d(4, 2, padding=2)])
@property
def output_types(self):
return {
"y": NeuralType(('B', 'S', 'T'), AudioSignal()),
"y_hat": NeuralType(('B', 'S', 'T'), AudioSignal()),
}
@property
def output_types(self):
return {
"real_scores": NeuralType(('B', 'T'), VoidType()),
"fake_scores": NeuralType(('B', 'T'), VoidType()),
"real_feature_maps": NeuralType(("B", "C", "T"), VoidType()),
"fake_feature_maps": NeuralType(("B", "C", "T"), VoidType()),
}
@typecheck()
def forward(self, y, y_hat):
y_d_rs = []
y_d_gs = []
fmap_rs = []
fmap_gs = []
for i, d in enumerate(self.discriminators):
if i != 0:
y = self.meanpools[i - 1](y)
y_hat = self.meanpools[i - 1](y_hat)
y_d_r, fmap_r = d(x=y)
y_d_g, fmap_g = d(x=y_hat)
y_d_rs.append(y_d_r)
fmap_rs.append(fmap_r)
y_d_gs.append(y_d_g)
fmap_gs.append(fmap_g)
return y_d_rs, y_d_gs, fmap_rs, fmap_gs