tacotron2 and waveglow checkpoints hosted on ngc
This commit is contained in:
parent
5639257b22
commit
4c12145030
13
hubconf.py
13
hubconf.py
|
@ -3,6 +3,7 @@ import torch
|
|||
import os
|
||||
import sys
|
||||
|
||||
|
||||
# from https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/SpeechSynthesis/Tacotron2/inference.py
|
||||
def checkpoint_from_distributed(state_dict):
|
||||
"""
|
||||
|
@ -34,6 +35,7 @@ def unwrap_distributed(state_dict):
|
|||
new_state_dict[new_key] = value
|
||||
return new_state_dict
|
||||
|
||||
|
||||
def _download_checkpoint(checkpoint, force_reload):
|
||||
model_dir = os.path.join(torch.hub._get_torch_home(), 'checkpoints')
|
||||
if not os.path.exists(model_dir):
|
||||
|
@ -44,6 +46,7 @@ def _download_checkpoint(checkpoint, force_reload):
|
|||
urllib.request.urlretrieve(checkpoint, ckpt_file)
|
||||
return ckpt_file
|
||||
|
||||
|
||||
dependencies = ['torch']
|
||||
|
||||
|
||||
|
@ -133,9 +136,9 @@ def nvidia_tacotron2(pretrained=True, **kwargs):
|
|||
|
||||
if pretrained:
|
||||
if fp16:
|
||||
checkpoint = 'https://developer.nvidia.com/joc-tacotron2-fp16-pyt-20190306'
|
||||
checkpoint = 'https://api.ngc.nvidia.com/v2/models/nvidia/tacotron2pyt_fp16/versions/1/files/nvidia_tacotron2pyt_fp16_20190306.pth'
|
||||
else:
|
||||
checkpoint = 'https://developer.nvidia.com/joc-tacotron2-fp32-pyt-20190306'
|
||||
checkpoint = 'https://api.ngc.nvidia.com/v2/models/nvidia/tacotron2pyt_fp32/versions/1/files/nvidia_tacotron2pyt_fp32_20190306.pth'
|
||||
ckpt_file = _download_checkpoint(checkpoint, force_reload)
|
||||
ckpt = torch.load(ckpt_file)
|
||||
state_dict = ckpt['state_dict']
|
||||
|
@ -190,9 +193,9 @@ def nvidia_waveglow(pretrained=True, **kwargs):
|
|||
|
||||
if pretrained:
|
||||
if fp16:
|
||||
checkpoint = 'https://developer.nvidia.com/joc-waveglow-fp16-pyt-20190306'
|
||||
checkpoint = 'https://api.ngc.nvidia.com/v2/models/nvidia/waveglowpyt_fp16/versions/1/files/nvidia_waveglowpyt_fp16_20190306.pth'
|
||||
else:
|
||||
checkpoint = 'https://developer.nvidia.com/joc-waveglow-fp32-pyt-20190306'
|
||||
checkpoint = 'https://api.ngc.nvidia.com/v2/models/nvidia/waveglowpyt_fp32/versions/1/files/nvidia_waveglowpyt_fp32_20190306.pth'
|
||||
ckpt_file = _download_checkpoint(checkpoint, force_reload)
|
||||
ckpt = torch.load(ckpt_file)
|
||||
state_dict = ckpt['state_dict']
|
||||
|
@ -326,6 +329,7 @@ def nvidia_ssd_processing_utils():
|
|||
|
||||
return Processing()
|
||||
|
||||
|
||||
def nvidia_ssd(pretrained=True, **kwargs):
|
||||
"""Constructs an SSD300 model.
|
||||
For detailed information on model input and output, training recipies, inference and performance
|
||||
|
@ -360,6 +364,7 @@ def nvidia_ssd(pretrained=True, **kwargs):
|
|||
checkpoint = 'https://api.ngc.nvidia.com/v2/models/nvidia/ssdpyt_fp16/versions/1/files/nvidia_ssdpyt_fp16_20190225.pt'
|
||||
else:
|
||||
checkpoint = 'https://api.ngc.nvidia.com/v2/models/nvidia/ssdpyt_fp32/versions/1/files/nvidia_ssdpyt_fp32_20190225.pt'
|
||||
# ckpt = torch.hub.load_state_dict_from_url(checkpoint, progress=True, check_hash=False)
|
||||
ckpt_file = _download_checkpoint(checkpoint, force_reload)
|
||||
ckpt = torch.load(ckpt_file)
|
||||
ckpt = ckpt['model']
|
||||
|
|
Loading…
Reference in a new issue