torch.hub entrypoints for tacotron2, waveglow and ncf

This commit is contained in:
Krzysztof Kudrynski 2019-05-15 14:16:46 +02:00
parent a9d0554e5d
commit 49acab486b
7 changed files with 234 additions and 24 deletions

View file

@ -32,6 +32,11 @@ import numpy as np
import torch
import torch.nn as nn
import sys
from os.path import abspath, join, dirname
# enabling modules discovery from global entrypoint
sys.path.append(abspath(dirname(__file__)+'/'))
from logger.logger import LOGGER
from logger import tags

View file

@ -190,7 +190,7 @@ def main():
if args.tacotron2:
tacotron2_t0 = time.time()
with torch.no_grad():
_, mel, _, _ = tacotron2.inference(sequence)
_, mel, _, _ = tacotron2.infer(sequence)
tacotron2_t1 = time.time()
tacotron2_infer_perf = sequence.size(1)/(tacotron2_t1-tacotron2_t0)
LOGGER.log(key="tacotron2_items_per_sec", value=tacotron2_infer_perf)

View file

@ -141,7 +141,7 @@ def main():
t0 = time.time()
with torch.no_grad():
_, mels, _, _ = model.inference(text_padded)
_, mels, _, _ = model.infer(text_padded)
t1 = time.time()
inference_time= t1 - t0
num_items = text_padded.size(0)*text_padded.size(1)

View file

@ -27,22 +27,22 @@
from tacotron2.model import Tacotron2
from waveglow.model import WaveGlow
from tacotron2.arg_parser import parse_tacotron2_args
from waveglow.arg_parser import parse_waveglow_args
import torch
def parse_model_args(model_name, parser, add_help=False):
if model_name == 'Tacotron2':
from tacotron2.arg_parser import parse_tacotron2_args
return parse_tacotron2_args(parser, add_help)
if model_name == 'WaveGlow':
from waveglow.arg_parser import parse_waveglow_args
return parse_waveglow_args(parser, add_help)
else:
raise NotImplementedError(model_name)
def batchnorm_to_float(module):
"""Converts LSTMCells to FP32"""
"""Converts batch norm to FP32"""
if isinstance(module, torch.nn.modules.batchnorm._BatchNorm):
module.float()
for child in module.children():
@ -51,7 +51,7 @@ def batchnorm_to_float(module):
def lstmcell_to_float(module):
"""Converts batch norm modules to FP32"""
"""Converts LSTMCells modules to FP32"""
if isinstance(module, torch.nn.LSTMCell):
module.float()
for child in module.children():

View file

