diff --git a/PyTorch/Recommendation/NCF/neumf.py b/PyTorch/Recommendation/NCF/neumf.py index c62bfbbc..60dbfb54 100644 --- a/PyTorch/Recommendation/NCF/neumf.py +++ b/PyTorch/Recommendation/NCF/neumf.py @@ -32,6 +32,11 @@ import numpy as np import torch import torch.nn as nn +import sys +from os.path import abspath, join, dirname +# enabling modules discovery from global entrypoint +sys.path.append(abspath(dirname(__file__)+'/')) + from logger.logger import LOGGER from logger import tags diff --git a/PyTorch/SpeechSynthesis/Tacotron2/inference.py b/PyTorch/SpeechSynthesis/Tacotron2/inference.py index b73682d9..c2a64c16 100644 --- a/PyTorch/SpeechSynthesis/Tacotron2/inference.py +++ b/PyTorch/SpeechSynthesis/Tacotron2/inference.py @@ -190,7 +190,7 @@ def main(): if args.tacotron2: tacotron2_t0 = time.time() with torch.no_grad(): - _, mel, _, _ = tacotron2.inference(sequence) + _, mel, _, _ = tacotron2.infer(sequence) tacotron2_t1 = time.time() tacotron2_infer_perf = sequence.size(1)/(tacotron2_t1-tacotron2_t0) LOGGER.log(key="tacotron2_items_per_sec", value=tacotron2_infer_perf) diff --git a/PyTorch/SpeechSynthesis/Tacotron2/inference_perf.py b/PyTorch/SpeechSynthesis/Tacotron2/inference_perf.py index 20bf6096..21028253 100644 --- a/PyTorch/SpeechSynthesis/Tacotron2/inference_perf.py +++ b/PyTorch/SpeechSynthesis/Tacotron2/inference_perf.py @@ -141,7 +141,7 @@ def main(): t0 = time.time() with torch.no_grad(): - _, mels, _, _ = model.inference(text_padded) + _, mels, _, _ = model.infer(text_padded) t1 = time.time() inference_time= t1 - t0 num_items = text_padded.size(0)*text_padded.size(1) diff --git a/PyTorch/SpeechSynthesis/Tacotron2/models.py b/PyTorch/SpeechSynthesis/Tacotron2/models.py index 6c3eb7fb..12298d79 100644 --- a/PyTorch/SpeechSynthesis/Tacotron2/models.py +++ b/PyTorch/SpeechSynthesis/Tacotron2/models.py @@ -27,22 +27,22 @@ from tacotron2.model import Tacotron2 from waveglow.model import WaveGlow -from tacotron2.arg_parser import parse_tacotron2_args -from waveglow.arg_parser import parse_waveglow_args import torch def parse_model_args(model_name, parser, add_help=False): if model_name == 'Tacotron2': + from tacotron2.arg_parser import parse_tacotron2_args return parse_tacotron2_args(parser, add_help) if model_name == 'WaveGlow': + from waveglow.arg_parser import parse_waveglow_args return parse_waveglow_args(parser, add_help) else: raise NotImplementedError(model_name) def batchnorm_to_float(module): - """Converts LSTMCells to FP32""" + """Converts batch norm to FP32""" if isinstance(module, torch.nn.modules.batchnorm._BatchNorm): module.float() for child in module.children(): @@ -51,7 +51,7 @@ def batchnorm_to_float(module): def lstmcell_to_float(module): - """Converts batch norm modules to FP32""" + """Converts LSTMCells modules to FP32""" if isinstance(module, torch.nn.LSTMCell): module.float() for child in module.children(): diff --git a/PyTorch/SpeechSynthesis/Tacotron2/tacotron2/model.py b/PyTorch/SpeechSynthesis/Tacotron2/tacotron2/model.py index d4b939bb..49ce684d 100644 --- a/PyTorch/SpeechSynthesis/Tacotron2/tacotron2/model.py +++ b/PyTorch/SpeechSynthesis/Tacotron2/tacotron2/model.py @@ -30,6 +30,10 @@ import torch from torch.autograd import Variable from torch import nn from torch.nn import functional as F +import sys +from os.path import abspath, dirname +# enabling modules discovery from global entrypoint +sys.path.append(abspath(dirname(__file__)+'/../')) from common.layers import ConvNorm, LinearNorm from common.utils import to_gpu, get_mask_from_lengths @@ -375,7 +379,7 @@ class Decoder(nn.Module): return mel_outputs, gate_outputs, alignments - def decode(self, decoder_input, is_infer=False): + def decode(self, decoder_input): """ Decoder step using stored states, attention and memory PARAMS ------ @@ -390,12 +394,8 @@ class Decoder(nn.Module): cell_input = torch.cat((decoder_input, self.attention_context), -1) attention_hidden_dtype = self.attention_hidden.dtype - if is_infer: - self.attention_hidden, self.attention_cell = self.attention_rnn( - cell_input, (self.attention_hidden, self.attention_cell)) - else: - self.attention_hidden, self.attention_cell = self.attention_rnn( - cell_input.float(), (self.attention_hidden.float(), self.attention_cell.float())) + self.attention_hidden, self.attention_cell = self.attention_rnn( + cell_input.float(), (self.attention_hidden.float(), self.attention_cell.float())) self.attention_hidden = F.dropout( self.attention_hidden, self.p_attention_dropout, self.training) @@ -418,13 +418,9 @@ class Decoder(nn.Module): (self.attention_hidden, self.attention_context), -1) decoder_hidden_dtype = self.decoder_hidden.dtype - if is_infer: - self.decoder_hidden, self.decoder_cell = self.decoder_rnn( - decoder_input, (self.decoder_hidden, self.decoder_cell)) - else: - self.decoder_hidden, self.decoder_cell = self.decoder_rnn( - decoder_input.float(), (self.decoder_hidden.float(), self.decoder_cell.float())) - + self.decoder_hidden, self.decoder_cell = self.decoder_rnn( + decoder_input.float(), (self.decoder_hidden.float(), self.decoder_cell.float())) + self.decoder_hidden = F.dropout( self.decoder_hidden, self.p_decoder_dropout, self.training) @@ -539,7 +535,7 @@ class Decoder(nn.Module): mel_outputs, gate_outputs, alignments = [], [], [] while True: decoder_input = self.prenet(decoder_input) - mel_output, gate_output, alignment = self.decode(decoder_input, is_infer=True) + mel_output, gate_output, alignment = self.decode(decoder_input) mel_outputs += [mel_output.squeeze(1)] gate_outputs += [gate_output] @@ -647,7 +643,7 @@ class Tacotron2(nn.Module): [mel_outputs, mel_outputs_postnet, gate_outputs, alignments], output_lengths) - def inference(self, inputs): + def infer(self, inputs): inputs = self.parse_input(inputs) embedded_inputs = self.embedding(inputs).transpose(1, 2) encoder_outputs = self.encoder.inference(embedded_inputs) diff --git a/PyTorch/SpeechSynthesis/Tacotron2/train.py b/PyTorch/SpeechSynthesis/Tacotron2/train.py index cbca9abc..2fbe9a22 100644 --- a/PyTorch/SpeechSynthesis/Tacotron2/train.py +++ b/PyTorch/SpeechSynthesis/Tacotron2/train.py @@ -240,7 +240,7 @@ def save_sample(model_name, model, waveglow_path, tacotron2_path, phrase_path, f 'WaveGlow', checkpoint['config'], to_fp16=False, to_cuda=False) waveglow.eval() model.eval() - mel = model.inference(phrase.cuda())[0].cpu() + mel = model.infer(phrase.cuda())[0].cpu() model.train() if fp16: mel = mel.float() @@ -254,7 +254,7 @@ def save_sample(model_name, model, waveglow_path, tacotron2_path, phrase_path, f tacotron2 = models.get_model( 'Tacotron2', checkpoint['config'], to_fp16=False, to_cuda=False) tacotron2.eval() - mel = tacotron2.inference(phrase)[0].cuda() + mel = tacotron2.infer(phrase)[0].cuda() model.eval() if fp16: mel = mel.half() diff --git a/hubconf.py b/hubconf.py new file mode 100644 index 00000000..e01157d0 --- /dev/null +++ b/hubconf.py @@ -0,0 +1,209 @@ +import urllib.request +import torch + +# from https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/SpeechSynthesis/Tacotron2/inference.py +def checkpoint_from_distributed(state_dict): + """ + Checks whether checkpoint was generated by DistributedDataParallel. DDP + wraps model in additional "module.", it needs to be unwrapped for single + GPU inference. + :param state_dict: model's state dict + """ + ret = False + for key, _ in state_dict.items(): + if key.find('module.') != -1: + ret = True + break + return ret + + +# from https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/SpeechSynthesis/Tacotron2/inference.py +def unwrap_distributed(state_dict): + """ + Unwraps model from DistributedDataParallel. + DDP wraps model in additional "module.", it needs to be removed for single + GPU inference. + :param state_dict: model's state dict + """ + new_state_dict = {} + for key, value in state_dict.items(): + new_key = key.replace('module.1.', '') + new_key = new_key.replace('module.', '') + new_state_dict[new_key] = value + return new_state_dict + + +dependencies = ['torch'] + + +def nvidia_ncf(pretrained=True, **kwargs): + """Constructs an NCF model. + For detailed information on model input and output, training recipies, inference and performance + visit: github.com/NVIDIA/DeepLearningExamples and/or ngc.nvidia.com + + Args: + pretrained (bool, True): If True, returns a model pretrained on ml-20m dataset. + model_math (str, 'fp32'): returns a model in given precision ('fp32' or 'fp16') + nb_users (int): number of users + nb_items (int): number of items + mf_dim (int, 64): dimension of latent space in matrix factorization + mlp_layer_sizes (list, [256,256,128,64]): sizes of layers of multi-layer-perceptron + dropout (float, 0.5): dropout + """ + + from PyTorch.Recommendation.NCF import neumf as ncf + + fp16 = "model_math" in kwargs and kwargs["model_math"] == "fp16" + + config = {'nb_users': None, 'nb_items': None, 'mf_dim': 64, 'mf_reg': 0., + 'mlp_layer_sizes': [256, 256, 128, 64], 'mlp_layer_regs':[0, 0, 0, 0], 'dropout': 0.5} + + if pretrained: + if fp16: + checkpoint = 'https://developer.nvidia.com/joc-ncf-fp16-pyt-20190225' + else: + checkpoint = 'https://developer.nvidia.com/joc-ncf-fp32-pyt-20190225' + ckpt_file = "ncf_ckpt.pt" + urllib.request.urlretrieve(checkpoint, ckpt_file) + ckpt = torch.load(ckpt_file) + + if checkpoint_from_distributed(ckpt): + ckpt = unwrap_distributed(ckpt) + + config['nb_users'] = ckpt['mf_user_embed.weight'].shape[0] + config['nb_items'] = ckpt['mf_item_embed.weight'].shape[0] + config['mf_dim'] = ckpt['mf_item_embed.weight'].shape[1] + mlp_shapes = [ckpt[k].shape for k in ckpt.keys() if 'mlp' in k and 'weight' in k and 'embed' not in k] + config['mlp_layer_sizes'] = [mlp_shapes[0][1], mlp_shapes[1][1], mlp_shapes[2][1], mlp_shapes[2][0]] + config['mlp_layer_regs'] = [0] * len(config['mlp_layer_sizes']) + + else: + if 'nb_users' not in kwargs: + raise ValueError("Missing 'nb_users' argument.") + if 'nb_items' not in kwargs: + raise ValueError("Missing 'nb_items' argument.") + for k,v in kwargs.items(): + if k in config.keys(): + config[k] = v + config['mlp_layer_regs'] = [0] * len(config['mlp_layer_sizes']) + + m = ncf.NeuMF(**config) + + if fp16: + m.half() + + if pretrained: + m.load_state_dict(ckpt) + + return m + + +def nvidia_tacotron2(pretrained=True, **kwargs): + """Constructs a Tacotron 2 model (nn.module with additional infer(input) method). + For detailed information on model input and output, training recipies, inference and performance + visit: github.com/NVIDIA/DeepLearningExamples and/or ngc.nvidia.com + + Args (type[, default value]): + pretrained (bool, True): If True, returns a model pretrained on LJ Speech dataset. + model_math (str, 'fp32'): returns a model in given precision ('fp32' or 'fp16') + n_symbols (int, 148): Number of symbols used in a sequence passed to the prenet, see + https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/SpeechSynthesis/Tacotron2/tacotron2/text/symbols.py + p_attention_dropout (float, 0.1): dropout probability on attention LSTM (1st LSTM layer in decoder) + p_decoder_dropout (float, 0.1): dropout probability on decoder LSTM (2nd LSTM layer in decoder) + max_decoder_steps (int, 1000): maximum number of generated mel spectrograms during inference + """ + + from PyTorch.SpeechSynthesis.Tacotron2.tacotron2 import model as tacotron2 + from PyTorch.SpeechSynthesis.Tacotron2.models import lstmcell_to_float, batchnorm_to_float + + fp16 = "model_math" in kwargs and kwargs["model_math"] == "fp16" + + if pretrained: + if fp16: + checkpoint = 'https://developer.nvidia.com/joc-tacotron2-fp16-pyt-20190306' + else: + checkpoint = 'https://developer.nvidia.com/joc-tacotron2-fp32-pyt-20190306' + ckpt_file = "tacotron2_ckpt.pt" + urllib.request.urlretrieve(checkpoint, ckpt_file) + ckpt = torch.load(ckpt_file) + state_dict = ckpt['state_dict'] + if checkpoint_from_distributed(state_dict): + state_dict = unwrap_distributed(state_dict) + config = ckpt['config'] + else: + config = {'mask_padding': False, 'n_mel_channels': 80, 'n_symbols': 148, + 'symbols_embedding_dim': 512, 'encoder_kernel_size': 5, + 'encoder_n_convolutions': 3, 'encoder_embedding_dim': 512, + 'attention_rnn_dim': 1024, 'attention_dim': 128, + 'attention_location_n_filters': 32, + 'attention_location_kernel_size': 31, 'n_frames_per_step': 1, + 'decoder_rnn_dim': 1024, 'prenet_dim': 256, + 'max_decoder_steps': 1000, 'gate_threshold': 0.5, + 'p_attention_dropout': 0.1, 'p_decoder_dropout': 0.1, + 'postnet_embedding_dim': 512, 'postnet_kernel_size': 5, + 'postnet_n_convolutions': 5, 'decoder_no_early_stopping': False} + for k,v in kwargs.items(): + if k in config.keys(): + config[k] = v + + m = tacotron2.Tacotron2(**config) + + if fp16: + m = batchnorm_to_float(m.half()) + m = lstmcell_to_float(m) + + if pretrained: + m.load_state_dict(state_dict) + + return m + + +def nvidia_waveglow(pretrained=True, **kwargs): + """Constructs a WaveGlow model (nn.module with additional infer(input) method). + For detailed information on model input and output, training recipies, inference and performance + visit: github.com/NVIDIA/DeepLearningExamples and/or ngc.nvidia.com + + Args: + pretrained (bool): If True, returns a model pretrained on LJ Speech dataset. + model_math (str, 'fp32'): returns a model in given precision ('fp32' or 'fp16') + """ + + from PyTorch.SpeechSynthesis.Tacotron2.waveglow import model as waveglow + from PyTorch.SpeechSynthesis.Tacotron2.models import batchnorm_to_float + + fp16 = "model_math" in kwargs and kwargs["model_math"] == "fp16" + + if pretrained: + if fp16: + checkpoint = 'https://developer.nvidia.com/joc-waveglow-fp16-pyt-20190306' + else: + checkpoint = 'https://developer.nvidia.com/joc-waveglow-fp32-pyt-20190306' + ckpt_file = "waveglow_ckpt.pt" + urllib.request.urlretrieve(checkpoint, ckpt_file) + ckpt = torch.load(ckpt_file) + state_dict = ckpt['state_dict'] + if checkpoint_from_distributed(state_dict): + state_dict = unwrap_distributed(state_dict) + config = ckpt['config'] + else: + config = {'n_mel_channels': 80, 'n_flows': 12, 'n_group': 8, + 'n_early_every': 4, 'n_early_size': 2, + 'WN_config': {'n_layers': 8, 'kernel_size': 3, + 'n_channels': 512}} + for k,v in kwargs.items(): + if k in config.keys(): + config[k] = v + elif k in config['WN_config'].keys(): + config['WN_config'][k] = v + + m = waveglow.WaveGlow(**config) + + if fp16: + m = batchnorm_to_float(m.half()) + for mat in m.convinv: + mat.float() + + if pretrained: + m.load_state_dict(state_dict) + + return m