HifiGAN MelSpectrogram Vocoder Model (#1706)
* HifiGAN: initial commit Signed-off-by: Felix Kreuk <felixkreuk@gmail.com> * switched to manual optimization, added grads to FilterBanks object, added v1,v2,v3 configurations for the generator added multiple audio/spec examples in wandb Signed-off-by: Felix Kreuk <felixkreuk@gmail.com> * added trg_mel_fn for different mel params in L1 loss Signed-off-by: Felix Kreuk <felixkreuk@gmail.com> * fixed ci checks Signed-off-by: Felix Kreuk <felixkreuk@gmail.com> * moved losses to separate classes, fixed exp_manager in yaml, style fixes Signed-off-by: Felix Kreuk <felixkreuk@gmail.com> * fixed `min_duration` in yamls of MelGAN and HifiGAN Signed-off-by: Felix Kreuk <felixkreuk@gmail.com> * fixed f_max in loss calculation, added docs Signed-off-by: Felix Kreuk <felixkreuk@gmail.com> Co-authored-by: Oleksii Kuchaiev <okuchaiev@users.noreply.github.com> Co-authored-by: Jason <jasoli@nvidia.com>
This commit is contained in:
parent
562087498d
commit
7d8b7d679a
1
.gitignore
vendored
1
.gitignore
vendored
|
@ -11,6 +11,7 @@ bert.pt.json
|
||||||
work
|
work
|
||||||
runs
|
runs
|
||||||
fastspeech_output
|
fastspeech_output
|
||||||
|
.hydra
|
||||||
|
|
||||||
# Byte-compiled / optimized / DLL files
|
# Byte-compiled / optimized / DLL files
|
||||||
__pycache__/
|
__pycache__/
|
||||||
|
|
|
@ -51,6 +51,12 @@ Full-band MelGAN
|
||||||
The :class:`MelGanModel<nemo.collections.tts.models.MelGanModel>` class implements Full-band MelGAN as described in
|
The :class:`MelGanModel<nemo.collections.tts.models.MelGanModel>` class implements Full-band MelGAN as described in
|
||||||
:cite:`yang2020multiband`. MelGAN is a GAN-based vocoder that converts mel-spectrograms to audio.
|
:cite:`yang2020multiband`. MelGAN is a GAN-based vocoder that converts mel-spectrograms to audio.
|
||||||
|
|
||||||
|
HifiGAN
|
||||||
|
-------
|
||||||
|
|
||||||
|
The :class:`HifiGanModel<nemo.collections.tts.models.HifiGanModel>` class implements HifiGAN as described in
|
||||||
|
:cite:`kong2020hifi`. HifiGAN is a GAN-based vocoder that converts mel-spectrograms to audio.
|
||||||
|
|
||||||
References
|
References
|
||||||
----------
|
----------
|
||||||
|
|
||||||
|
|
|
@ -37,4 +37,11 @@
|
||||||
eprint={2005.05106},
|
eprint={2005.05106},
|
||||||
archivePrefix={arXiv},
|
archivePrefix={arXiv},
|
||||||
primaryClass={cs.SD}
|
primaryClass={cs.SD}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@article{kong2020hifi,
|
||||||
|
title={HiFi-GAN: Generative Adversarial Networks for Efficient and High Fidelity Speech Synthesis},
|
||||||
|
author={Kong, Jungil and Kim, Jaehyeon and Bae, Jaekyoung},
|
||||||
|
journal={arXiv preprint arXiv:2010.05646},
|
||||||
|
year={2020}
|
||||||
|
}
|
||||||
|
|
84
examples/tts/conf/hifigan/hifigan.yaml
Normal file
84
examples/tts/conf/hifigan/hifigan.yaml
Normal file
|
@ -0,0 +1,84 @@
|
||||||
|
name: "HifiGan"
|
||||||
|
train_dataset: ???
|
||||||
|
validation_datasets: ???
|
||||||
|
|
||||||
|
defaults:
|
||||||
|
- model/generator: v1
|
||||||
|
|
||||||
|
model:
|
||||||
|
train_ds:
|
||||||
|
dataset:
|
||||||
|
_target_: "nemo.collections.tts.data.datalayers.AudioDataset"
|
||||||
|
manifest_filepath: ${train_dataset}
|
||||||
|
max_duration: null
|
||||||
|
min_duration: 0.75
|
||||||
|
n_segments: 8192
|
||||||
|
trim: false
|
||||||
|
dataloader_params:
|
||||||
|
drop_last: false
|
||||||
|
shuffle: true
|
||||||
|
batch_size: 16
|
||||||
|
num_workers: 4
|
||||||
|
|
||||||
|
validation_ds:
|
||||||
|
dataset:
|
||||||
|
_target_: "nemo.collections.tts.data.datalayers.AudioDataset"
|
||||||
|
manifest_filepath: ${validation_datasets}
|
||||||
|
max_duration: null
|
||||||
|
min_duration: null
|
||||||
|
n_segments: -1
|
||||||
|
trim: false
|
||||||
|
dataloader_params:
|
||||||
|
drop_last: false
|
||||||
|
shuffle: false
|
||||||
|
batch_size: 3
|
||||||
|
num_workers: 4
|
||||||
|
|
||||||
|
preprocessor:
|
||||||
|
_target_: nemo.collections.asr.parts.features.FilterbankFeatures
|
||||||
|
dither: 0.0
|
||||||
|
frame_splicing: 1
|
||||||
|
nfilt: 80
|
||||||
|
highfreq: 8000
|
||||||
|
log: true
|
||||||
|
log_zero_guard_type: clamp
|
||||||
|
log_zero_guard_value: 1e-05
|
||||||
|
lowfreq: 0
|
||||||
|
mag_power: 1.0
|
||||||
|
n_fft: 1024
|
||||||
|
n_window_size: 1024
|
||||||
|
n_window_stride: 256
|
||||||
|
normalize: null
|
||||||
|
pad_to: 0
|
||||||
|
pad_value: -11.52
|
||||||
|
preemph: null
|
||||||
|
sample_rate: 22050
|
||||||
|
stft_exact_pad: true
|
||||||
|
stft_conv: false
|
||||||
|
window: hann
|
||||||
|
use_grads: false
|
||||||
|
|
||||||
|
optim:
|
||||||
|
lr: 0.0002
|
||||||
|
adam_b1: 0.8
|
||||||
|
adam_b2: 0.99
|
||||||
|
lr_decay: 0.999
|
||||||
|
lr_step: 15
|
||||||
|
|
||||||
|
trainer:
|
||||||
|
gpus: -1 # number of gpus
|
||||||
|
max_epochs: 1000000
|
||||||
|
num_nodes: 1
|
||||||
|
accelerator: ddp
|
||||||
|
accumulate_grad_batches: 1
|
||||||
|
checkpoint_callback: False # Provided by exp_manager
|
||||||
|
logger: False # Provided by exp_manager
|
||||||
|
log_every_n_steps: 100
|
||||||
|
check_val_every_n_epoch: 10
|
||||||
|
automatic_optimization: false
|
||||||
|
|
||||||
|
exp_manager:
|
||||||
|
exp_dir: null
|
||||||
|
name: ${name}
|
||||||
|
create_tensorboard_logger: True
|
||||||
|
create_checkpoint_callback: True
|
8
examples/tts/conf/hifigan/model/generator/v1.yaml
Normal file
8
examples/tts/conf/hifigan/model/generator/v1.yaml
Normal file
|
@ -0,0 +1,8 @@
|
||||||
|
# @package _group_
|
||||||
|
_target_: nemo.collections.tts.modules.hifigan_modules.Generator
|
||||||
|
resblock: 1
|
||||||
|
upsample_rates: [8,8,2,2]
|
||||||
|
upsample_kernel_sizes: [16,16,4,4]
|
||||||
|
upsample_initial_channel: 512
|
||||||
|
resblock_kernel_sizes: [3,7,11]
|
||||||
|
resblock_dilation_sizes: [[1,3,5], [1,3,5], [1,3,5]]
|
8
examples/tts/conf/hifigan/model/generator/v2.yaml
Normal file
8
examples/tts/conf/hifigan/model/generator/v2.yaml
Normal file
|
@ -0,0 +1,8 @@
|
||||||
|
# @package _group_
|
||||||
|
_target_: nemo.collections.tts.modules.hifigan_modules.Generator
|
||||||
|
resblock: 1
|
||||||
|
upsample_rates: [8,8,2,2]
|
||||||
|
upsample_kernel_sizes: [16,16,4,4]
|
||||||
|
upsample_initial_channel: 128
|
||||||
|
resblock_kernel_sizes: [3,7,11]
|
||||||
|
resblock_dilation_sizes: [[1,3,5], [1,3,5], [1,3,5]]
|
8
examples/tts/conf/hifigan/model/generator/v3.yaml
Normal file
8
examples/tts/conf/hifigan/model/generator/v3.yaml
Normal file
|
@ -0,0 +1,8 @@
|
||||||
|
# @package _group_
|
||||||
|
_target_: nemo.collections.tts.modules.hifigan_modules.Generator
|
||||||
|
resblock: 2
|
||||||
|
upsample_rates: [8,8,4]
|
||||||
|
upsample_kernel_sizes: [16,16,8]
|
||||||
|
upsample_initial_channel: 256
|
||||||
|
resblock_kernel_sizes: [3,5,7]
|
||||||
|
resblock_dilation_sizes: [[1,2], [2,6], [3,12]]
|
|
@ -28,7 +28,7 @@ model:
|
||||||
_target_: "nemo.collections.tts.data.datalayers.AudioDataset"
|
_target_: "nemo.collections.tts.data.datalayers.AudioDataset"
|
||||||
manifest_filepath: ${validation_datasets}
|
manifest_filepath: ${validation_datasets}
|
||||||
max_duration: null
|
max_duration: null
|
||||||
min_duration: 0.75
|
min_duration: null
|
||||||
n_segments: -1
|
n_segments: -1
|
||||||
trim: false
|
trim: false
|
||||||
dataloader_params:
|
dataloader_params:
|
||||||
|
|
34
examples/tts/hifigan.py
Normal file
34
examples/tts/hifigan.py
Normal file
|
@ -0,0 +1,34 @@
|
||||||
|
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import pytorch_lightning as pl
|
||||||
|
|
||||||
|
from nemo.collections.common.callbacks import LogEpochTimeCallback
|
||||||
|
from nemo.collections.tts.models import HifiGanModel
|
||||||
|
from nemo.core.config import hydra_runner
|
||||||
|
from nemo.utils.exp_manager import exp_manager
|
||||||
|
|
||||||
|
|
||||||
|
@hydra_runner(config_path="conf/hifigan", config_name="hifigan")
|
||||||
|
def main(cfg):
|
||||||
|
trainer = pl.Trainer(**cfg.trainer)
|
||||||
|
exp_manager(trainer, cfg.get("exp_manager", None))
|
||||||
|
model = HifiGanModel(cfg=cfg.model, trainer=trainer)
|
||||||
|
epoch_time_logger = LogEpochTimeCallback()
|
||||||
|
trainer.callbacks.extend([epoch_time_logger])
|
||||||
|
trainer.fit(model)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main() # noqa pylint: disable=no-value-for-parameter
|
|
@ -220,6 +220,7 @@ class FilterbankFeatures(nn.Module):
|
||||||
stft_conv=False,
|
stft_conv=False,
|
||||||
pad_value=0,
|
pad_value=0,
|
||||||
mag_power=2.0,
|
mag_power=2.0,
|
||||||
|
use_grads=False,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.log_zero_guard_value = log_zero_guard_value
|
self.log_zero_guard_value = log_zero_guard_value
|
||||||
|
@ -301,6 +302,11 @@ class FilterbankFeatures(nn.Module):
|
||||||
f"log_zero_guard_type parameter. It must be either 'add' or "
|
f"log_zero_guard_type parameter. It must be either 'add' or "
|
||||||
f"'clamp'."
|
f"'clamp'."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.use_grads = use_grads
|
||||||
|
if not use_grads:
|
||||||
|
self.forward = torch.no_grad()(self.forward)
|
||||||
|
|
||||||
# log_zero_guard_value is the the small we want to use, we support
|
# log_zero_guard_value is the the small we want to use, we support
|
||||||
# an actual number, or "tiny", or "eps"
|
# an actual number, or "tiny", or "eps"
|
||||||
self.log_zero_guard_type = log_zero_guard_type
|
self.log_zero_guard_type = log_zero_guard_type
|
||||||
|
@ -311,6 +317,7 @@ class FilterbankFeatures(nn.Module):
|
||||||
logging.debug(f"n_mels: {nfilt}")
|
logging.debug(f"n_mels: {nfilt}")
|
||||||
logging.debug(f"fmin: {lowfreq}")
|
logging.debug(f"fmin: {lowfreq}")
|
||||||
logging.debug(f"fmax: {highfreq}")
|
logging.debug(f"fmax: {highfreq}")
|
||||||
|
logging.debug(f"using grads: {use_grads}")
|
||||||
|
|
||||||
def log_zero_guard_value_fn(self, x):
|
def log_zero_guard_value_fn(self, x):
|
||||||
if isinstance(self.log_zero_guard_value, str):
|
if isinstance(self.log_zero_guard_value, str):
|
||||||
|
@ -334,7 +341,6 @@ class FilterbankFeatures(nn.Module):
|
||||||
def filter_banks(self):
|
def filter_banks(self):
|
||||||
return self.fb
|
return self.fb
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def forward(self, x, seq_len):
|
def forward(self, x, seq_len):
|
||||||
seq_len = self.get_seq_len(seq_len.float())
|
seq_len = self.get_seq_len(seq_len.float())
|
||||||
|
|
||||||
|
@ -356,7 +362,9 @@ class FilterbankFeatures(nn.Module):
|
||||||
|
|
||||||
# torch returns real, imag; so convert to magnitude
|
# torch returns real, imag; so convert to magnitude
|
||||||
if not self.stft_conv:
|
if not self.stft_conv:
|
||||||
x = torch.sqrt(x.pow(2).sum(-1))
|
# guard is needed for sqrt if grads are passed through
|
||||||
|
guard = 0 if not self.use_grads else CONSTANT
|
||||||
|
x = torch.sqrt(x.pow(2).sum(-1) + guard)
|
||||||
|
|
||||||
# get power spectrum
|
# get power spectrum
|
||||||
if self.mag_power != 1.0:
|
if self.mag_power != 1.0:
|
||||||
|
|
132
nemo/collections/tts/losses/hifigan_losses.py
Normal file
132
nemo/collections/tts/losses/hifigan_losses.py
Normal file
|
@ -0,0 +1,132 @@
|
||||||
|
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
# MIT License
|
||||||
|
#
|
||||||
|
# Copyright (c) 2020 Jungil Kong
|
||||||
|
#
|
||||||
|
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
|
# of this software and associated documentation files (the "Software"), to deal
|
||||||
|
# in the Software without restriction, including without limitation the rights
|
||||||
|
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||||
|
# copies of the Software, and to permit persons to whom the Software is
|
||||||
|
# furnished to do so, subject to the following conditions:
|
||||||
|
#
|
||||||
|
# The above copyright notice and this permission notice shall be included in all
|
||||||
|
# copies or substantial portions of the Software.
|
||||||
|
#
|
||||||
|
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
|
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
|
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||||
|
# SOFTWARE.
|
||||||
|
|
||||||
|
# The forward functions onf the following classes are based on code from https://github.com/jik876/hifi-gan:
|
||||||
|
# FeatureMatchingLoss, DiscriminatorLoss, GeneratorLoss
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from nemo.core.classes import Loss, typecheck
|
||||||
|
from nemo.core.neural_types.elements import LossType, VoidType
|
||||||
|
from nemo.core.neural_types.neural_type import NeuralType
|
||||||
|
|
||||||
|
|
||||||
|
class FeatureMatchingLoss(Loss):
|
||||||
|
"""Feature Matching Loss module"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def input_types(self):
|
||||||
|
return {
|
||||||
|
"fmap_r": NeuralType(elements_type=VoidType()),
|
||||||
|
"fmap_g": NeuralType(elements_type=VoidType()),
|
||||||
|
}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def output_types(self):
|
||||||
|
return {
|
||||||
|
"loss": NeuralType(elements_type=LossType()),
|
||||||
|
}
|
||||||
|
|
||||||
|
@typecheck()
|
||||||
|
def forward(self, fmap_r, fmap_g):
|
||||||
|
loss = 0
|
||||||
|
for dr, dg in zip(fmap_r, fmap_g):
|
||||||
|
for rl, gl in zip(dr, dg):
|
||||||
|
loss += torch.mean(torch.abs(rl - gl))
|
||||||
|
|
||||||
|
return loss * 2
|
||||||
|
|
||||||
|
|
||||||
|
class DiscriminatorLoss(Loss):
|
||||||
|
"""Discriminator Loss module"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def input_types(self):
|
||||||
|
return {
|
||||||
|
"disc_real_outputs": NeuralType(('B', 'T'), VoidType()),
|
||||||
|
"disc_generated_outputs": NeuralType(('B', 'T'), VoidType()),
|
||||||
|
}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def output_types(self):
|
||||||
|
return {
|
||||||
|
"loss": NeuralType(elements_type=LossType()),
|
||||||
|
"real_losses": NeuralType(elements_type=LossType()),
|
||||||
|
"fake_losses": NeuralType(elements_type=LossType()),
|
||||||
|
}
|
||||||
|
|
||||||
|
@typecheck()
|
||||||
|
def forward(self, disc_real_outputs, disc_generated_outputs):
|
||||||
|
loss = 0
|
||||||
|
r_losses = []
|
||||||
|
g_losses = []
|
||||||
|
for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
|
||||||
|
r_loss = torch.mean((1 - dr) ** 2)
|
||||||
|
g_loss = torch.mean(dg ** 2)
|
||||||
|
loss += r_loss + g_loss
|
||||||
|
r_losses.append(r_loss.item())
|
||||||
|
g_losses.append(g_loss.item())
|
||||||
|
|
||||||
|
return loss, r_losses, g_losses
|
||||||
|
|
||||||
|
|
||||||
|
class GeneratorLoss(Loss):
|
||||||
|
"""Generator Loss module"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def input_types(self):
|
||||||
|
return {
|
||||||
|
"disc_outputs": NeuralType(('B', 'T'), VoidType()),
|
||||||
|
}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def output_types(self):
|
||||||
|
return {
|
||||||
|
"loss": NeuralType(elements_type=LossType()),
|
||||||
|
"fake_losses": NeuralType(elements_type=LossType()),
|
||||||
|
}
|
||||||
|
|
||||||
|
@typecheck()
|
||||||
|
def forward(self, disc_outputs):
|
||||||
|
loss = 0
|
||||||
|
gen_losses = []
|
||||||
|
for dg in disc_outputs:
|
||||||
|
l = torch.mean((1 - dg) ** 2)
|
||||||
|
gen_losses.append(l)
|
||||||
|
loss += l
|
||||||
|
|
||||||
|
return loss, gen_losses
|
|
@ -15,6 +15,7 @@
|
||||||
from nemo.collections.tts.models.degli import DegliModel
|
from nemo.collections.tts.models.degli import DegliModel
|
||||||
from nemo.collections.tts.models.ed_mel2spec import EDMel2SpecModel
|
from nemo.collections.tts.models.ed_mel2spec import EDMel2SpecModel
|
||||||
from nemo.collections.tts.models.glow_tts import GlowTTSModel
|
from nemo.collections.tts.models.glow_tts import GlowTTSModel
|
||||||
|
from nemo.collections.tts.models.hifigan import HifiGanModel
|
||||||
from nemo.collections.tts.models.melgan import MelGanModel
|
from nemo.collections.tts.models.melgan import MelGanModel
|
||||||
from nemo.collections.tts.models.squeezewave import SqueezeWaveModel
|
from nemo.collections.tts.models.squeezewave import SqueezeWaveModel
|
||||||
from nemo.collections.tts.models.tacotron2 import Tacotron2Model
|
from nemo.collections.tts.models.tacotron2 import Tacotron2Model
|
||||||
|
@ -37,4 +38,5 @@ __all__ = [
|
||||||
"TalkNetSpectModel",
|
"TalkNetSpectModel",
|
||||||
"UniGlowModel",
|
"UniGlowModel",
|
||||||
"MelGanModel",
|
"MelGanModel",
|
||||||
|
"HifiGanModel",
|
||||||
]
|
]
|
||||||
|
|
219
nemo/collections/tts/models/hifigan.py
Normal file
219
nemo/collections/tts/models/hifigan.py
Normal file
|
@ -0,0 +1,219 @@
|
||||||
|
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import itertools
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import wandb
|
||||||
|
from hydra.utils import instantiate
|
||||||
|
from omegaconf import DictConfig, OmegaConf, open_dict
|
||||||
|
from pytorch_lightning.loggers.wandb import WandbLogger
|
||||||
|
|
||||||
|
from nemo.collections.tts.helpers.helpers import plot_spectrogram_to_numpy
|
||||||
|
from nemo.collections.tts.losses.hifigan_losses import DiscriminatorLoss, FeatureMatchingLoss, GeneratorLoss
|
||||||
|
from nemo.collections.tts.models.base import Vocoder
|
||||||
|
from nemo.collections.tts.modules.hifigan_modules import MultiPeriodDiscriminator, MultiScaleDiscriminator
|
||||||
|
from nemo.core.classes.common import typecheck
|
||||||
|
from nemo.core.neural_types.elements import AudioSignal, MelSpectrogramType
|
||||||
|
from nemo.core.neural_types.neural_type import NeuralType
|
||||||
|
from nemo.utils import logging
|
||||||
|
|
||||||
|
|
||||||
|
class HifiGanModel(Vocoder):
|
||||||
|
def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None):
|
||||||
|
if isinstance(cfg, dict):
|
||||||
|
cfg = OmegaConf.create(cfg)
|
||||||
|
super().__init__(cfg=cfg, trainer=trainer)
|
||||||
|
|
||||||
|
self.audio_to_melspec_precessor = instantiate(cfg.preprocessor)
|
||||||
|
# use a different melspec extractor because:
|
||||||
|
# 1. we need to pass grads
|
||||||
|
# 2. we need remove fmax limitation
|
||||||
|
self.trg_melspec_fn = instantiate(cfg.preprocessor, highfreq=None, use_grads=True)
|
||||||
|
self.generator = instantiate(cfg.generator)
|
||||||
|
self.mpd = MultiPeriodDiscriminator()
|
||||||
|
self.msd = MultiScaleDiscriminator()
|
||||||
|
self.feature_loss = FeatureMatchingLoss()
|
||||||
|
self.discriminator_loss = DiscriminatorLoss()
|
||||||
|
self.generator_loss = GeneratorLoss()
|
||||||
|
|
||||||
|
self.sample_rate = self._cfg.preprocessor.sample_rate
|
||||||
|
|
||||||
|
def configure_optimizers(self):
|
||||||
|
self.optim_g = torch.optim.AdamW(
|
||||||
|
self.generator.parameters(), self._cfg.optim.lr, betas=[self._cfg.optim.adam_b1, self._cfg.optim.adam_b2]
|
||||||
|
)
|
||||||
|
self.optim_d = torch.optim.AdamW(
|
||||||
|
itertools.chain(self.msd.parameters(), self.mpd.parameters()),
|
||||||
|
self._cfg.optim.lr,
|
||||||
|
betas=[self._cfg.optim.adam_b1, self._cfg.optim.adam_b2],
|
||||||
|
)
|
||||||
|
|
||||||
|
self.scheduler_g = torch.optim.lr_scheduler.StepLR(
|
||||||
|
self.optim_g, step_size=self._cfg.optim.lr_step, gamma=self._cfg.optim.lr_decay,
|
||||||
|
)
|
||||||
|
self.scheduler_d = torch.optim.lr_scheduler.StepLR(
|
||||||
|
self.optim_d, step_size=self._cfg.optim.lr_step, gamma=self._cfg.optim.lr_decay,
|
||||||
|
)
|
||||||
|
|
||||||
|
return [self.optim_g, self.optim_d], [self.scheduler_g, self.scheduler_d]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def input_types(self):
|
||||||
|
return {
|
||||||
|
"spec": NeuralType(('B', 'D', 'T'), MelSpectrogramType()),
|
||||||
|
}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def output_types(self):
|
||||||
|
return {
|
||||||
|
"audio": NeuralType(('B', 'S', 'T'), AudioSignal(self.sample_rate)),
|
||||||
|
}
|
||||||
|
|
||||||
|
@typecheck()
|
||||||
|
def forward(self, *, spec):
|
||||||
|
"""
|
||||||
|
Runs the generator, for inputs and outputs see input_types, and output_types
|
||||||
|
"""
|
||||||
|
return self.generator(x=spec)
|
||||||
|
|
||||||
|
@typecheck(output_types={"audio": NeuralType(('B', 'T'), AudioSignal())})
|
||||||
|
def convert_spectrogram_to_audio(self, spec: 'torch.tensor') -> 'torch.tensor':
|
||||||
|
return self(spec=spec).squeeze(1)
|
||||||
|
|
||||||
|
def training_step(self, batch, batch_idx, optimizer_idx):
|
||||||
|
audio, audio_len = batch
|
||||||
|
# mel as input for generator
|
||||||
|
audio_mel, _ = self.audio_to_melspec_precessor(audio, audio_len)
|
||||||
|
# mel as input for L1 mel loss
|
||||||
|
audio_trg_mel, _ = self.trg_melspec_fn(audio, audio_len)
|
||||||
|
audio = audio.unsqueeze(1)
|
||||||
|
|
||||||
|
audio_pred = self.generator(x=audio_mel)
|
||||||
|
audio_pred_mel, _ = self.trg_melspec_fn(audio_pred.squeeze(1), audio_len)
|
||||||
|
|
||||||
|
# train discriminator
|
||||||
|
self.optim_d.zero_grad()
|
||||||
|
mpd_score_real, mpd_score_gen, _, _ = self.mpd(y=audio, y_hat=audio_pred.detach())
|
||||||
|
loss_disc_mpd, _, _ = self.discriminator_loss(
|
||||||
|
disc_real_outputs=mpd_score_real, disc_generated_outputs=mpd_score_gen
|
||||||
|
)
|
||||||
|
msd_score_real, msd_score_gen, _, _ = self.msd(y=audio, y_hat=audio_pred.detach())
|
||||||
|
loss_disc_msd, _, _ = self.discriminator_loss(
|
||||||
|
disc_real_outputs=msd_score_real, disc_generated_outputs=msd_score_gen
|
||||||
|
)
|
||||||
|
loss_d = loss_disc_msd + loss_disc_mpd
|
||||||
|
self.manual_backward(loss_d, self.optim_d)
|
||||||
|
self.optim_d.step()
|
||||||
|
|
||||||
|
# train generator
|
||||||
|
self.optim_g.zero_grad()
|
||||||
|
loss_mel = F.l1_loss(audio_pred_mel, audio_trg_mel) * 45
|
||||||
|
_, mpd_score_gen, fmap_mpd_real, fmap_mpd_gen = self.mpd(y=audio, y_hat=audio_pred)
|
||||||
|
_, msd_score_gen, fmap_msd_real, fmap_msd_gen = self.msd(y=audio, y_hat=audio_pred)
|
||||||
|
loss_fm_mpd = self.feature_loss(fmap_r=fmap_mpd_real, fmap_g=fmap_mpd_gen)
|
||||||
|
loss_fm_msd = self.feature_loss(fmap_r=fmap_msd_real, fmap_g=fmap_msd_gen)
|
||||||
|
loss_gen_mpd, _ = self.generator_loss(disc_outputs=mpd_score_gen)
|
||||||
|
loss_gen_msd, _ = self.generator_loss(disc_outputs=msd_score_gen)
|
||||||
|
loss_g = loss_gen_msd + loss_gen_mpd + loss_fm_msd + loss_fm_mpd + loss_mel
|
||||||
|
self.manual_backward(loss_g, self.optim_g)
|
||||||
|
self.optim_g.step()
|
||||||
|
|
||||||
|
metrics = {
|
||||||
|
"g_l1_loss": loss_mel,
|
||||||
|
"g_loss_fm_mpd": loss_fm_mpd,
|
||||||
|
"g_loss_fm_msd": loss_fm_msd,
|
||||||
|
"g_loss_gen_mpd": loss_gen_mpd,
|
||||||
|
"g_loss_gen_msd": loss_gen_msd,
|
||||||
|
"g_loss": loss_g,
|
||||||
|
"d_loss_mpd": loss_disc_mpd,
|
||||||
|
"d_loss_msd": loss_disc_msd,
|
||||||
|
"d_loss": loss_d,
|
||||||
|
"global_step": self.global_step,
|
||||||
|
"lr": self.optim_g.param_groups[0]['lr'],
|
||||||
|
}
|
||||||
|
self.log_dict(metrics, on_step=False, on_epoch=True, sync_dist=True)
|
||||||
|
|
||||||
|
def validation_step(self, batch, batch_idx):
|
||||||
|
audio, audio_len = batch
|
||||||
|
audio_mel, audio_mel_len = self.audio_to_melspec_precessor(audio, audio_len)
|
||||||
|
audio_pred = self(spec=audio_mel)
|
||||||
|
|
||||||
|
audio_pred_mel, _ = self.audio_to_melspec_precessor(audio_pred.squeeze(1), audio_len)
|
||||||
|
loss_mel = F.l1_loss(audio_mel, audio_pred_mel)
|
||||||
|
|
||||||
|
self.log("val_loss", loss_mel, prog_bar=True, sync_dist=True)
|
||||||
|
|
||||||
|
# plot audio once per epoch
|
||||||
|
if batch_idx == 0 and isinstance(self.logger, WandbLogger):
|
||||||
|
clips = []
|
||||||
|
specs = []
|
||||||
|
for i in range(min(5, audio.shape[0])):
|
||||||
|
clips += [
|
||||||
|
wandb.Audio(
|
||||||
|
audio[i, : audio_len[i]].data.cpu().numpy(),
|
||||||
|
caption=f"real audio {i}",
|
||||||
|
sample_rate=self.sample_rate,
|
||||||
|
),
|
||||||
|
wandb.Audio(
|
||||||
|
audio_pred[i, 0, : audio_len[i]].data.cpu().numpy().astype('float32'),
|
||||||
|
caption=f"generated audio {i}",
|
||||||
|
sample_rate=self.sample_rate,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
specs += [
|
||||||
|
wandb.Image(
|
||||||
|
plot_spectrogram_to_numpy(audio_mel[i, :, : audio_mel_len[i]].data.cpu().numpy()),
|
||||||
|
caption=f"real audio {i}",
|
||||||
|
),
|
||||||
|
wandb.Image(
|
||||||
|
plot_spectrogram_to_numpy(audio_pred_mel[i, :, : audio_mel_len[i]].data.cpu().numpy()),
|
||||||
|
caption=f"generated audio {i}",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
self.logger.experiment.log({"audio": clips, "specs": specs}, commit=False)
|
||||||
|
|
||||||
|
def __setup_dataloader_from_config(self, cfg, shuffle_should_be: bool = True, name: str = "train"):
|
||||||
|
if "dataset" not in cfg or not isinstance(cfg.dataset, DictConfig):
|
||||||
|
raise ValueError(f"No dataset for {name}")
|
||||||
|
if "dataloader_params" not in cfg or not isinstance(cfg.dataloader_params, DictConfig):
|
||||||
|
raise ValueError(f"No dataloder_params for {name}")
|
||||||
|
if shuffle_should_be:
|
||||||
|
if 'shuffle' not in cfg.dataloader_params:
|
||||||
|
logging.warning(
|
||||||
|
f"Shuffle should be set to True for {self}'s {name} dataloader but was not found in its "
|
||||||
|
"config. Manually setting to True"
|
||||||
|
)
|
||||||
|
with open_dict(cfg["dataloader_params"]):
|
||||||
|
cfg.dataloader_params.shuffle = True
|
||||||
|
elif not cfg.dataloader_params.shuffle:
|
||||||
|
logging.error(f"The {name} dataloader for {self} has shuffle set to False!!!")
|
||||||
|
elif not shuffle_should_be and cfg.dataloader_params.shuffle:
|
||||||
|
logging.error(f"The {name} dataloader for {self} has shuffle set to True!!!")
|
||||||
|
|
||||||
|
dataset = instantiate(cfg.dataset)
|
||||||
|
return torch.utils.data.DataLoader(dataset, collate_fn=dataset.collate_fn, **cfg.dataloader_params)
|
||||||
|
|
||||||
|
def setup_training_data(self, cfg):
|
||||||
|
self._train_dl = self.__setup_dataloader_from_config(cfg)
|
||||||
|
|
||||||
|
def setup_validation_data(self, cfg):
|
||||||
|
self._validation_dl = self.__setup_dataloader_from_config(cfg, shuffle_should_be=False, name="validation")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def list_available_models(cls) -> 'Optional[Dict[str, str]]':
|
||||||
|
# TODO
|
||||||
|
pass
|
434
nemo/collections/tts/modules/hifigan_modules.py
Normal file
434
nemo/collections/tts/modules/hifigan_modules.py
Normal file
|
@ -0,0 +1,434 @@
|
||||||
|
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
# MIT License
|
||||||
|
#
|
||||||
|
# Copyright (c) 2020 Jungil Kong
|
||||||
|
#
|
||||||
|
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
|
# of this software and associated documentation files (the "Software"), to deal
|
||||||
|
# in the Software without restriction, including without limitation the rights
|
||||||
|
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||||
|
# copies of the Software, and to permit persons to whom the Software is
|
||||||
|
# furnished to do so, subject to the following conditions:
|
||||||
|
#
|
||||||
|
# The above copyright notice and this permission notice shall be included in all
|
||||||
|
# copies or substantial portions of the Software.
|
||||||
|
#
|
||||||
|
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
|
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
|
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||||
|
# SOFTWARE.
|
||||||
|
|
||||||
|
# The following functions/classes were based on code from https://github.com/jik876/hifi-gan:
|
||||||
|
# ResBlock1, ResBlock2, Generator, DiscriminatorP, DiscriminatorS, MultiScaleDiscriminator,
|
||||||
|
# MultiPeriodDiscriminator, init_weights, get_padding
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch.nn import AvgPool1d, Conv1d, Conv2d, ConvTranspose1d
|
||||||
|
from torch.nn.utils import remove_weight_norm, spectral_norm, weight_norm
|
||||||
|
|
||||||
|
from nemo.core.classes import NeuralModule
|
||||||
|
from nemo.core.classes.common import typecheck
|
||||||
|
from nemo.core.neural_types.elements import AudioSignal, MelSpectrogramType, VoidType
|
||||||
|
from nemo.core.neural_types.neural_type import NeuralType
|
||||||
|
|
||||||
|
LRELU_SLOPE = 0.1
|
||||||
|
|
||||||
|
|
||||||
|
def init_weights(m, mean=0.0, std=0.01):
|
||||||
|
classname = m.__class__.__name__
|
||||||
|
if classname.find("Conv") != -1:
|
||||||
|
m.weight.data.normal_(mean, std)
|
||||||
|
|
||||||
|
|
||||||
|
def get_padding(kernel_size, dilation=1):
|
||||||
|
return int((kernel_size * dilation - dilation) / 2)
|
||||||
|
|
||||||
|
|
||||||
|
class ResBlock1(torch.nn.Module):
|
||||||
|
def __init__(self, channels, kernel_size, dilation):
|
||||||
|
super(ResBlock1, self).__init__()
|
||||||
|
self.convs1 = nn.ModuleList(
|
||||||
|
[
|
||||||
|
weight_norm(
|
||||||
|
Conv1d(
|
||||||
|
channels,
|
||||||
|
channels,
|
||||||
|
kernel_size,
|
||||||
|
1,
|
||||||
|
dilation=dilation[0],
|
||||||
|
padding=get_padding(kernel_size, dilation[0]),
|
||||||
|
)
|
||||||
|
),
|
||||||
|
weight_norm(
|
||||||
|
Conv1d(
|
||||||
|
channels,
|
||||||
|
channels,
|
||||||
|
kernel_size,
|
||||||
|
1,
|
||||||
|
dilation=dilation[1],
|
||||||
|
padding=get_padding(kernel_size, dilation[1]),
|
||||||
|
)
|
||||||
|
),
|
||||||
|
weight_norm(
|
||||||
|
Conv1d(
|
||||||
|
channels,
|
||||||
|
channels,
|
||||||
|
kernel_size,
|
||||||
|
1,
|
||||||
|
dilation=dilation[2],
|
||||||
|
padding=get_padding(kernel_size, dilation[2]),
|
||||||
|
)
|
||||||
|
),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
self.convs1.apply(init_weights)
|
||||||
|
|
||||||
|
self.convs2 = nn.ModuleList(
|
||||||
|
[
|
||||||
|
weight_norm(
|
||||||
|
Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1))
|
||||||
|
),
|
||||||
|
weight_norm(
|
||||||
|
Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1))
|
||||||
|
),
|
||||||
|
weight_norm(
|
||||||
|
Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1))
|
||||||
|
),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
self.convs2.apply(init_weights)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
for c1, c2 in zip(self.convs1, self.convs2):
|
||||||
|
xt = F.leaky_relu(x, LRELU_SLOPE)
|
||||||
|
xt = c1(xt)
|
||||||
|
xt = F.leaky_relu(xt, LRELU_SLOPE)
|
||||||
|
xt = c2(xt)
|
||||||
|
x = xt + x
|
||||||
|
return x
|
||||||
|
|
||||||
|
def remove_weight_norm(self):
|
||||||
|
for l in self.convs1:
|
||||||
|
remove_weight_norm(l)
|
||||||
|
for l in self.convs2:
|
||||||
|
remove_weight_norm(l)
|
||||||
|
|
||||||
|
|
||||||
|
class ResBlock2(torch.nn.Module):
|
||||||
|
def __init__(self, channels, kernel_size, dilation):
|
||||||
|
super(ResBlock2, self).__init__()
|
||||||
|
self.convs = nn.ModuleList(
|
||||||
|
[
|
||||||
|
weight_norm(
|
||||||
|
Conv1d(
|
||||||
|
channels,
|
||||||
|
channels,
|
||||||
|
kernel_size,
|
||||||
|
1,
|
||||||
|
dilation=dilation[0],
|
||||||
|
padding=get_padding(kernel_size, dilation[0]),
|
||||||
|
)
|
||||||
|
),
|
||||||
|
weight_norm(
|
||||||
|
Conv1d(
|
||||||
|
channels,
|
||||||
|
channels,
|
||||||
|
kernel_size,
|
||||||
|
1,
|
||||||
|
dilation=dilation[1],
|
||||||
|
padding=get_padding(kernel_size, dilation[1]),
|
||||||
|
)
|
||||||
|
),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
self.convs.apply(init_weights)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
for c in self.convs:
|
||||||
|
xt = F.leaky_relu(x, LRELU_SLOPE)
|
||||||
|
xt = c(xt)
|
||||||
|
x = xt + x
|
||||||
|
return x
|
||||||
|
|
||||||
|
def remove_weight_norm(self):
|
||||||
|
for l in self.convs:
|
||||||
|
remove_weight_norm(l)
|
||||||
|
|
||||||
|
|
||||||
|
class Generator(NeuralModule):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
resblock,
|
||||||
|
upsample_rates,
|
||||||
|
upsample_kernel_sizes,
|
||||||
|
upsample_initial_channel,
|
||||||
|
resblock_kernel_sizes,
|
||||||
|
resblock_dilation_sizes,
|
||||||
|
):
|
||||||
|
super(Generator, self).__init__()
|
||||||
|
self.num_kernels = len(resblock_kernel_sizes)
|
||||||
|
self.num_upsamples = len(upsample_rates)
|
||||||
|
self.conv_pre = weight_norm(Conv1d(80, upsample_initial_channel, 7, 1, padding=3))
|
||||||
|
resblock = ResBlock1 if resblock == 1 else ResBlock2
|
||||||
|
|
||||||
|
self.ups = nn.ModuleList()
|
||||||
|
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
|
||||||
|
self.ups.append(
|
||||||
|
weight_norm(
|
||||||
|
ConvTranspose1d(
|
||||||
|
upsample_initial_channel // (2 ** i),
|
||||||
|
upsample_initial_channel // (2 ** (i + 1)),
|
||||||
|
k,
|
||||||
|
u,
|
||||||
|
padding=(k - u) // 2,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.resblocks = nn.ModuleList()
|
||||||
|
for i in range(len(self.ups)):
|
||||||
|
ch = upsample_initial_channel // (2 ** (i + 1))
|
||||||
|
for j, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
|
||||||
|
self.resblocks.append(resblock(ch, k, d))
|
||||||
|
|
||||||
|
self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3))
|
||||||
|
self.ups.apply(init_weights)
|
||||||
|
self.conv_post.apply(init_weights)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def input_types(self):
|
||||||
|
return {
|
||||||
|
"x": NeuralType(('B', 'D', 'T'), MelSpectrogramType()),
|
||||||
|
}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def output_types(self):
|
||||||
|
return {
|
||||||
|
"audio": NeuralType(('B', 'S', 'T'), AudioSignal()),
|
||||||
|
}
|
||||||
|
|
||||||
|
@typecheck()
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.conv_pre(x)
|
||||||
|
for i in range(self.num_upsamples):
|
||||||
|
x = F.leaky_relu(x, LRELU_SLOPE)
|
||||||
|
x = self.ups[i](x)
|
||||||
|
xs = None
|
||||||
|
for j in range(self.num_kernels):
|
||||||
|
if xs is None:
|
||||||
|
xs = self.resblocks[i * self.num_kernels + j](x)
|
||||||
|
else:
|
||||||
|
xs += self.resblocks[i * self.num_kernels + j](x)
|
||||||
|
x = xs / self.num_kernels
|
||||||
|
x = F.leaky_relu(x)
|
||||||
|
x = self.conv_post(x)
|
||||||
|
x = torch.tanh(x)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
def remove_weight_norm(self):
|
||||||
|
print('Removing weight norm...')
|
||||||
|
for l in self.ups:
|
||||||
|
remove_weight_norm(l)
|
||||||
|
for l in self.resblocks:
|
||||||
|
l.remove_weight_norm()
|
||||||
|
remove_weight_norm(self.conv_pre)
|
||||||
|
remove_weight_norm(self.conv_post)
|
||||||
|
|
||||||
|
|
||||||
|
class DiscriminatorP(NeuralModule):
|
||||||
|
def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
|
||||||
|
super(DiscriminatorP, self).__init__()
|
||||||
|
self.period = period
|
||||||
|
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
|
||||||
|
self.convs = nn.ModuleList(
|
||||||
|
[
|
||||||
|
norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
|
||||||
|
norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
|
||||||
|
norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
|
||||||
|
norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
|
||||||
|
norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
|
||||||
|
|
||||||
|
@property
|
||||||
|
def input_types(self):
|
||||||
|
return {
|
||||||
|
"x": NeuralType(('B', 'S', 'T'), AudioSignal()),
|
||||||
|
}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def output_types(self):
|
||||||
|
return {
|
||||||
|
"decision": NeuralType(('B', 'T'), VoidType()),
|
||||||
|
"feature_maps": NeuralType(("B", "C", "H", "W"), VoidType()),
|
||||||
|
}
|
||||||
|
|
||||||
|
@typecheck()
|
||||||
|
def forward(self, x):
|
||||||
|
fmap = []
|
||||||
|
|
||||||
|
# 1d to 2d
|
||||||
|
b, c, t = x.shape
|
||||||
|
if t % self.period != 0: # pad first
|
||||||
|
n_pad = self.period - (t % self.period)
|
||||||
|
x = F.pad(x, (0, n_pad), "reflect")
|
||||||
|
t = t + n_pad
|
||||||
|
x = x.view(b, c, t // self.period, self.period)
|
||||||
|
|
||||||
|
for l in self.convs:
|
||||||
|
x = l(x)
|
||||||
|
x = F.leaky_relu(x, LRELU_SLOPE)
|
||||||
|
fmap.append(x)
|
||||||
|
x = self.conv_post(x)
|
||||||
|
fmap.append(x)
|
||||||
|
x = torch.flatten(x, 1, -1)
|
||||||
|
|
||||||
|
return x, fmap
|
||||||
|
|
||||||
|
|
||||||
|
class MultiPeriodDiscriminator(NeuralModule):
|
||||||
|
def __init__(self):
|
||||||
|
super(MultiPeriodDiscriminator, self).__init__()
|
||||||
|
self.discriminators = nn.ModuleList(
|
||||||
|
[DiscriminatorP(2), DiscriminatorP(3), DiscriminatorP(5), DiscriminatorP(7), DiscriminatorP(11),]
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def output_types(self):
|
||||||
|
return {
|
||||||
|
"y": NeuralType(('B', 'S', 'T'), AudioSignal()),
|
||||||
|
"y_hat": NeuralType(('B', 'S', 'T'), AudioSignal()),
|
||||||
|
}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def output_types(self):
|
||||||
|
return {
|
||||||
|
"real_scores": NeuralType(('B', 'T'), VoidType()),
|
||||||
|
"fake_scores": NeuralType(('B', 'T'), VoidType()),
|
||||||
|
"real_feature_maps": NeuralType(("B", "C", "H", "W"), VoidType()),
|
||||||
|
"fake_feature_maps": NeuralType(("B", "C", "H", "W"), VoidType()),
|
||||||
|
}
|
||||||
|
|
||||||
|
@typecheck()
|
||||||
|
def forward(self, y, y_hat):
|
||||||
|
y_d_rs = []
|
||||||
|
y_d_gs = []
|
||||||
|
fmap_rs = []
|
||||||
|
fmap_gs = []
|
||||||
|
for i, d in enumerate(self.discriminators):
|
||||||
|
y_d_r, fmap_r = d(x=y)
|
||||||
|
y_d_g, fmap_g = d(x=y_hat)
|
||||||
|
y_d_rs.append(y_d_r)
|
||||||
|
fmap_rs.append(fmap_r)
|
||||||
|
y_d_gs.append(y_d_g)
|
||||||
|
fmap_gs.append(fmap_g)
|
||||||
|
|
||||||
|
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
||||||
|
|
||||||
|
|
||||||
|
class DiscriminatorS(NeuralModule):
|
||||||
|
def __init__(self, use_spectral_norm=False):
|
||||||
|
super(DiscriminatorS, self).__init__()
|
||||||
|
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
|
||||||
|
self.convs = nn.ModuleList(
|
||||||
|
[
|
||||||
|
norm_f(Conv1d(1, 128, 15, 1, padding=7)),
|
||||||
|
norm_f(Conv1d(128, 128, 41, 2, groups=4, padding=20)),
|
||||||
|
norm_f(Conv1d(128, 256, 41, 2, groups=16, padding=20)),
|
||||||
|
norm_f(Conv1d(256, 512, 41, 4, groups=16, padding=20)),
|
||||||
|
norm_f(Conv1d(512, 1024, 41, 4, groups=16, padding=20)),
|
||||||
|
norm_f(Conv1d(1024, 1024, 41, 1, groups=16, padding=20)),
|
||||||
|
norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
|
||||||
|
|
||||||
|
@property
|
||||||
|
def input_types(self):
|
||||||
|
return {
|
||||||
|
"x": NeuralType(('B', 'S', 'T'), AudioSignal()),
|
||||||
|
}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def output_types(self):
|
||||||
|
return {
|
||||||
|
"decision": NeuralType(('B', 'T'), VoidType()),
|
||||||
|
"feature_maps": NeuralType(("B", "C", "T"), VoidType()),
|
||||||
|
}
|
||||||
|
|
||||||
|
@typecheck()
|
||||||
|
def forward(self, x):
|
||||||
|
fmap = []
|
||||||
|
for l in self.convs:
|
||||||
|
x = l(x)
|
||||||
|
x = F.leaky_relu(x, LRELU_SLOPE)
|
||||||
|
fmap.append(x)
|
||||||
|
x = self.conv_post(x)
|
||||||
|
fmap.append(x)
|
||||||
|
x = torch.flatten(x, 1, -1)
|
||||||
|
|
||||||
|
return x, fmap
|
||||||
|
|
||||||
|
|
||||||
|
class MultiScaleDiscriminator(NeuralModule):
|
||||||
|
def __init__(self):
|
||||||
|
super(MultiScaleDiscriminator, self).__init__()
|
||||||
|
self.discriminators = nn.ModuleList(
|
||||||
|
[DiscriminatorS(use_spectral_norm=True), DiscriminatorS(), DiscriminatorS(),]
|
||||||
|
)
|
||||||
|
self.meanpools = nn.ModuleList([AvgPool1d(4, 2, padding=2), AvgPool1d(4, 2, padding=2)])
|
||||||
|
|
||||||
|
@property
|
||||||
|
def output_types(self):
|
||||||
|
return {
|
||||||
|
"y": NeuralType(('B', 'S', 'T'), AudioSignal()),
|
||||||
|
"y_hat": NeuralType(('B', 'S', 'T'), AudioSignal()),
|
||||||
|
}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def output_types(self):
|
||||||
|
return {
|
||||||
|
"real_scores": NeuralType(('B', 'T'), VoidType()),
|
||||||
|
"fake_scores": NeuralType(('B', 'T'), VoidType()),
|
||||||
|
"real_feature_maps": NeuralType(("B", "C", "T"), VoidType()),
|
||||||
|
"fake_feature_maps": NeuralType(("B", "C", "T"), VoidType()),
|
||||||
|
}
|
||||||
|
|
||||||
|
@typecheck()
|
||||||
|
def forward(self, y, y_hat):
|
||||||
|
y_d_rs = []
|
||||||
|
y_d_gs = []
|
||||||
|
fmap_rs = []
|
||||||
|
fmap_gs = []
|
||||||
|
for i, d in enumerate(self.discriminators):
|
||||||
|
if i != 0:
|
||||||
|
y = self.meanpools[i - 1](y)
|
||||||
|
y_hat = self.meanpools[i - 1](y_hat)
|
||||||
|
y_d_r, fmap_r = d(x=y)
|
||||||
|
y_d_g, fmap_g = d(x=y_hat)
|
||||||
|
y_d_rs.append(y_d_r)
|
||||||
|
fmap_rs.append(fmap_r)
|
||||||
|
y_d_gs.append(y_d_g)
|
||||||
|
fmap_gs.append(fmap_g)
|
||||||
|
|
||||||
|
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
Loading…
Reference in a new issue