@ -30,6 +30,10 @@ import torch
from torch.autograd import Variable
from torch import nn
from torch.nn import functional as F
import sys
from os.path import abspath, dirname
# enabling modules discovery from global entrypoint
sys.path.append(abspath(dirname(__file__)+'/../'))
from common.layers import ConvNorm, LinearNorm
from common.utils import to_gpu, get_mask_from_lengths
@ -375,7 +379,7 @@ class Decoder(nn.Module):
return mel_outputs, gate_outputs, alignments
def decode(self, decoder_input, is_infer=False):
def decode(self, decoder_input):
""" Decoder step using stored states, attention and memory
PARAMS
------
@ -390,12 +394,8 @@ class Decoder(nn.Module):
cell_input = torch.cat((decoder_input, self.attention_context), -1)
attention_hidden_dtype = self.attention_hidden.dtype
if is_infer:
self.attention_hidden, self.attention_cell = self.attention_rnn(
cell_input, (self.attention_hidden, self.attention_cell))
else:
self.attention_hidden, self.attention_cell = self.attention_rnn(
cell_input.float(), (self.attention_hidden.float(), self.attention_cell.float()))
self.attention_hidden, self.attention_cell = self.attention_rnn(
cell_input.float(), (self.attention_hidden.float(), self.attention_cell.float()))
self.attention_hidden = F.dropout(
self.attention_hidden, self.p_attention_dropout, self.training)
@ -418,13 +418,9 @@ class Decoder(nn.Module):
(self.attention_hidden, self.attention_context), -1)
decoder_hidden_dtype = self.decoder_hidden.dtype
if is_infer:
self.decoder_hidden, self.decoder_cell = self.decoder_rnn(
decoder_input, (self.decoder_hidden, self.decoder_cell))
else:
self.decoder_hidden, self.decoder_cell = self.decoder_rnn(
decoder_input.float(), (self.decoder_hidden.float(), self.decoder_cell.float()))
self.decoder_hidden, self.decoder_cell = self.decoder_rnn(
decoder_input.float(), (self.decoder_hidden.float(), self.decoder_cell.float()))
self.decoder_hidden = F.dropout(
self.decoder_hidden, self.p_decoder_dropout, self.training)
@ -539,7 +535,7 @@ class Decoder(nn.Module):
mel_outputs, gate_outputs, alignments = [], [], []
while True:
decoder_input = self.prenet(decoder_input)
mel_output, gate_output, alignment = self.decode(decoder_input, is_infer=True)
mel_output, gate_output, alignment = self.decode(decoder_input)
mel_outputs += [mel_output.squeeze(1)]
gate_outputs += [gate_output]
@ -647,7 +643,7 @@ class Tacotron2(nn.Module):
[mel_outputs, mel_outputs_postnet, gate_outputs, alignments],
output_lengths)
def inference(self, inputs):
def infer(self, inputs):
inputs = self.parse_input(inputs)
embedded_inputs = self.embedding(inputs).transpose(1, 2)
encoder_outputs = self.encoder.inference(embedded_inputs)

View file

@ -240,7 +240,7 @@ def save_sample(model_name, model, waveglow_path, tacotron2_path, phrase_path, f
'WaveGlow', checkpoint['config'], to_fp16=False, to_cuda=False)
waveglow.eval()
model.eval()
mel = model.inference(phrase.cuda())[0].cpu()
mel = model.infer(phrase.cuda())[0].cpu()
model.train()
if fp16:
mel = mel.float()
@ -254,7 +254,7 @@ def save_sample(model_name, model, waveglow_path, tacotron2_path, phrase_path, f
tacotron2 = models.get_model(
'Tacotron2', checkpoint['config'], to_fp16=False, to_cuda=False)
tacotron2.eval()
mel = tacotron2.inference(phrase)[0].cuda()
mel = tacotron2.infer(phrase)[0].cuda()
model.eval()
if fp16:
mel = mel.half()

209
hubconf.py Normal file
View file

@ -0,0 +1,209 @@
import urllib.request
import torch
# from https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/SpeechSynthesis/Tacotron2/inference.py
def checkpoint_from_distributed(state_dict):
"""
Checks whether checkpoint was generated by DistributedDataParallel. DDP
wraps model in additional "module.", it needs to be unwrapped for single
GPU inference.
:param state_dict: model's state dict
"""
ret = False
for key, _ in state_dict.items():
if key.find('module.') != -1:
ret = True
break
return ret
# from https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/SpeechSynthesis/Tacotron2/inference.py
def unwrap_distributed(state_dict):
"""
Unwraps model from DistributedDataParallel.
DDP wraps model in additional "module.", it needs to be removed for single
GPU inference.
:param state_dict: model's state dict
"""
new_state_dict = {}
for key, value in state_dict.items():
new_key = key.replace('module.1.', '')
new_key = new_key.replace('module.', '')
new_state_dict[new_key] = value
return new_state_dict
dependencies = ['torch']
def nvidia_ncf(pretrained=True, **kwargs):
"""Constructs an NCF model.
For detailed information on model input and output, training recipies, inference and performance
visit: github.com/NVIDIA/DeepLearningExamples and/or ngc.nvidia.com
Args:
pretrained (bool, True): If True, returns a model pretrained on ml-20m dataset.
model_math (str, 'fp32'): returns a model in given precision ('fp32' or 'fp16')
nb_users (int): number of users
nb_items (int): number of items
mf_dim (int, 64): dimension of latent space in matrix factorization
mlp_layer_sizes (list, [256,256,128,64]): sizes of layers of multi-layer-perceptron
dropout (float, 0.5): dropout
"""
from PyTorch.Recommendation.NCF import neumf as ncf
fp16 = "model_math" in kwargs and kwargs["model_math"] == "fp16"
config = {'nb_users': None, 'nb_items': None, 'mf_dim': 64, 'mf_reg': 0.,
'mlp_layer_sizes': [256, 256, 128, 64], 'mlp_layer_regs':[0, 0, 0, 0], 'dropout': 0.5}
if pretrained:
if fp16:
checkpoint = 'https://developer.nvidia.com/joc-ncf-fp16-pyt-20190225'
else:
checkpoint = 'https://developer.nvidia.com/joc-ncf-fp32-pyt-20190225'
ckpt_file = "ncf_ckpt.pt"
urllib.request.urlretrieve(checkpoint, ckpt_file)
ckpt = torch.load(ckpt_file)
if checkpoint_from_distributed(ckpt):
ckpt = unwrap_distributed(ckpt)
config['nb_users'] = ckpt['mf_user_embed.weight'].shape[0]
config['nb_items'] = ckpt['mf_item_embed.weight'].shape[0]
config['mf_dim'] = ckpt['mf_item_embed.weight'].shape[1]
mlp_shapes = [ckpt[k].shape for k in ckpt.keys() if 'mlp' in k and 'weight' in k and 'embed' not in k]
config['mlp_layer_sizes'] = [mlp_shapes[0][1], mlp_shapes[1][1], mlp_shapes[2][1], mlp_shapes[2][0]]
config['mlp_layer_regs'] = [0] * len(config['mlp_layer_sizes'])
else:
if 'nb_users' not in kwargs:
raise ValueError("Missing 'nb_users' argument.")
if 'nb_items' not in kwargs:
raise ValueError("Missing 'nb_items' argument.")
for k,v in kwargs.items():
if k in config.keys():
config[k] = v
config['mlp_layer_regs'] = [0] * len(config['mlp_layer_sizes'])
m = ncf.NeuMF(**config)
if fp16:
m.half()
if pretrained:
m.load_state_dict(ckpt)
return m
def nvidia_tacotron2(pretrained=True, **kwargs):
"""Constructs a Tacotron 2 model (nn.module with additional infer(input) method).
For detailed information on model input and output, training recipies, inference and performance
visit: github.com/NVIDIA/DeepLearningExamples and/or ngc.nvidia.com
Args (type[, default value]):
pretrained (bool, True): If True, returns a model pretrained on LJ Speech dataset.
model_math (str, 'fp32'): returns a model in given precision ('fp32' or 'fp16')
n_symbols (int, 148): Number of symbols used in a sequence passed to the prenet, see
https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/SpeechSynthesis/Tacotron2/tacotron2/text/symbols.py
p_attention_dropout (float, 0.1): dropout probability on attention LSTM (1st LSTM layer in decoder)
p_decoder_dropout (float, 0.1): dropout probability on decoder LSTM (2nd LSTM layer in decoder)
max_decoder_steps (int, 1000): maximum number of generated mel spectrograms during inference
"""
from PyTorch.SpeechSynthesis.Tacotron2.tacotron2 import model as tacotron2
from PyTorch.SpeechSynthesis.Tacotron2.models import lstmcell_to_float, batchnorm_to_float
fp16 = "model_math" in kwargs and kwargs["model_math"] == "fp16"
if pretrained:
if fp16:
checkpoint = 'https://developer.nvidia.com/joc-tacotron2-fp16-pyt-20190306'
else:
checkpoint = 'https://developer.nvidia.com/joc-tacotron2-fp32-pyt-20190306'
ckpt_file = "tacotron2_ckpt.pt"
urllib.request.urlretrieve(checkpoint, ckpt_file)
ckpt = torch.load(ckpt_file)
state_dict = ckpt['state_dict']
if checkpoint_from_distributed(state_dict):
state_dict = unwrap_distributed(state_dict)
config = ckpt['config']
else:
config = {'mask_padding': False, 'n_mel_channels': 80, 'n_symbols': 148,
'symbols_embedding_dim': 512, 'encoder_kernel_size': 5,
'encoder_n_convolutions': 3, 'encoder_embedding_dim': 512,
'attention_rnn_dim': 1024, 'attention_dim': 128,
'attention_location_n_filters': 32,
'attention_location_kernel_size': 31, 'n_frames_per_step': 1,
'decoder_rnn_dim': 1024, 'prenet_dim': 256,
'max_decoder_steps': 1000, 'gate_threshold': 0.5,
'p_attention_dropout': 0.1, 'p_decoder_dropout': 0.1,
'postnet_embedding_dim': 512, 'postnet_kernel_size': 5,
'postnet_n_convolutions': 5, 'decoder_no_early_stopping': False}
for k,v in kwargs.items():
if k in config.keys():
config[k] = v
m = tacotron2.Tacotron2(**config)
if fp16:
m = batchnorm_to_float(m.half())
m = lstmcell_to_float(m)
if pretrained:
m.load_state_dict(state_dict)
return m
def nvidia_waveglow(pretrained=True, **kwargs):
"""Constructs a WaveGlow model (nn.module with additional infer(input) method).
For detailed information on model input and output, training recipies, inference and performance
visit: github.com/NVIDIA/DeepLearningExamples and/or ngc.nvidia.com
Args:
pretrained (bool): If True, returns a model pretrained on LJ Speech dataset.
model_math (str, 'fp32'): returns a model in given precision ('fp32' or 'fp16')
"""
from PyTorch.SpeechSynthesis.Tacotron2.waveglow import model as waveglow
from PyTorch.SpeechSynthesis.Tacotron2.models import batchnorm_to_float
fp16 = "model_math" in kwargs and kwargs["model_math"] == "fp16"
if pretrained:
if fp16:
checkpoint = 'https://developer.nvidia.com/joc-waveglow-fp16-pyt-20190306'
else:
checkpoint = 'https://developer.nvidia.com/joc-waveglow-fp32-pyt-20190306'
ckpt_file = "waveglow_ckpt.pt"
urllib.request.urlretrieve(checkpoint, ckpt_file)
ckpt = torch.load(ckpt_file)
state_dict = ckpt['state_dict']
if checkpoint_from_distributed(state_dict):
state_dict = unwrap_distributed(state_dict)
config = ckpt['config']
else:
config = {'n_mel_channels': 80, 'n_flows': 12, 'n_group': 8,
'n_early_every': 4, 'n_early_size': 2,
'WN_config': {'n_layers': 8, 'kernel_size': 3,
'n_channels': 512}}
for k,v in kwargs.items():
if k in config.keys():
config[k] = v
elif k in config['WN_config'].keys():
config['WN_config'][k] = v
m = waveglow.WaveGlow(**config)
if fp16:
m = batchnorm_to_float(m.half())
for mat in m.convinv:
mat.float()
if pretrained:
m.load_state_dict(state_dict)
return m