DeepLearningExamples/PyTorch/SpeechSynthesis/Tacotron2/models.py

148 lines
5.9 KiB
Python

# *****************************************************************************
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# * Neither the name of the NVIDIA CORPORATION nor the
# names of its contributors may be used to endorse or promote products
# derived from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
# DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
# (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
# ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
# *****************************************************************************
import sys
from os.path import abspath, dirname
# enabling modules discovery from global entrypoint
sys.path.append(abspath(dirname(__file__)+'/'))
from tacotron2.model import Tacotron2
from waveglow.model import WaveGlow
import torch
def parse_model_args(model_name, parser, add_help=False):
if model_name == 'Tacotron2':
from tacotron2.arg_parser import parse_tacotron2_args
return parse_tacotron2_args(parser, add_help)
if model_name == 'WaveGlow':
from waveglow.arg_parser import parse_waveglow_args
return parse_waveglow_args(parser, add_help)
else:
raise NotImplementedError(model_name)
def batchnorm_to_float(module):
"""Converts batch norm to FP32"""
if isinstance(module, torch.nn.modules.batchnorm._BatchNorm):
module.float()
for child in module.children():
batchnorm_to_float(child)
return module
def init_bn(module):
if isinstance(module, torch.nn.modules.batchnorm._BatchNorm):
if module.affine:
module.weight.data.uniform_()
for child in module.children():
init_bn(child)
def get_model(model_name, model_config, to_cuda,
uniform_initialize_bn_weight=False, forward_is_infer=False):
""" Code chooses a model based on name"""
model = None
if model_name == 'Tacotron2':
if forward_is_infer:
class Tacotron2__forward_is_infer(Tacotron2):
def forward(self, inputs, input_lengths):
return self.infer(inputs, input_lengths)
model = Tacotron2__forward_is_infer(**model_config)
else:
model = Tacotron2(**model_config)
elif model_name == 'WaveGlow':
if forward_is_infer:
class WaveGlow__forward_is_infer(WaveGlow):
def forward(self, spect, sigma=1.0):
return self.infer(spect, sigma)
model = WaveGlow__forward_is_infer(**model_config)
else:
model = WaveGlow(**model_config)
else:
raise NotImplementedError(model_name)
if uniform_initialize_bn_weight:
init_bn(model)
if to_cuda:
model = model.cuda()
return model
def get_model_config(model_name, args):
""" Code chooses a model based on name"""
if model_name == 'Tacotron2':
model_config = dict(
# optimization
mask_padding=args.mask_padding,
# audio
n_mel_channels=args.n_mel_channels,
# symbols
n_symbols=args.n_symbols,
symbols_embedding_dim=args.symbols_embedding_dim,
# encoder
encoder_kernel_size=args.encoder_kernel_size,
encoder_n_convolutions=args.encoder_n_convolutions,
encoder_embedding_dim=args.encoder_embedding_dim,
# attention
attention_rnn_dim=args.attention_rnn_dim,
attention_dim=args.attention_dim,
# attention location
attention_location_n_filters=args.attention_location_n_filters,
attention_location_kernel_size=args.attention_location_kernel_size,
# decoder
n_frames_per_step=args.n_frames_per_step,
decoder_rnn_dim=args.decoder_rnn_dim,
prenet_dim=args.prenet_dim,
max_decoder_steps=args.max_decoder_steps,
gate_threshold=args.gate_threshold,
p_attention_dropout=args.p_attention_dropout,
p_decoder_dropout=args.p_decoder_dropout,
# postnet
postnet_embedding_dim=args.postnet_embedding_dim,
postnet_kernel_size=args.postnet_kernel_size,
postnet_n_convolutions=args.postnet_n_convolutions,
decoder_no_early_stopping=args.decoder_no_early_stopping
)
return model_config
elif model_name == 'WaveGlow':
model_config = dict(
n_mel_channels=args.n_mel_channels,
n_flows=args.flows,
n_group=args.groups,
n_early_every=args.early_every,
n_early_size=args.early_size,
WN_config=dict(
n_layers=args.wn_layers,
kernel_size=args.wn_kernel_size,
n_channels=args.wn_channels
)
)
return model_config
else:
raise NotImplementedError(model_name)