HifiGAN MelSpectrogram Vocoder Model (#1706)
* HifiGAN: initial commit Signed-off-by: Felix Kreuk <felixkreuk@gmail.com> * switched to manual optimization, added grads to FilterBanks object, added v1,v2,v3 configurations for the generator added multiple audio/spec examples in wandb Signed-off-by: Felix Kreuk <felixkreuk@gmail.com> * added trg_mel_fn for different mel params in L1 loss Signed-off-by: Felix Kreuk <felixkreuk@gmail.com> * fixed ci checks Signed-off-by: Felix Kreuk <felixkreuk@gmail.com> * moved losses to separate classes, fixed exp_manager in yaml, style fixes Signed-off-by: Felix Kreuk <felixkreuk@gmail.com> * fixed `min_duration` in yamls of MelGAN and HifiGAN Signed-off-by: Felix Kreuk <felixkreuk@gmail.com> * fixed f_max in loss calculation, added docs Signed-off-by: Felix Kreuk <felixkreuk@gmail.com> Co-authored-by: Oleksii Kuchaiev <okuchaiev@users.noreply.github.com> Co-authored-by: Jason <jasoli@nvidia.com>
This commit is contained in:
parent
562087498d
commit
7d8b7d679a
1
.gitignore
vendored
1
.gitignore
vendored
|
@ -11,6 +11,7 @@ bert.pt.json
|
|||
work
|
||||
runs
|
||||
fastspeech_output
|
||||
.hydra
|
||||
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
|
|
|
@ -51,6 +51,12 @@ Full-band MelGAN
|
|||
The :class:`MelGanModel<nemo.collections.tts.models.MelGanModel>` class implements Full-band MelGAN as described in
|
||||
:cite:`yang2020multiband`. MelGAN is a GAN-based vocoder that converts mel-spectrograms to audio.
|
||||
|
||||
HifiGAN
|
||||
-------
|
||||
|
||||
The :class:`HifiGanModel<nemo.collections.tts.models.HifiGanModel>` class implements HifiGAN as described in
|
||||
:cite:`kong2020hifi`. HifiGAN is a GAN-based vocoder that converts mel-spectrograms to audio.
|
||||
|
||||
References
|
||||
----------
|
||||
|
||||
|
|
|
@ -37,4 +37,11 @@
|
|||
eprint={2005.05106},
|
||||
archivePrefix={arXiv},
|
||||
primaryClass={cs.SD}
|
||||
}
|
||||
}
|
||||
|
||||
@article{kong2020hifi,
|
||||
title={HiFi-GAN: Generative Adversarial Networks for Efficient and High Fidelity Speech Synthesis},
|
||||
author={Kong, Jungil and Kim, Jaehyeon and Bae, Jaekyoung},
|
||||
journal={arXiv preprint arXiv:2010.05646},
|
||||
year={2020}
|
||||
}
|
||||
|
|
84
examples/tts/conf/hifigan/hifigan.yaml
Normal file
84
examples/tts/conf/hifigan/hifigan.yaml
Normal file
|
@ -0,0 +1,84 @@
|
|||
name: "HifiGan"
|
||||
train_dataset: ???
|
||||
validation_datasets: ???
|
||||
|
||||
defaults:
|
||||
- model/generator: v1
|
||||
|
||||
model:
|
||||
train_ds:
|
||||
dataset:
|
||||
_target_: "nemo.collections.tts.data.datalayers.AudioDataset"
|
||||
manifest_filepath: ${train_dataset}
|
||||
max_duration: null
|
||||
min_duration: 0.75
|
||||
n_segments: 8192
|
||||
trim: false
|
||||
dataloader_params:
|
||||
drop_last: false
|
||||
shuffle: true
|
||||
batch_size: 16
|
||||
num_workers: 4
|
||||
|
||||
validation_ds:
|
||||
dataset:
|
||||
_target_: "nemo.collections.tts.data.datalayers.AudioDataset"
|
||||
manifest_filepath: ${validation_datasets}
|
||||
max_duration: null
|
||||
min_duration: null
|
||||
n_segments: -1
|
||||
trim: false
|
||||
dataloader_params:
|
||||
drop_last: false
|
||||
shuffle: false
|
||||
batch_size: 3
|
||||
num_workers: 4
|
||||
|
||||
preprocessor:
|
||||
_target_: nemo.collections.asr.parts.features.FilterbankFeatures
|
||||
dither: 0.0
|
||||
frame_splicing: 1
|
||||
nfilt: 80
|
||||
highfreq: 8000
|
||||
log: true
|
||||
log_zero_guard_type: clamp
|
||||
log_zero_guard_value: 1e-05
|
||||
lowfreq: 0
|
||||
mag_power: 1.0
|
||||
n_fft: 1024
|
||||
n_window_size: 1024
|
||||
n_window_stride: 256
|
||||
normalize: null
|
||||
pad_to: 0
|
||||
pad_value: -11.52
|
||||
preemph: null
|
||||
sample_rate: 22050
|
||||
stft_exact_pad: true
|
||||
stft_conv: false
|
||||
window: hann
|
||||
use_grads: false
|
||||
|
||||
optim:
|
||||
lr: 0.0002
|
||||
adam_b1: 0.8
|
||||
adam_b2: 0.99
|
||||
lr_decay: 0.999
|
||||
lr_step: 15
|
||||
|
||||
trainer:
|
||||
gpus: -1 # number of gpus
|
||||
max_epochs: 1000000
|
||||
num_nodes: 1
|
||||
accelerator: ddp
|
||||
accumulate_grad_batches: 1
|
||||
checkpoint_callback: False # Provided by exp_manager
|
||||
logger: False # Provided by exp_manager
|
||||
log_every_n_steps: 100
|
||||
check_val_every_n_epoch: 10
|
||||
automatic_optimization: false
|
||||
|
||||
exp_manager:
|
||||
exp_dir: null
|
||||
name: ${name}
|
||||
create_tensorboard_logger: True
|
||||
create_checkpoint_callback: True
|
8
examples/tts/conf/hifigan/model/generator/v1.yaml
Normal file
8
examples/tts/conf/hifigan/model/generator/v1.yaml
Normal file
|
@ -0,0 +1,8 @@
|
|||
# @package _group_
|
||||
_target_: nemo.collections.tts.modules.hifigan_modules.Generator
|
||||
resblock: 1
|
||||
upsample_rates: [8,8,2,2]
|
||||
upsample_kernel_sizes: [16,16,4,4]
|
||||
upsample_initial_channel: 512
|
||||
resblock_kernel_sizes: [3,7,11]
|
||||
resblock_dilation_sizes: [[1,3,5], [1,3,5], [1,3,5]]
|
8
examples/tts/conf/hifigan/model/generator/v2.yaml
Normal file
8
examples/tts/conf/hifigan/model/generator/v2.yaml
Normal file
|
@ -0,0 +1,8 @@
|
|||
# @package _group_
|
||||
_target_: nemo.collections.tts.modules.hifigan_modules.Generator
|
||||
resblock: 1
|
||||
upsample_rates: [8,8,2,2]
|
||||
upsample_kernel_sizes: [16,16,4,4]
|
||||
upsample_initial_channel: 128
|
||||
resblock_kernel_sizes: [3,7,11]
|
||||
resblock_dilation_sizes: [[1,3,5], [1,3,5], [1,3,5]]
|
8
examples/tts/conf/hifigan/model/generator/v3.yaml
Normal file
8
examples/tts/conf/hifigan/model/generator/v3.yaml
Normal file
|
@ -0,0 +1,8 @@
|
|||
# @package _group_
|
||||
_target_: nemo.collections.tts.modules.hifigan_modules.Generator
|
||||
resblock: 2
|
||||
upsample_rates: [8,8,4]
|
||||
upsample_kernel_sizes: [16,16,8]
|
||||
upsample_initial_channel: 256
|
||||
resblock_kernel_sizes: [3,5,7]
|
||||
resblock_dilation_sizes: [[1,2], [2,6], [3,12]]
|
|
@ -28,7 +28,7 @@ model:
|
|||
_target_: "nemo.collections.tts.data.datalayers.AudioDataset"
|
||||
manifest_filepath: ${validation_datasets}
|
||||
max_duration: null
|
||||
min_duration: 0.75
|
||||
min_duration: null
|
||||
n_segments: -1
|
||||
trim: false
|
||||
dataloader_params:
|
||||
|
|
34
examples/tts/hifigan.py
Normal file
34
examples/tts/hifigan.py
Normal file
|
@ -0,0 +1,34 @@
|
|||
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import pytorch_lightning as pl
|
||||
|
||||
from nemo.collections.common.callbacks import LogEpochTimeCallback
|
||||
from nemo.collections.tts.models import HifiGanModel
|
||||
from nemo.core.config import hydra_runner
|
||||
from nemo.utils.exp_manager import exp_manager
|
||||
|
||||
|
||||
@hydra_runner(config_path="conf/hifigan", config_name="hifigan")
|
||||
def main(cfg):
|
||||
trainer = pl.Trainer(**cfg.trainer)
|
||||
exp_manager(trainer, cfg.get("exp_manager", None))
|
||||
model = HifiGanModel(cfg=cfg.model, trainer=trainer)
|
||||
epoch_time_logger = LogEpochTimeCallback()
|
||||
trainer.callbacks.extend([epoch_time_logger])
|
||||
trainer.fit(model)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main() # noqa pylint: disable=no-value-for-parameter
|
|
@ -220,6 +220,7 @@ class FilterbankFeatures(nn.Module):
|
|||
stft_conv=False,
|
||||
pad_value=0,
|
||||
mag_power=2.0,
|
||||
use_grads=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.log_zero_guard_value = log_zero_guard_value
|
||||
|
@ -301,6 +302,11 @@ class FilterbankFeatures(nn.Module):
|
|||
f"log_zero_guard_type parameter. It must be either 'add' or "
|
||||
f"'clamp'."
|
||||
)
|
||||
|
||||
self.use_grads = use_grads
|
||||
if not use_grads:
|
||||
self.forward = torch.no_grad()(self.forward)
|
||||
|
||||
# log_zero_guard_value is the the small we want to use, we support
|
||||
# an actual number, or "tiny", or "eps"
|
||||
self.log_zero_guard_type = log_zero_guard_type
|
||||
|
@ -311,6 +317,7 @@ class FilterbankFeatures(nn.Module):
|
|||
logging.debug(f"n_mels: {nfilt}")
|
||||
logging.debug(f"fmin: {lowfreq}")
|
||||
logging.debug(f"fmax: {highfreq}")
|
||||
logging.debug(f"using grads: {use_grads}")
|
||||
|
||||
def log_zero_guard_value_fn(self, x):
|
||||
if isinstance(self.log_zero_guard_value, str):
|
||||
|
@ -334,7 +341,6 @@ class FilterbankFeatures(nn.Module):
|
|||
def filter_banks(self):
|
||||
return self.fb
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(self, x, seq_len):
|
||||
seq_len = self.get_seq_len(seq_len.float())
|
||||
|
||||
|
@ -356,7 +362,9 @@ class FilterbankFeatures(nn.Module):
|
|||
|
||||
# torch returns real, imag; so convert to magnitude
|
||||
if not self.stft_conv:
|
||||
x = torch.sqrt(x.pow(2).sum(-1))
|
||||
# guard is needed for sqrt if grads are passed through
|
||||
guard = 0 if not self.use_grads else CONSTANT
|
||||
x = torch.sqrt(x.pow(2).sum(-1) + guard)
|
||||
|
||||
# get power spectrum
|
||||
if self.mag_power != 1.0:
|
||||
|
|
132
nemo/collections/tts/losses/hifigan_losses.py
Normal file
132
nemo/collections/tts/losses/hifigan_losses.py
Normal file
|
@ -0,0 +1,132 @@
|
|||
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
# MIT License
|
||||
#
|
||||
# Copyright (c) 2020 Jungil Kong
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
# of this software and associated documentation files (the "Software"), to deal
|
||||
# in the Software without restriction, including without limitation the rights
|
||||
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
# copies of the Software, and to permit persons to whom the Software is
|
||||
# furnished to do so, subject to the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be included in all
|
||||
# copies or substantial portions of the Software.
|
||||
#
|
||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
# SOFTWARE.
|
||||
|
||||
# The forward functions onf the following classes are based on code from https://github.com/jik876/hifi-gan:
|
||||
# FeatureMatchingLoss, DiscriminatorLoss, GeneratorLoss
|
||||
|
||||
import torch
|
||||
|
||||
from nemo.core.classes import Loss, typecheck
|
||||
from nemo.core.neural_types.elements import LossType, VoidType
|
||||
from nemo.core.neural_types.neural_type import NeuralType
|
||||
|
||||
|
||||
class FeatureMatchingLoss(Loss):
|
||||
"""Feature Matching Loss module"""
|
||||
|
||||
@property
|
||||
def input_types(self):
|
||||
return {
|
||||
"fmap_r": NeuralType(elements_type=VoidType()),
|
||||
"fmap_g": NeuralType(elements_type=VoidType()),
|
||||
}
|
||||
|
||||
@property
|
||||
def output_types(self):
|
||||
return {
|
||||
"loss": NeuralType(elements_type=LossType()),
|
||||
}
|
||||
|
||||
@typecheck()
|
||||
def forward(self, fmap_r, fmap_g):
|
||||
loss = 0
|
||||
for dr, dg in zip(fmap_r, fmap_g):
|
||||
for rl, gl in zip(dr, dg):
|
||||
loss += torch.mean(torch.abs(rl - gl))
|
||||
|
||||
return loss * 2
|
||||
|
||||
|
||||
class DiscriminatorLoss(Loss):
|
||||
"""Discriminator Loss module"""
|
||||
|
||||
@property
|
||||
def input_types(self):
|
||||
return {
|
||||
"disc_real_outputs": NeuralType(('B', 'T'), VoidType()),
|
||||
"disc_generated_outputs": NeuralType(('B', 'T'), VoidType()),
|
||||
}
|
||||
|
||||
@property
|
||||
def output_types(self):
|
||||
return {
|
||||
"loss": NeuralType(elements_type=LossType()),
|
||||
"real_losses": NeuralType(elements_type=LossType()),
|
||||
"fake_losses": NeuralType(elements_type=LossType()),
|
||||
}
|
||||
|
||||
@typecheck()
|
||||
def forward(self, disc_real_outputs, disc_generated_outputs):
|
||||
loss = 0
|
||||
r_losses = []
|
||||
g_losses = []
|
||||
for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
|
||||
r_loss = torch.mean((1 - dr) ** 2)
|
||||
g_loss = torch.mean(dg ** 2)
|
||||
loss += r_loss + g_loss
|
||||
r_losses.append(r_loss.item())
|
||||
g_losses.append(g_loss.item())
|
||||
|
||||
return loss, r_losses, g_losses
|
||||
|
||||
|
||||
class GeneratorLoss(Loss):
|
||||
"""Generator Loss module"""
|
||||
|
||||
@property
|
||||
def input_types(self):
|
||||
return {
|
||||
"disc_outputs": NeuralType(('B', 'T'), VoidType()),
|
||||
}
|
||||
|
||||
@property
|
||||
def output_types(self):
|
||||
return {
|
||||
"loss": NeuralType(elements_type=LossType()),
|
||||
"fake_losses": NeuralType(elements_type=LossType()),
|
||||
}
|
||||
|
||||
@typecheck()
|
||||
def forward(self, disc_outputs):
|
||||
loss = 0
|
||||
gen_losses = []
|
||||
for dg in disc_outputs:
|
||||
l = torch.mean((1 - dg) ** 2)
|
||||
gen_losses.append(l)
|
||||
loss += l
|
||||
|
||||
return loss, gen_losses
|
|
@ -15,6 +15,7 @@
|
|||
from nemo.collections.tts.models.degli import DegliModel
|
||||
from nemo.collections.tts.models.ed_mel2spec import EDMel2SpecModel
|
||||
from nemo.collections.tts.models.glow_tts import GlowTTSModel
|
||||
from nemo.collections.tts.models.hifigan import HifiGanModel
|
||||
from nemo.collections.tts.models.melgan import MelGanModel
|
||||
from nemo.collections.tts.models.squeezewave import SqueezeWaveModel
|
||||
from nemo.collections.tts.models.tacotron2 import Tacotron2Model
|
||||
|
@ -37,4 +38,5 @@ __all__ = [
|
|||
"TalkNetSpectModel",
|
||||
"UniGlowModel",
|
||||
"MelGanModel",
|
||||
"HifiGanModel",
|
||||
]
|
||||
|
|
219
nemo/collections/tts/models/hifigan.py
Normal file
219
nemo/collections/tts/models/hifigan.py
Normal file
|
@ -0,0 +1,219 @@
|
|||
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import itertools
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import wandb
|
||||
from hydra.utils import instantiate
|
||||
from omegaconf import DictConfig, OmegaConf, open_dict
|
||||
from pytorch_lightning.loggers.wandb import WandbLogger
|
||||
|
||||
from nemo.collections.tts.helpers.helpers import plot_spectrogram_to_numpy
|
||||
from nemo.collections.tts.losses.hifigan_losses import DiscriminatorLoss, FeatureMatchingLoss, GeneratorLoss
|
||||
from nemo.collections.tts.models.base import Vocoder
|
||||
from nemo.collections.tts.modules.hifigan_modules import MultiPeriodDiscriminator, MultiScaleDiscriminator
|
||||
from nemo.core.classes.common import typecheck
|
||||
from nemo.core.neural_types.elements import AudioSignal, MelSpectrogramType
|
||||
from nemo.core.neural_types.neural_type import NeuralType
|
||||
from nemo.utils import logging
|
||||
|
||||
|
||||
class HifiGanModel(Vocoder):
|
||||
def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None):
|
||||
if isinstance(cfg, dict):
|
||||
cfg = OmegaConf.create(cfg)
|
||||
super().__init__(cfg=cfg, trainer=trainer)
|
||||
|
||||
self.audio_to_melspec_precessor = instantiate(cfg.preprocessor)
|
||||
# use a different melspec extractor because:
|
||||
# 1. we need to pass grads
|
||||
# 2. we need remove fmax limitation
|
||||
self.trg_melspec_fn = instantiate(cfg.preprocessor, highfreq=None, use_grads=True)
|
||||
self.generator = instantiate(cfg.generator)
|
||||
self.mpd = MultiPeriodDiscriminator()
|
||||
self.msd = MultiScaleDiscriminator()
|
||||
self.feature_loss = FeatureMatchingLoss()
|
||||
self.discriminator_loss = DiscriminatorLoss()
|
||||
self.generator_loss = GeneratorLoss()
|
||||
|
||||
self.sample_rate = self._cfg.preprocessor.sample_rate
|
||||
|
||||
def configure_optimizers(self):
|
||||
self.optim_g = torch.optim.AdamW(
|
||||
self.generator.parameters(), self._cfg.optim.lr, betas=[self._cfg.optim.adam_b1, self._cfg.optim.adam_b2]
|
||||
)
|
||||
self.optim_d = torch.optim.AdamW(
|
||||
itertools.chain(self.msd.parameters(), self.mpd.parameters()),
|
||||
self._cfg.optim.lr,
|
||||
betas=[self._cfg.optim.adam_b1, self._cfg.optim.adam_b2],
|
||||
)
|
||||
|
||||
self.scheduler_g = torch.optim.lr_scheduler.StepLR(
|
||||
self.optim_g, step_size=self._cfg.optim.lr_step, gamma=self._cfg.optim.lr_decay,
|
||||
)
|
||||
self.scheduler_d = torch.optim.lr_scheduler.StepLR(
|
||||
self.optim_d, step_size=self._cfg.optim.lr_step, gamma=self._cfg.optim.lr_decay,
|
||||
)
|
||||
|
||||
return [self.optim_g, self.optim_d], [self.scheduler_g, self.scheduler_d]
|
||||
|
||||
@property
|
||||
def input_types(self):
|
||||
return {
|
||||
"spec": NeuralType(('B', 'D', 'T'), MelSpectrogramType()),
|
||||
}
|
||||
|
||||
@property
|
||||
def output_types(self):
|
||||
return {
|
||||
"audio": NeuralType(('B', 'S', 'T'), AudioSignal(self.sample_rate)),
|
||||
}
|
||||
|
||||
@typecheck()
|
||||
def forward(self, *, spec):
|
||||
"""
|
||||
Runs the generator, for inputs and outputs see input_types, and output_types
|
||||
"""
|
||||
return self.generator(x=spec)
|
||||
|
||||
@typecheck(output_types={"audio": NeuralType(('B', 'T'), AudioSignal())})
|
||||
def convert_spectrogram_to_audio(self, spec: 'torch.tensor') -> 'torch.tensor':
|
||||
return self(spec=spec).squeeze(1)
|
||||
|
||||
def training_step(self, batch, batch_idx, optimizer_idx):
|
||||
audio, audio_len = batch
|
||||
# mel as input for generator
|
||||
audio_mel, _ = self.audio_to_melspec_precessor(audio, audio_len)
|
||||
# mel as input for L1 mel loss
|
||||
audio_trg_mel, _ = self.trg_melspec_fn(audio, audio_len)
|
||||
audio = audio.unsqueeze(1)
|
||||
|
||||
audio_pred = self.generator(x=audio_mel)
|
||||
audio_pred_mel, _ = self.trg_melspec_fn(audio_pred.squeeze(1), audio_len)
|
||||
|
||||
# train discriminator
|
||||
self.optim_d.zero_grad()
|
||||
mpd_score_real, mpd_score_gen, _, _ = self.mpd(y=audio, y_hat=audio_pred.detach())
|
||||
loss_disc_mpd, _, _ = self.discriminator_loss(
|
||||
disc_real_outputs=mpd_score_real, disc_generated_outputs=mpd_score_gen
|
||||
)
|
||||
msd_score_real, msd_score_gen, _, _ = self.msd(y=audio, y_hat=audio_pred.detach())
|
||||
loss_disc_msd, _, _ = self.discriminator_loss(
|
||||
disc_real_outputs=msd_score_real, disc_generated_outputs=msd_score_gen
|
||||
)
|
||||
loss_d = loss_disc_msd + loss_disc_mpd
|
||||
self.manual_backward(loss_d, self.optim_d)
|
||||
self.optim_d.step()
|
||||
|
||||
# train generator
|
||||
self.optim_g.zero_grad()
|
||||
loss_mel = F.l1_loss(audio_pred_mel, audio_trg_mel) * 45
|
||||
_, mpd_score_gen, fmap_mpd_real, fmap_mpd_gen = self.mpd(y=audio, y_hat=audio_pred)
|
||||
_, msd_score_gen, fmap_msd_real, fmap_msd_gen = self.msd(y=audio, y_hat=audio_pred)
|
||||
loss_fm_mpd = self.feature_loss(fmap_r=fmap_mpd_real, fmap_g=fmap_mpd_gen)
|
||||
loss_fm_msd = self.feature_loss(fmap_r=fmap_msd_real, fmap_g=fmap_msd_gen)
|
||||
loss_gen_mpd, _ = self.generator_loss(disc_outputs=mpd_score_gen)
|
||||
loss_gen_msd, _ = self.generator_loss(disc_outputs=msd_score_gen)
|
||||
loss_g = loss_gen_msd + loss_gen_mpd + loss_fm_msd + loss_fm_mpd + loss_mel
|
||||
self.manual_backward(loss_g, self.optim_g)
|
||||
self.optim_g.step()
|
||||
|
||||
metrics = {
|
||||
"g_l1_loss": loss_mel,
|
||||
"g_loss_fm_mpd": loss_fm_mpd,
|
||||
"g_loss_fm_msd": loss_fm_msd,
|
||||
"g_loss_gen_mpd": loss_gen_mpd,
|
||||
"g_loss_gen_msd": loss_gen_msd,
|
||||
"g_loss": loss_g,
|
||||
"d_loss_mpd": loss_disc_mpd,
|
||||
"d_loss_msd": loss_disc_msd,
|
||||
"d_loss": loss_d,
|
||||
"global_step": self.global_step,
|
||||
"lr": self.optim_g.param_groups[0]['lr'],
|
||||
}
|
||||
self.log_dict(metrics, on_step=False, on_epoch=True, sync_dist=True)
|
||||
|
||||
def validation_step(self, batch, batch_idx):
|
||||
audio, audio_len = batch
|
||||
audio_mel, audio_mel_len = self.audio_to_melspec_precessor(audio, audio_len)
|
||||
audio_pred = self(spec=audio_mel)
|
||||
|
||||
audio_pred_mel, _ = self.audio_to_melspec_precessor(audio_pred.squeeze(1), audio_len)
|
||||
loss_mel = F.l1_loss(audio_mel, audio_pred_mel)
|
||||
|
||||
self.log("val_loss", loss_mel, prog_bar=True, sync_dist=True)
|
||||
|
||||
# plot audio once per epoch
|
||||
if batch_idx == 0 and isinstance(self.logger, WandbLogger):
|
||||
clips = []
|
||||
specs = []
|
||||
for i in range(min(5, audio.shape[0])):
|
||||
clips += [
|
||||
wandb.Audio(
|
||||
audio[i, : audio_len[i]].data.cpu().numpy(),
|
||||
caption=f"real audio {i}",
|
||||
sample_rate=self.sample_rate,
|
||||
),
|
||||
wandb.Audio(
|
||||
audio_pred[i, 0, : audio_len[i]].data.cpu().numpy().astype('float32'),
|
||||
caption=f"generated audio {i}",
|
||||
sample_rate=self.sample_rate,
|
||||
),
|
||||
]
|
||||
specs += [
|
||||
wandb.Image(
|
||||
plot_spectrogram_to_numpy(audio_mel[i, :, : audio_mel_len[i]].data.cpu().numpy()),
|
||||
caption=f"real audio {i}",
|
||||
),
|
||||
wandb.Image(
|
||||
plot_spectrogram_to_numpy(audio_pred_mel[i, :, : audio_mel_len[i]].data.cpu().numpy()),
|
||||
caption=f"generated audio {i}",
|
||||
),
|
||||
]
|
||||
|
||||
self.logger.experiment.log({"audio": clips, "specs": specs}, commit=False)
|
||||
|
||||
def __setup_dataloader_from_config(self, cfg, shuffle_should_be: bool = True, name: str = "train"):
|
||||
if "dataset" not in cfg or not isinstance(cfg.dataset, DictConfig):
|
||||
raise ValueError(f"No dataset for {name}")
|
||||
if "dataloader_params" not in cfg or not isinstance(cfg.dataloader_params, DictConfig):
|
||||
raise ValueError(f"No dataloder_params for {name}")
|
||||
if shuffle_should_be:
|
||||
if 'shuffle' not in cfg.dataloader_params:
|
||||
logging.warning(
|
||||
f"Shuffle should be set to True for {self}'s {name} dataloader but was not found in its "
|
||||
"config. Manually setting to True"
|
||||
)
|
||||
with open_dict(cfg["dataloader_params"]):
|
||||
cfg.dataloader_params.shuffle = True
|
||||
elif not cfg.dataloader_params.shuffle:
|
||||
logging.error(f"The {name} dataloader for {self} has shuffle set to False!!!")
|
||||
elif not shuffle_should_be and cfg.dataloader_params.shuffle:
|
||||
logging.error(f"The {name} dataloader for {self} has shuffle set to True!!!")
|
||||
|
||||
dataset = instantiate(cfg.dataset)
|
||||
return torch.utils.data.DataLoader(dataset, collate_fn=dataset.collate_fn, **cfg.dataloader_params)
|
||||
|
||||
def setup_training_data(self, cfg):
|
||||
self._train_dl = self.__setup_dataloader_from_config(cfg)
|
||||
|
||||
def setup_validation_data(self, cfg):
|
||||
self._validation_dl = self.__setup_dataloader_from_config(cfg, shuffle_should_be=False, name="validation")
|
||||
|
||||
@classmethod
|
||||
def list_available_models(cls) -> 'Optional[Dict[str, str]]':
|
||||
# TODO
|
||||
pass
|
434
nemo/collections/tts/modules/hifigan_modules.py
Normal file
434
nemo/collections/tts/modules/hifigan_modules.py
Normal file
|
@ -0,0 +1,434 @@
|
|||
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# MIT License
|
||||
#
|
||||
# Copyright (c) 2020 Jungil Kong
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
# of this software and associated documentation files (the "Software"), to deal
|
||||
# in the Software without restriction, including without limitation the rights
|
||||
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
# copies of the Software, and to permit persons to whom the Software is
|
||||
# furnished to do so, subject to the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be included in all
|
||||
# copies or substantial portions of the Software.
|
||||
#
|
||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
# SOFTWARE.
|
||||
|
||||
# The following functions/classes were based on code from https://github.com/jik876/hifi-gan:
|
||||
# ResBlock1, ResBlock2, Generator, DiscriminatorP, DiscriminatorS, MultiScaleDiscriminator,
|
||||
# MultiPeriodDiscriminator, init_weights, get_padding
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn import AvgPool1d, Conv1d, Conv2d, ConvTranspose1d
|
||||
from torch.nn.utils import remove_weight_norm, spectral_norm, weight_norm
|
||||
|
||||
from nemo.core.classes import NeuralModule
|
||||
from nemo.core.classes.common import typecheck
|
||||
from nemo.core.neural_types.elements import AudioSignal, MelSpectrogramType, VoidType
|
||||
from nemo.core.neural_types.neural_type import NeuralType
|
||||
|
||||
LRELU_SLOPE = 0.1
|
||||
|
||||
|
||||
def init_weights(m, mean=0.0, std=0.01):
|
||||
classname = m.__class__.__name__
|
||||
if classname.find("Conv") != -1:
|
||||
m.weight.data.normal_(mean, std)
|
||||
|
||||
|
||||
def get_padding(kernel_size, dilation=1):
|
||||
return int((kernel_size * dilation - dilation) / 2)
|
||||
|
||||
|
||||
class ResBlock1(torch.nn.Module):
|
||||
def __init__(self, channels, kernel_size, dilation):
|
||||
super(ResBlock1, self).__init__()
|
||||
self.convs1 = nn.ModuleList(
|
||||
[
|
||||
weight_norm(
|
||||
Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=dilation[0],
|
||||
padding=get_padding(kernel_size, dilation[0]),
|
||||
)
|
||||
),
|
||||
weight_norm(
|
||||
Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=dilation[1],
|
||||
padding=get_padding(kernel_size, dilation[1]),
|
||||
)
|
||||
),
|
||||
weight_norm(
|
||||
Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=dilation[2],
|
||||
padding=get_padding(kernel_size, dilation[2]),
|
||||
)
|
||||
),
|
||||
]
|
||||
)
|
||||
self.convs1.apply(init_weights)
|
||||
|
||||
self.convs2 = nn.ModuleList(
|
||||
[
|
||||
weight_norm(
|
||||
Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1))
|
||||
),
|
||||
weight_norm(
|
||||
Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1))
|
||||
),
|
||||
weight_norm(
|
||||
Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1))
|
||||
),
|
||||
]
|
||||
)
|
||||
self.convs2.apply(init_weights)
|
||||
|
||||
def forward(self, x):
|
||||
for c1, c2 in zip(self.convs1, self.convs2):
|
||||
xt = F.leaky_relu(x, LRELU_SLOPE)
|
||||
xt = c1(xt)
|
||||
xt = F.leaky_relu(xt, LRELU_SLOPE)
|
||||
xt = c2(xt)
|
||||
x = xt + x
|
||||
return x
|
||||
|
||||
def remove_weight_norm(self):
|
||||
for l in self.convs1:
|
||||
remove_weight_norm(l)
|
||||
for l in self.convs2:
|
||||
remove_weight_norm(l)
|
||||
|
||||
|
||||
class ResBlock2(torch.nn.Module):
|
||||
def __init__(self, channels, kernel_size, dilation):
|
||||
super(ResBlock2, self).__init__()
|
||||
self.convs = nn.ModuleList(
|
||||
[
|
||||
weight_norm(
|
||||
Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=dilation[0],
|
||||
padding=get_padding(kernel_size, dilation[0]),
|
||||
)
|
||||
),
|
||||
weight_norm(
|
||||
Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=dilation[1],
|
||||
padding=get_padding(kernel_size, dilation[1]),
|
||||
)
|
||||
),
|
||||
]
|
||||
)
|
||||
self.convs.apply(init_weights)
|
||||
|
||||
def forward(self, x):
|
||||
for c in self.convs:
|
||||
xt = F.leaky_relu(x, LRELU_SLOPE)
|
||||
xt = c(xt)
|
||||
x = xt + x
|
||||
return x
|
||||
|
||||
def remove_weight_norm(self):
|
||||
for l in self.convs:
|
||||
remove_weight_norm(l)
|
||||
|
||||
|
||||
class Generator(NeuralModule):
|
||||
def __init__(
|
||||
self,
|
||||
resblock,
|
||||
upsample_rates,
|
||||
upsample_kernel_sizes,
|
||||
upsample_initial_channel,
|
||||
resblock_kernel_sizes,
|
||||
resblock_dilation_sizes,
|
||||
):
|
||||
super(Generator, self).__init__()
|
||||
self.num_kernels = len(resblock_kernel_sizes)
|
||||
self.num_upsamples = len(upsample_rates)
|
||||
self.conv_pre = weight_norm(Conv1d(80, upsample_initial_channel, 7, 1, padding=3))
|
||||
resblock = ResBlock1 if resblock == 1 else ResBlock2
|
||||
|
||||
self.ups = nn.ModuleList()
|
||||
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
|
||||
self.ups.append(
|
||||
weight_norm(
|
||||
ConvTranspose1d(
|
||||
upsample_initial_channel // (2 ** i),
|
||||
upsample_initial_channel // (2 ** (i + 1)),
|
||||
k,
|
||||
u,
|
||||
padding=(k - u) // 2,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
self.resblocks = nn.ModuleList()
|
||||
for i in range(len(self.ups)):
|
||||
ch = upsample_initial_channel // (2 ** (i + 1))
|
||||
for j, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
|
||||
self.resblocks.append(resblock(ch, k, d))
|
||||
|
||||
self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3))
|
||||
self.ups.apply(init_weights)
|
||||
self.conv_post.apply(init_weights)
|
||||
|
||||
@property
|
||||
def input_types(self):
|
||||
return {
|
||||
"x": NeuralType(('B', 'D', 'T'), MelSpectrogramType()),
|
||||
}
|
||||
|
||||
@property
|
||||
def output_types(self):
|
||||
return {
|
||||
"audio": NeuralType(('B', 'S', 'T'), AudioSignal()),
|
||||
}
|
||||
|
||||
@typecheck()
|
||||
def forward(self, x):
|
||||
x = self.conv_pre(x)
|
||||
for i in range(self.num_upsamples):
|
||||
x = F.leaky_relu(x, LRELU_SLOPE)
|
||||
x = self.ups[i](x)
|
||||
xs = None
|
||||
for j in range(self.num_kernels):
|
||||
if xs is None:
|
||||
xs = self.resblocks[i * self.num_kernels + j](x)
|
||||
else:
|
||||
xs += self.resblocks[i * self.num_kernels + j](x)
|
||||
x = xs / self.num_kernels
|
||||
x = F.leaky_relu(x)
|
||||
x = self.conv_post(x)
|
||||
x = torch.tanh(x)
|
||||
|
||||
return x
|
||||
|
||||
def remove_weight_norm(self):
|
||||
print('Removing weight norm...')
|
||||
for l in self.ups:
|
||||
remove_weight_norm(l)
|
||||
for l in self.resblocks:
|
||||
l.remove_weight_norm()
|
||||
remove_weight_norm(self.conv_pre)
|
||||
remove_weight_norm(self.conv_post)
|
||||
|
||||
|
||||
class DiscriminatorP(NeuralModule):
|
||||
def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
|
||||
super(DiscriminatorP, self).__init__()
|
||||
self.period = period
|
||||
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
|
||||
self.convs = nn.ModuleList(
|
||||
[
|
||||
norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
|
||||
norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
|
||||
norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
|
||||
norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
|
||||
norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))),
|
||||
]
|
||||
)
|
||||
self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
|
||||
|
||||
@property
|
||||
def input_types(self):
|
||||
return {
|
||||
"x": NeuralType(('B', 'S', 'T'), AudioSignal()),
|
||||
}
|
||||
|
||||
@property
|
||||
def output_types(self):
|
||||
return {
|
||||
"decision": NeuralType(('B', 'T'), VoidType()),
|
||||
"feature_maps": NeuralType(("B", "C", "H", "W"), VoidType()),
|
||||
}
|
||||
|
||||
@typecheck()
|
||||
def forward(self, x):
|
||||
fmap = []
|
||||
|
||||
# 1d to 2d
|
||||
b, c, t = x.shape
|
||||
if t % self.period != 0: # pad first
|
||||
n_pad = self.period - (t % self.period)
|
||||
x = F.pad(x, (0, n_pad), "reflect")
|
||||
t = t + n_pad
|
||||
x = x.view(b, c, t // self.period, self.period)
|
||||
|
||||
for l in self.convs:
|
||||
x = l(x)
|
||||
x = F.leaky_relu(x, LRELU_SLOPE)
|
||||
fmap.append(x)
|
||||
x = self.conv_post(x)
|
||||
fmap.append(x)
|
||||
x = torch.flatten(x, 1, -1)
|
||||
|
||||
return x, fmap
|
||||
|
||||
|
||||
class MultiPeriodDiscriminator(NeuralModule):
|
||||
def __init__(self):
|
||||
super(MultiPeriodDiscriminator, self).__init__()
|
||||
self.discriminators = nn.ModuleList(
|
||||
[DiscriminatorP(2), DiscriminatorP(3), DiscriminatorP(5), DiscriminatorP(7), DiscriminatorP(11),]
|
||||
)
|
||||
|
||||
@property
|
||||
def output_types(self):
|
||||
return {
|
||||
"y": NeuralType(('B', 'S', 'T'), AudioSignal()),
|
||||
"y_hat": NeuralType(('B', 'S', 'T'), AudioSignal()),
|
||||
}
|
||||
|
||||
@property
|
||||
def output_types(self):
|
||||
return {
|
||||
"real_scores": NeuralType(('B', 'T'), VoidType()),
|
||||
"fake_scores": NeuralType(('B', 'T'), VoidType()),
|
||||
"real_feature_maps": NeuralType(("B", "C", "H", "W"), VoidType()),
|
||||
"fake_feature_maps": NeuralType(("B", "C", "H", "W"), VoidType()),
|
||||
}
|
||||
|
||||
@typecheck()
|
||||
def forward(self, y, y_hat):
|
||||
y_d_rs = []
|
||||
y_d_gs = []
|
||||
fmap_rs = []
|
||||
fmap_gs = []
|
||||
for i, d in enumerate(self.discriminators):
|
||||
y_d_r, fmap_r = d(x=y)
|
||||
y_d_g, fmap_g = d(x=y_hat)
|
||||
y_d_rs.append(y_d_r)
|
||||
fmap_rs.append(fmap_r)
|
||||
y_d_gs.append(y_d_g)
|
||||
fmap_gs.append(fmap_g)
|
||||
|
||||
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
||||
|
||||
|
||||
class DiscriminatorS(NeuralModule):
|
||||
def __init__(self, use_spectral_norm=False):
|
||||
super(DiscriminatorS, self).__init__()
|
||||
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
|
||||
self.convs = nn.ModuleList(
|
||||
[
|
||||
norm_f(Conv1d(1, 128, 15, 1, padding=7)),
|
||||
norm_f(Conv1d(128, 128, 41, 2, groups=4, padding=20)),
|
||||
norm_f(Conv1d(128, 256, 41, 2, groups=16, padding=20)),
|
||||
norm_f(Conv1d(256, 512, 41, 4, groups=16, padding=20)),
|
||||
norm_f(Conv1d(512, 1024, 41, 4, groups=16, padding=20)),
|
||||
norm_f(Conv1d(1024, 1024, 41, 1, groups=16, padding=20)),
|
||||
norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
|
||||
]
|
||||
)
|
||||
self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
|
||||
|
||||
@property
|
||||
def input_types(self):
|
||||
return {
|
||||
"x": NeuralType(('B', 'S', 'T'), AudioSignal()),
|
||||
}
|
||||
|
||||
@property
|
||||
def output_types(self):
|
||||
return {
|
||||
"decision": NeuralType(('B', 'T'), VoidType()),
|
||||
"feature_maps": NeuralType(("B", "C", "T"), VoidType()),
|
||||
}
|
||||
|
||||
@typecheck()
|
||||
def forward(self, x):
|
||||
fmap = []
|
||||
for l in self.convs:
|
||||
x = l(x)
|
||||
x = F.leaky_relu(x, LRELU_SLOPE)
|
||||
fmap.append(x)
|
||||
x = self.conv_post(x)
|
||||
fmap.append(x)
|
||||
x = torch.flatten(x, 1, -1)
|
||||
|
||||
return x, fmap
|
||||
|
||||
|
||||
class MultiScaleDiscriminator(NeuralModule):
|
||||
def __init__(self):
|
||||
super(MultiScaleDiscriminator, self).__init__()
|
||||
self.discriminators = nn.ModuleList(
|
||||
[DiscriminatorS(use_spectral_norm=True), DiscriminatorS(), DiscriminatorS(),]
|
||||
)
|
||||
self.meanpools = nn.ModuleList([AvgPool1d(4, 2, padding=2), AvgPool1d(4, 2, padding=2)])
|
||||
|
||||
@property
|
||||
def output_types(self):
|
||||
return {
|
||||
"y": NeuralType(('B', 'S', 'T'), AudioSignal()),
|
||||
"y_hat": NeuralType(('B', 'S', 'T'), AudioSignal()),
|
||||
}
|
||||
|
||||
@property
|
||||
def output_types(self):
|
||||
return {
|
||||
"real_scores": NeuralType(('B', 'T'), VoidType()),
|
||||
"fake_scores": NeuralType(('B', 'T'), VoidType()),
|
||||
"real_feature_maps": NeuralType(("B", "C", "T"), VoidType()),
|
||||
"fake_feature_maps": NeuralType(("B", "C", "T"), VoidType()),
|
||||
}
|
||||
|
||||
@typecheck()
|
||||
def forward(self, y, y_hat):
|
||||
y_d_rs = []
|
||||
y_d_gs = []
|
||||
fmap_rs = []
|
||||
fmap_gs = []
|
||||
for i, d in enumerate(self.discriminators):
|
||||
if i != 0:
|
||||
y = self.meanpools[i - 1](y)
|
||||
y_hat = self.meanpools[i - 1](y_hat)
|
||||
y_d_r, fmap_r = d(x=y)
|
||||
y_d_g, fmap_g = d(x=y_hat)
|
||||
y_d_rs.append(y_d_r)
|
||||
fmap_rs.append(fmap_r)
|
||||
y_d_gs.append(y_d_g)
|
||||
fmap_gs.append(fmap_g)
|
||||
|
||||
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
Loading…
Reference in a new issue