HifiGAN MelSpectrogram Vocoder Model (#1706)

* HifiGAN: initial commit

Signed-off-by: Felix Kreuk <felixkreuk@gmail.com>

* switched to manual optimization, added grads to FilterBanks object,
added v1,v2,v3 configurations for the generator
added multiple audio/spec examples in wandb

Signed-off-by: Felix Kreuk <felixkreuk@gmail.com>

* added trg_mel_fn for different mel params in L1 loss

Signed-off-by: Felix Kreuk <felixkreuk@gmail.com>

* fixed ci checks

Signed-off-by: Felix Kreuk <felixkreuk@gmail.com>

* moved losses to separate classes, fixed exp_manager in yaml, style fixes

Signed-off-by: Felix Kreuk <felixkreuk@gmail.com>

* fixed `min_duration` in yamls of MelGAN and HifiGAN

Signed-off-by: Felix Kreuk <felixkreuk@gmail.com>

* fixed f_max in loss calculation, added docs

Signed-off-by: Felix Kreuk <felixkreuk@gmail.com>

Co-authored-by: Oleksii Kuchaiev <okuchaiev@users.noreply.github.com>
Co-authored-by: Jason <jasoli@nvidia.com>
This commit is contained in:
Felix Kreuk 2021-02-16 23:56:29 +02:00 committed by GitHub
parent 562087498d
commit 7d8b7d679a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
14 changed files with 955 additions and 4 deletions

1
.gitignore vendored
View file

@ -11,6 +11,7 @@ bert.pt.json
work
runs
fastspeech_output
.hydra
# Byte-compiled / optimized / DLL files
__pycache__/

View file

@ -51,6 +51,12 @@ Full-band MelGAN
The :class:`MelGanModel<nemo.collections.tts.models.MelGanModel>` class implements Full-band MelGAN as described in
:cite:`yang2020multiband`. MelGAN is a GAN-based vocoder that converts mel-spectrograms to audio.
HifiGAN
-------
The :class:`HifiGanModel<nemo.collections.tts.models.HifiGanModel>` class implements HifiGAN as described in
:cite:`kong2020hifi`. HifiGAN is a GAN-based vocoder that converts mel-spectrograms to audio.
References
----------

View file

@ -37,4 +37,11 @@
eprint={2005.05106},
archivePrefix={arXiv},
primaryClass={cs.SD}
}
}
@article{kong2020hifi,
title={HiFi-GAN: Generative Adversarial Networks for Efficient and High Fidelity Speech Synthesis},
author={Kong, Jungil and Kim, Jaehyeon and Bae, Jaekyoung},
journal={arXiv preprint arXiv:2010.05646},
year={2020}
}

View file

@ -0,0 +1,84 @@
name: "HifiGan"
train_dataset: ???
validation_datasets: ???
defaults:
- model/generator: v1
model:
train_ds:
dataset:
_target_: "nemo.collections.tts.data.datalayers.AudioDataset"
manifest_filepath: ${train_dataset}
max_duration: null
min_duration: 0.75
n_segments: 8192
trim: false
dataloader_params:
drop_last: false
shuffle: true
batch_size: 16
num_workers: 4
validation_ds:
dataset:
_target_: "nemo.collections.tts.data.datalayers.AudioDataset"
manifest_filepath: ${validation_datasets}
max_duration: null
min_duration: null
n_segments: -1
trim: false
dataloader_params:
drop_last: false
shuffle: false
batch_size: 3
num_workers: 4
preprocessor:
_target_: nemo.collections.asr.parts.features.FilterbankFeatures
dither: 0.0
frame_splicing: 1
nfilt: 80
highfreq: 8000
log: true
log_zero_guard_type: clamp
log_zero_guard_value: 1e-05
lowfreq: 0
mag_power: 1.0
n_fft: 1024
n_window_size: 1024
n_window_stride: 256
normalize: null
pad_to: 0
pad_value: -11.52
preemph: null
sample_rate: 22050
stft_exact_pad: true
stft_conv: false
window: hann
use_grads: false
optim:
lr: 0.0002
adam_b1: 0.8
adam_b2: 0.99
lr_decay: 0.999
lr_step: 15
trainer:
gpus: -1 # number of gpus
max_epochs: 1000000
num_nodes: 1
accelerator: ddp
accumulate_grad_batches: 1
checkpoint_callback: False # Provided by exp_manager
logger: False # Provided by exp_manager
log_every_n_steps: 100
check_val_every_n_epoch: 10
automatic_optimization: false
exp_manager:
exp_dir: null
name: ${name}
create_tensorboard_logger: True
create_checkpoint_callback: True

