tacotron2 and waveglow checkpoints hosted on ngc

This commit is contained in:
Krzysztof Kudrynski 2019-12-02 17:21:50 +01:00
parent 5639257b22
commit 4c12145030

View file

@ -3,6 +3,7 @@ import torch
import os
import sys
# from https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/SpeechSynthesis/Tacotron2/inference.py
def checkpoint_from_distributed(state_dict):
"""
@ -34,6 +35,7 @@ def unwrap_distributed(state_dict):
new_state_dict[new_key] = value
return new_state_dict
def _download_checkpoint(checkpoint, force_reload):
model_dir = os.path.join(torch.hub._get_torch_home(), 'checkpoints')
if not os.path.exists(model_dir):
@ -44,6 +46,7 @@ def _download_checkpoint(checkpoint, force_reload):
urllib.request.urlretrieve(checkpoint, ckpt_file)
return ckpt_file
dependencies = ['torch']
@ -133,9 +136,9 @@ def nvidia_tacotron2(pretrained=True, **kwargs):
if pretrained:
if fp16:
checkpoint = 'https://developer.nvidia.com/joc-tacotron2-fp16-pyt-20190306'
checkpoint = 'https://api.ngc.nvidia.com/v2/models/nvidia/tacotron2pyt_fp16/versions/1/files/nvidia_tacotron2pyt_fp16_20190306.pth'
else:
checkpoint = 'https://developer.nvidia.com/joc-tacotron2-fp32-pyt-20190306'
checkpoint = 'https://api.ngc.nvidia.com/v2/models/nvidia/tacotron2pyt_fp32/versions/1/files/nvidia_tacotron2pyt_fp32_20190306.pth'
ckpt_file = _download_checkpoint(checkpoint, force_reload)
ckpt = torch.load(ckpt_file)
state_dict = ckpt['state_dict']
@ -190,9 +193,9 @@ def nvidia_waveglow(pretrained=True, **kwargs):
if pretrained:
if fp16:
checkpoint = 'https://developer.nvidia.com/joc-waveglow-fp16-pyt-20190306'
checkpoint = 'https://api.ngc.nvidia.com/v2/models/nvidia/waveglowpyt_fp16/versions/1/files/nvidia_waveglowpyt_fp16_20190306.pth'
else:
checkpoint = 'https://developer.nvidia.com/joc-waveglow-fp32-pyt-20190306'
checkpoint = 'https://api.ngc.nvidia.com/v2/models/nvidia/waveglowpyt_fp32/versions/1/files/nvidia_waveglowpyt_fp32_20190306.pth'
ckpt_file = _download_checkpoint(checkpoint, force_reload)
ckpt = torch.load(ckpt_file)
state_dict = ckpt['state_dict']
@ -326,6 +329,7 @@ def nvidia_ssd_processing_utils():
return Processing()
def nvidia_ssd(pretrained=True, **kwargs):
"""Constructs an SSD300 model.
For detailed information on model input and output, training recipies, inference and performance
@ -360,6 +364,7 @@ def nvidia_ssd(pretrained=True, **kwargs):
checkpoint = 'https://api.ngc.nvidia.com/v2/models/nvidia/ssdpyt_fp16/versions/1/files/nvidia_ssdpyt_fp16_20190225.pt'
else:
checkpoint = 'https://api.ngc.nvidia.com/v2/models/nvidia/ssdpyt_fp32/versions/1/files/nvidia_ssdpyt_fp32_20190225.pt'
# ckpt = torch.hub.load_state_dict_from_url(checkpoint, progress=True, check_hash=False)
ckpt_file = _download_checkpoint(checkpoint, force_reload)
ckpt = torch.load(ckpt_file)
ckpt = ckpt['model']