View file

@ -0,0 +1,8 @@
# @package _group_
_target_: nemo.collections.tts.modules.hifigan_modules.Generator
resblock: 1
upsample_rates: [8,8,2,2]
upsample_kernel_sizes: [16,16,4,4]
upsample_initial_channel: 512
resblock_kernel_sizes: [3,7,11]
resblock_dilation_sizes: [[1,3,5], [1,3,5], [1,3,5]]

View file

@ -0,0 +1,8 @@
# @package _group_
_target_: nemo.collections.tts.modules.hifigan_modules.Generator
resblock: 1
upsample_rates: [8,8,2,2]
upsample_kernel_sizes: [16,16,4,4]
upsample_initial_channel: 128
resblock_kernel_sizes: [3,7,11]
resblock_dilation_sizes: [[1,3,5], [1,3,5], [1,3,5]]

View file

@ -0,0 +1,8 @@
# @package _group_
_target_: nemo.collections.tts.modules.hifigan_modules.Generator
resblock: 2
upsample_rates: [8,8,4]
upsample_kernel_sizes: [16,16,8]
upsample_initial_channel: 256
resblock_kernel_sizes: [3,5,7]
resblock_dilation_sizes: [[1,2], [2,6], [3,12]]

View file

@ -28,7 +28,7 @@ model:
_target_: "nemo.collections.tts.data.datalayers.AudioDataset"
manifest_filepath: ${validation_datasets}
max_duration: null
min_duration: 0.75
min_duration: null
n_segments: -1
trim: false
dataloader_params:

34
examples/tts/hifigan.py Normal file
View file

@ -0,0 +1,34 @@
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytorch_lightning as pl
from nemo.collections.common.callbacks import LogEpochTimeCallback
from nemo.collections.tts.models import HifiGanModel
from nemo.core.config import hydra_runner
from nemo.utils.exp_manager import exp_manager
@hydra_runner(config_path="conf/hifigan", config_name="hifigan")
def main(cfg):
trainer = pl.Trainer(**cfg.trainer)
exp_manager(trainer, cfg.get("exp_manager", None))
model = HifiGanModel(cfg=cfg.model, trainer=trainer)
epoch_time_logger = LogEpochTimeCallback()
trainer.callbacks.extend([epoch_time_logger])
trainer.fit(model)
if __name__ == '__main__':
main() # noqa pylint: disable=no-value-for-parameter

View file

@ -220,6 +220,7 @@ class FilterbankFeatures(nn.Module):
stft_conv=False,
pad_value=0,
mag_power=2.0,
use_grads=False,
):
super().__init__()
self.log_zero_guard_value = log_zero_guard_value
@ -301,6 +302,11 @@ class FilterbankFeatures(nn.Module):
f"log_zero_guard_type parameter. It must be either 'add' or "
f"'clamp'."
)
self.use_grads = use_grads
if not use_grads:
self.forward = torch.no_grad()(self.forward)
# log_zero_guard_value is the the small we want to use, we support
# an actual number, or "tiny", or "eps"
self.log_zero_guard_type = log_zero_guard_type
@ -311,6 +317,7 @@ class FilterbankFeatures(nn.Module):
logging.debug(f"n_mels: {nfilt}")
logging.debug(f"fmin: {lowfreq}")
logging.debug(f"fmax: {highfreq}")
logging.debug(f"using grads: {use_grads}")
def log_zero_guard_value_fn(self, x):
if isinstance(self.log_zero_guard_value, str):
@ -334,7 +341,6 @@ class FilterbankFeatures(nn.Module):
def filter_banks(self):
return self.fb
@torch.no_grad()
def forward(self, x, seq_len):
seq_len = self.get_seq_len(seq_len.float())
@ -356,7 +362,9 @@ class FilterbankFeatures(nn.Module):
# torch returns real, imag; so convert to magnitude
if not self.stft_conv:
x = torch.sqrt(x.pow(2).sum(-1))
# guard is needed for sqrt if grads are passed through
guard = 0 if not self.use_grads else CONSTANT
x = torch.sqrt(x.pow(2).sum(-1) + guard)
# get power spectrum
if self.mag_power != 1.0:

View file

@ -0,0 +1,132 @@
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# MIT License
#
# Copyright (c) 2020 Jungil Kong
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
# The forward functions onf the following classes are based on code from https://github.com/jik876/hifi-gan:
# FeatureMatchingLoss, DiscriminatorLoss, GeneratorLoss
import torch
from nemo.core.classes import Loss, typecheck
from nemo.core.neural_types.elements import LossType, VoidType
from nemo.core.neural_types.neural_type import NeuralType
class FeatureMatchingLoss(Loss):
"""Feature Matching Loss module"""
@property
def input_types(self):
return {
"fmap_r": NeuralType(elements_type=VoidType()),
"fmap_g": NeuralType(elements_type=VoidType()),
}
@property
def output_types(self):
return {
"loss": NeuralType(elements_type=LossType()),
}
@typecheck()
def forward(self, fmap_r, fmap_g):
loss = 0
for dr, dg in zip(fmap_r, fmap_g):
for rl, gl in zip(dr, dg):
loss += torch.mean(torch.abs(rl - gl))
return loss * 2
class DiscriminatorLoss(Loss):
"""Discriminator Loss module"""
@property
def input_types(self):
return {
"disc_real_outputs": NeuralType(('B', 'T'), VoidType()),
"disc_generated_outputs": NeuralType(('B', 'T'), VoidType()),
}
@property
def output_types(self):
return {
"loss": NeuralType(elements_type=LossType()),
"real_losses": NeuralType(elements_type=LossType()),
"fake_losses": NeuralType(elements_type=LossType()),
}
@typecheck()
def forward(self, disc_real_outputs, disc_generated_outputs):
loss = 0
r_losses = []
g_losses = []
for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
r_loss = torch.mean((1 - dr) ** 2)
g_loss = torch.mean(dg ** 2)
loss += r_loss + g_loss
r_losses.append(r_loss.item())
g_losses.append(g_loss.item())
return loss, r_losses, g_losses
class GeneratorLoss(Loss):
"""Generator Loss module"""
@property
def input_types(self):
return {
"disc_outputs": NeuralType(('B', 'T'), VoidType()),
}
@property
def output_types(self):
return {
"loss": NeuralType(elements_type=LossType()),
"fake_losses": NeuralType(elements_type=LossType()),
}
@typecheck()
def forward(self, disc_outputs):
loss = 0
gen_losses = []
for dg in disc_outputs:
l = torch.mean((1 - dg) ** 2)
gen_losses.append(l)
loss += l
return loss, gen_losses

View file

@ -15,6 +15,7 @@
from nemo.collections.tts.models.degli import DegliModel
from nemo.collections.tts.models.ed_mel2spec import EDMel2SpecModel
from nemo.collections.tts.models.glow_tts import GlowTTSModel
from nemo.collections.tts.models.hifigan import HifiGanModel
from nemo.collections.tts.models.melgan import MelGanModel
from nemo.collections.tts.models.squeezewave import SqueezeWaveModel
from nemo.collections.tts.models.tacotron2 import Tacotron2Model
@ -37,4 +38,5 @@ __all__ = [
"TalkNetSpectModel",
"UniGlowModel",
"MelGanModel",
"HifiGanModel",
]

View file

@ -0,0 +1,219 @@
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import itertools
import torch
import torch.nn.functional as F
import wandb
from hydra.utils import instantiate
from omegaconf import DictConfig, OmegaConf, open_dict
from pytorch_lightning.loggers.wandb import WandbLogger
from nemo.collections.tts.helpers.helpers import plot_spectrogram_to_numpy
from nemo.collections.tts.losses.hifigan_losses import DiscriminatorLoss, FeatureMatchingLoss, GeneratorLoss
from nemo.collections.tts.models.base import Vocoder
from nemo.collections.tts.modules.hifigan_modules import MultiPeriodDiscriminator, MultiScaleDiscriminator
from nemo.core.classes.common import typecheck
from nemo.core.neural_types.elements import AudioSignal, MelSpectrogramType
from nemo.core.neural_types.neural_type import NeuralType
from nemo.utils import logging
class HifiGanModel(Vocoder):
def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None):
if isinstance(cfg, dict):
cfg = OmegaConf.create(cfg)
super().__init__(cfg=cfg, trainer=trainer)
self.audio_to_melspec_precessor = instantiate(cfg.preprocessor)
# use a different melspec extractor because:
# 1. we need to pass grads
# 2. we need remove fmax limitation
self.trg_melspec_fn = instantiate(cfg.preprocessor, highfreq=None, use_grads=True)
self.generator = instantiate(cfg.generator)
self.mpd = MultiPeriodDiscriminator()
self.msd = MultiScaleDiscriminator()
self.feature_loss = FeatureMatchingLoss()
self.discriminator_loss = DiscriminatorLoss()
self.generator_loss = GeneratorLoss()
self.sample_rate = self._cfg.preprocessor.sample_rate
def configure_optimizers(self):
self.optim_g = torch.optim.AdamW(
self.generator.parameters(), self._cfg.optim.lr, betas=[self._cfg.optim.adam_b1, self._cfg.optim.adam_b2]
)
self.optim_d = torch.optim.AdamW(
itertools.chain(self.msd.parameters(), self.mpd.parameters()),
self._cfg.optim.lr,
betas=[self._cfg.optim.adam_b1, self._cfg.optim.adam_b2],
)
self.scheduler_g = torch.optim.lr_scheduler.StepLR(
self.optim_g, step_size=self._cfg.optim.lr_step, gamma=self._cfg.optim.lr_decay,
)
self.scheduler_d = torch.optim.lr_scheduler.StepLR(
self.optim_d, step_size=self._cfg.optim.lr_step, gamma=self._cfg.optim.lr_decay,
)
return [self.optim_g, self.optim_d], [self.scheduler_g, self.scheduler_d]
@property
def input_types(self):
return {
"spec": NeuralType(('B', 'D', 'T'), MelSpectrogramType()),
}
@property
def output_types(self):
return {
"audio": NeuralType(('B', 'S', 'T'), AudioSignal(self.sample_rate)),
}
@typecheck()
def forward(self, *, spec):
"""
Runs the generator, for inputs and outputs see input_types, and output_types
"""
return self.generator(x=spec)
@typecheck(output_types={"audio": NeuralType(('B', 'T'), AudioSignal())})
def convert_spectrogram_to_audio(self, spec: 'torch.tensor') -> 'torch.tensor':
return self(spec=spec).squeeze(1)
def training_step(self, batch, batch_idx, optimizer_idx):
audio, audio_len = batch
# mel as input for generator
audio_mel, _ = self.audio_to_melspec_precessor(audio, audio_len)
# mel as input for L1 mel loss
audio_trg_mel, _ = self.trg_melspec_fn(audio, audio_len)
audio = audio.unsqueeze(1)
audio_pred = self.generator(x=audio_mel)
audio_pred_mel, _ = self.trg_melspec_fn(audio_pred.squeeze(1), audio_len)
# train discriminator
self.optim_d.zero_grad()
mpd_score_real, mpd_score_gen, _, _ = self.mpd(y=audio, y_hat=audio_pred.detach())
loss_disc_mpd, _, _ = self.discriminator_loss(
disc_real_outputs=mpd_score_real, disc_generated_outputs=mpd_score_gen
)
msd_score_real, msd_score_gen, _, _ = self.msd(y=audio, y_hat=audio_pred.detach())
loss_disc_msd, _, _ = self.discriminator_loss(
disc_real_outputs=msd_score_real, disc_generated_outputs=msd_score_gen
)
loss_d = loss_disc_msd + loss_disc_mpd
self.manual_backward(loss_d, self.optim_d)
self.optim_d.step()
# train generator
self.optim_g.zero_grad()
loss_mel = F.l1_loss(audio_pred_mel, audio_trg_mel) * 45
_, mpd_score_gen, fmap_mpd_real, fmap_mpd_gen = self.mpd(y=audio, y_hat=audio_pred)
_, msd_score_gen, fmap_msd_real, fmap_msd_gen = self.msd(y=audio, y_hat=audio_pred)
loss_fm_mpd = self.feature_loss(fmap_r=fmap_mpd_real, fmap_g=fmap_mpd_gen)
loss_fm_msd = self.feature_loss(fmap_r=fmap_msd_real, fmap_g=fmap_msd_gen)
loss_gen_mpd, _ = self.generator_loss(disc_outputs=mpd_score_gen)
loss_gen_msd, _ = self.generator_loss(disc_outputs=msd_score_gen)
loss_g = loss_gen_msd + loss_gen_mpd + loss_fm_msd + loss_fm_mpd + loss_mel
self.manual_backward(loss_g, self.optim_g)
self.optim_g.step()
metrics = {
"g_l1_loss": loss_mel,
"g_loss_fm_mpd": loss_fm_mpd,
"g_loss_fm_msd": loss_fm_msd,
"g_loss_gen_mpd": loss_gen_mpd,
"g_loss_gen_msd": loss_gen_msd,
"g_loss": loss_g,
"d_loss_mpd": loss_disc_mpd,
"d_loss_msd": loss_disc_msd,
"d_loss": loss_d,
"global_step": self.global_step,
"lr": self.optim_g.param_groups[0]['lr'],
}
self.log_dict(metrics, on_step=False, on_epoch=True, sync_dist=True)
def validation_step(self, batch, batch_idx):
audio, audio_len = batch
audio_mel, audio_mel_len = self.audio_to_melspec_precessor(audio, audio_len)
audio_pred = self(spec=audio_mel)
audio_pred_mel, _ = self.audio_to_melspec_precessor(audio_pred.squeeze(1), audio_len)
loss_mel = F.l1_loss(audio_mel, audio_pred_mel)
self.log("val_loss", loss_mel, prog_bar=True, sync_dist=True)
# plot audio once per epoch
if batch_idx == 0 and isinstance(self.logger, WandbLogger):
clips = []
specs = []
for i in range(min(5, audio.shape[0])):
clips += [
wandb.Audio(
audio[i, : audio_len[i]].data.cpu().numpy(),
caption=f"real audio {i}",
sample_rate=self.sample_rate,
),
wandb.Audio(
audio_pred[i, 0, : audio_len[i]].data.cpu().numpy().astype('float32'),
caption=f"generated audio {i}",
sample_rate=self.sample_rate,
),
]
specs += [
wandb.Image(
plot_spectrogram_to_numpy(audio_mel[i, :, : audio_mel_len[i]].data.cpu().numpy()),
caption=f"real audio {i}",
),
wandb.Image(
plot_spectrogram_to_numpy(audio_pred_mel[i, :, : audio_mel_len[i]].data.cpu().numpy()),
caption=f"generated audio {i}",
),
]
self.logger.experiment.log({"audio": clips, "specs": specs}, commit=False)
def __setup_dataloader_from_config(self, cfg, shuffle_should_be: bool = True, name: str = "train"):
if "dataset" not in cfg or not isinstance(cfg.dataset, DictConfig):
raise ValueError(f"No dataset for {name}")
if "dataloader_params" not in cfg or not isinstance(cfg.dataloader_params, DictConfig):
raise ValueError(f"No dataloder_params for {name}")
if shuffle_should_be:
if 'shuffle' not in cfg.dataloader_params:
logging.warning(
f"Shuffle should be set to True for {self}'s {name} dataloader but was not found in its "
"config. Manually setting to True"
)
with open_dict(cfg["dataloader_params"]):
cfg.dataloader_params.shuffle = True
elif not cfg.dataloader_params.shuffle:
logging.error(f"The {name} dataloader for {self} has shuffle set to False!!!")
elif not shuffle_should_be and cfg.dataloader_params.shuffle:
logging.error(f"The {name} dataloader for {self} has shuffle set to True!!!")
dataset = instantiate(cfg.dataset)
return torch.utils.data.DataLoader(dataset, collate_fn=dataset.collate_fn, **cfg.dataloader_params)
def setup_training_data(self, cfg):
self._train_dl = self.__setup_dataloader_from_config(cfg)
def setup_validation_data(self, cfg):
self._validation_dl = self.__setup_dataloader_from_config(cfg, shuffle_should_be=False, name="validation")
@classmethod
def list_available_models(cls) -> 'Optional[Dict[str, str]]':
# TODO
pass

View file

@ -0,0 +1,434 @@
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# MIT License
#
# Copyright (c) 2020 Jungil Kong
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
# The following functions/classes were based on code from https://github.com/jik876/hifi-gan:
# ResBlock1, ResBlock2, Generator, DiscriminatorP, DiscriminatorS, MultiScaleDiscriminator,
# MultiPeriodDiscriminator, init_weights, get_padding
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import AvgPool1d, Conv1d, Conv2d, ConvTranspose1d
from torch.nn.utils import remove_weight_norm, spectral_norm, weight_norm
from nemo.core.classes import NeuralModule
from nemo.core.classes.common import typecheck
from nemo.core.neural_types.elements import AudioSignal, MelSpectrogramType, VoidType
from nemo.core.neural_types.neural_type import NeuralType
LRELU_SLOPE = 0.1
def init_weights(m, mean=0.0, std=0.01):
classname = m.__class__.__name__
if classname.find("Conv") != -1:
m.weight.data.normal_(mean, std)
def get_padding(kernel_size, dilation=1):
return int((kernel_size * dilation - dilation) / 2)
class ResBlock1(torch.nn.Module):
def __init__(self, channels, kernel_size, dilation):
super(ResBlock1, self).__init__()
self.convs1 = nn.ModuleList(
[
weight_norm(
Conv1d(
channels,
channels,
kernel_size,
1,
dilation=dilation[0],
padding=get_padding(kernel_size, dilation[0]),
)
),
weight_norm(
Conv1d(
channels,
channels,
kernel_size,
1,
dilation=dilation[1],
padding=get_padding(kernel_size, dilation[1]),
)
),
weight_norm(
Conv1d(
channels,
channels,
kernel_size,
1,
dilation=dilation[2],
padding=get_padding(kernel_size, dilation[2]),
)
),
]
)
self.convs1.apply(init_weights)
self.convs2 = nn.ModuleList(
[
weight_norm(
Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1))
),
weight_norm(
Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1))
),
weight_norm(
Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1))
),
]
)
self.convs2.apply(init_weights)
def forward(self, x):
for c1, c2 in zip(self.convs1, self.convs2):
xt = F.leaky_relu(x, LRELU_SLOPE)
xt = c1(xt)
xt = F.leaky_relu(xt, LRELU_SLOPE)
xt = c2(xt)
x = xt + x
return x
def remove_weight_norm(self):
for l in self.convs1:
remove_weight_norm(l)
for l in self.convs2:
remove_weight_norm(l)
class ResBlock2(torch.nn.Module):
def __init__(self, channels, kernel_size, dilation):
super(ResBlock2, self).__init__()
self.convs = nn.ModuleList(
[
weight_norm(
Conv1d(
channels,
channels,
kernel_size,
1,
dilation=dilation[0],
padding=get_padding(kernel_size, dilation[0]),
)
),
weight_norm(
Conv1d(
channels,
channels,
kernel_size,
1,
dilation=dilation[1],
padding=get_padding(kernel_size, dilation[1]),
)
),
]
)
self.convs.apply(init_weights)
def forward(self, x):
for c in self.convs:
xt = F.leaky_relu(x, LRELU_SLOPE)
xt = c(xt)
x = xt + x
return x
def remove_weight_norm(self):
for l in self.convs:
remove_weight_norm(l)
class Generator(NeuralModule):
def __init__(
self,
resblock,
upsample_rates,
upsample_kernel_sizes,
upsample_initial_channel,
resblock_kernel_sizes,
resblock_dilation_sizes,
):
super(Generator, self).__init__()
self.num_kernels = len(resblock_kernel_sizes)
self.num_upsamples = len(upsample_rates)
self.conv_pre = weight_norm(Conv1d(80, upsample_initial_channel, 7, 1, padding=3))
resblock = ResBlock1 if resblock == 1 else ResBlock2
self.ups = nn.ModuleList()
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
self.ups.append(
weight_norm(
ConvTranspose1d(
upsample_initial_channel // (2 ** i),
upsample_initial_channel // (2 ** (i + 1)),
k,
u,
padding=(k - u) // 2,
)
)
)
self.resblocks = nn.ModuleList()
for i in range(len(self.ups)):
ch = upsample_initial_channel // (2 ** (i + 1))
for j, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
self.resblocks.append(resblock(ch, k, d))
self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3))
self.ups.apply(init_weights)
self.conv_post.apply(init_weights)
@property
def input_types(self):
return {
"x": NeuralType(('B', 'D', 'T'), MelSpectrogramType()),
}
@property
def output_types(self):
return {
"audio": NeuralType(('B', 'S', 'T'), AudioSignal()),
}
@typecheck()
def forward(self, x):
x = self.conv_pre(x)
for i in range(self.num_upsamples):
x = F.leaky_relu(x, LRELU_SLOPE)
x = self.ups[i](x)
xs = None
for j in range(self.num_kernels):
if xs is None:
xs = self.resblocks[i * self.num_kernels + j](x)
else:
xs += self.resblocks[i * self.num_kernels + j](x)
x = xs / self.num_kernels
x = F.leaky_relu(x)
x = self.conv_post(x)
x = torch.tanh(x)
return x
def remove_weight_norm(self):
print('Removing weight norm...')
for l in self.ups:
remove_weight_norm(l)
for l in self.resblocks:
l.remove_weight_norm()
remove_weight_norm(self.conv_pre)
remove_weight_norm(self.conv_post)
class DiscriminatorP(NeuralModule):
def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
super(DiscriminatorP, self).__init__()
self.period = period
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
self.convs = nn.ModuleList(
[
norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))),
]
)
self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
@property
def input_types(self):
return {
"x": NeuralType(('B', 'S', 'T'), AudioSignal()),
}
@property
def output_types(self):
return {
"decision": NeuralType(('B', 'T'), VoidType()),
"feature_maps": NeuralType(("B", "C", "H", "W"), VoidType()),
}
@typecheck()
def forward(self, x):
fmap = []
# 1d to 2d
b, c, t = x.shape
if t % self.period != 0: # pad first
n_pad = self.period - (t % self.period)
x = F.pad(x, (0, n_pad), "reflect")
t = t + n_pad
x = x.view(b, c, t // self.period, self.period)
for l in self.convs:
x = l(x)
x = F.leaky_relu(x, LRELU_SLOPE)
fmap.append(x)
x = self.conv_post(x)
fmap.append(x)
x = torch.flatten(x, 1, -1)
return x, fmap
class MultiPeriodDiscriminator(NeuralModule):
def __init__(self):
super(MultiPeriodDiscriminator, self).__init__()
self.discriminators = nn.ModuleList(
[DiscriminatorP(2), DiscriminatorP(3), DiscriminatorP(5), DiscriminatorP(7), DiscriminatorP(11),]
)
@property
def output_types(self):
return {
"y": NeuralType(('B', 'S', 'T'), AudioSignal()),
"y_hat": NeuralType(('B', 'S', 'T'), AudioSignal()),
}
@property
def output_types(self):
return {
"real_scores": NeuralType(('B', 'T'), VoidType()),
"fake_scores": NeuralType(('B', 'T'), VoidType()),
"real_feature_maps": NeuralType(("B", "C", "H", "W"), VoidType()),
"fake_feature_maps": NeuralType(("B", "C", "H", "W"), VoidType()),
}
@typecheck()
def forward(self, y, y_hat):
y_d_rs = []
y_d_gs = []
fmap_rs = []
fmap_gs = []
for i, d in enumerate(self.discriminators):
y_d_r, fmap_r = d(x=y)
y_d_g, fmap_g = d(x=y_hat)
y_d_rs.append(y_d_r)
fmap_rs.append(fmap_r)
y_d_gs.append(y_d_g)
fmap_gs.append(fmap_g)
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
class DiscriminatorS(NeuralModule):
def __init__(self, use_spectral_norm=False):
super(DiscriminatorS, self).__init__()
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
self.convs = nn.ModuleList(
[
norm_f(Conv1d(1, 128, 15, 1, padding=7)),
norm_f(Conv1d(128, 128, 41, 2, groups=4, padding=20)),
norm_f(Conv1d(128, 256, 41, 2, groups=16, padding=20)),
norm_f(Conv1d(256, 512, 41, 4, groups=16, padding=20)),
norm_f(Conv1d(512, 1024, 41, 4, groups=16, padding=20)),
norm_f(Conv1d(1024, 1024, 41, 1, groups=16, padding=20)),
norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
]
)
self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
@property
def input_types(self):
return {
"x": NeuralType(('B', 'S', 'T'), AudioSignal()),
}
@property
def output_types(self):
return {
"decision": NeuralType(('B', 'T'), VoidType()),
"feature_maps": NeuralType(("B", "C", "T"), VoidType()),
}
@typecheck()
def forward(self, x):
fmap = []
for l in self.convs:
x = l(x)
x = F.leaky_relu(x, LRELU_SLOPE)
fmap.append(x)
x = self.conv_post(x)
fmap.append(x)
x = torch.flatten(x, 1, -1)
return x, fmap
class MultiScaleDiscriminator(NeuralModule):
def __init__(self):
super(MultiScaleDiscriminator, self).__init__()
self.discriminators = nn.ModuleList(
[DiscriminatorS(use_spectral_norm=True), DiscriminatorS(), DiscriminatorS(),]
)
self.meanpools = nn.ModuleList([AvgPool1d(4, 2, padding=2), AvgPool1d(4, 2, padding=2)])
@property
def output_types(self):
return {
"y": NeuralType(('B', 'S', 'T'), AudioSignal()),
"y_hat": NeuralType(('B', 'S', 'T'), AudioSignal()),
}
@property
def output_types(self):
return {
"real_scores": NeuralType(('B', 'T'), VoidType()),
"fake_scores": NeuralType(('B', 'T'), VoidType()),
"real_feature_maps": NeuralType(("B", "C", "T"), VoidType()),
"fake_feature_maps": NeuralType(("B", "C", "T"), VoidType()),
}
@typecheck()
def forward(self, y, y_hat):
y_d_rs = []
y_d_gs = []
fmap_rs = []
fmap_gs = []
for i, d in enumerate(self.discriminators):
if i != 0:
y = self.meanpools[i - 1](y)
y_hat = self.meanpools[i - 1](y_hat)
y_d_r, fmap_r = d(x=y)
y_d_g, fmap_g = d(x=y_hat)
y_d_rs.append(y_d_r)
fmap_rs.append(fmap_r)
y_d_gs.append(y_d_g)
fmap_gs.append(fmap_g)
return y_d_rs, y_d_gs, fmap_rs, fmap_gs