caching of pre-trained weights added to entrypoints
This commit is contained in:
parent
14fe91fad5
commit
161a4ea165
23
hubconf.py
23
hubconf.py
|
@ -1,5 +1,7 @@
|
|||
import urllib.request
|
||||
import torch
|
||||
import os
|
||||
import sys
|
||||
|
||||
# from https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/SpeechSynthesis/Tacotron2/inference.py
|
||||
def checkpoint_from_distributed(state_dict):
|
||||
|
@ -54,6 +56,7 @@ def nvidia_ncf(pretrained=True, **kwargs):
|
|||
from PyTorch.Recommendation.NCF import neumf as ncf
|
||||
|
||||
fp16 = "model_math" in kwargs and kwargs["model_math"] == "fp16"
|
||||
force_reload = "force_reload" in kwargs and kwargs["force_reload"]
|
||||
|
||||
config = {'nb_users': None, 'nb_items': None, 'mf_dim': 64, 'mf_reg': 0.,
|
||||
'mlp_layer_sizes': [256, 256, 128, 64], 'mlp_layer_regs':[0, 0, 0, 0], 'dropout': 0.5}
|
||||
|
@ -63,8 +66,10 @@ def nvidia_ncf(pretrained=True, **kwargs):
|
|||
checkpoint = 'https://developer.nvidia.com/joc-ncf-fp16-pyt-20190225'
|
||||
else:
|
||||
checkpoint = 'https://developer.nvidia.com/joc-ncf-fp32-pyt-20190225'
|
||||
ckpt_file = "ncf_ckpt.pt"
|
||||
urllib.request.urlretrieve(checkpoint, ckpt_file)
|
||||
ckpt_file = os.path.basename(checkpoint)
|
||||
if not os.path.exists(ckpt_file) or force_reload:
|
||||
sys.stderr.write('Downloading checkpoint from {}\n'.format(checkpoint))
|
||||
urllib.request.urlretrieve(checkpoint, ckpt_file)
|
||||
ckpt = torch.load(ckpt_file)
|
||||
|
||||
if checkpoint_from_distributed(ckpt):
|
||||
|
@ -117,14 +122,17 @@ def nvidia_tacotron2(pretrained=True, **kwargs):
|
|||
from PyTorch.SpeechSynthesis.Tacotron2.models import lstmcell_to_float, batchnorm_to_float
|
||||
|
||||
fp16 = "model_math" in kwargs and kwargs["model_math"] == "fp16"
|
||||
force_reload = "force_reload" in kwargs and kwargs["force_reload"]
|
||||
|
||||
if pretrained:
|
||||
if fp16:
|
||||
checkpoint = 'https://developer.nvidia.com/joc-tacotron2-fp16-pyt-20190306'
|
||||
else:
|
||||
checkpoint = 'https://developer.nvidia.com/joc-tacotron2-fp32-pyt-20190306'
|
||||
ckpt_file = "tacotron2_ckpt.pt"
|
||||
urllib.request.urlretrieve(checkpoint, ckpt_file)
|
||||
ckpt_file = os.path.basename(checkpoint)
|
||||
if not os.path.exists(ckpt_file) or force_reload:
|
||||
sys.stderr.write('Downloading checkpoint from {}\n'.format(checkpoint))
|
||||
urllib.request.urlretrieve(checkpoint, ckpt_file)
|
||||
ckpt = torch.load(ckpt_file)
|
||||
state_dict = ckpt['state_dict']
|
||||
if checkpoint_from_distributed(state_dict):
|
||||
|
@ -172,14 +180,17 @@ def nvidia_waveglow(pretrained=True, **kwargs):
|
|||
from PyTorch.SpeechSynthesis.Tacotron2.models import batchnorm_to_float
|
||||
|
||||
fp16 = "model_math" in kwargs and kwargs["model_math"] == "fp16"
|
||||
force_reload = "force_reload" in kwargs and kwargs["force_reload"]
|
||||
|
||||
if pretrained:
|
||||
if fp16:
|
||||
checkpoint = 'https://developer.nvidia.com/joc-waveglow-fp16-pyt-20190306'
|
||||
else:
|
||||
checkpoint = 'https://developer.nvidia.com/joc-waveglow-fp32-pyt-20190306'
|
||||
ckpt_file = "waveglow_ckpt.pt"
|
||||
urllib.request.urlretrieve(checkpoint, ckpt_file)
|
||||
ckpt_file = os.path.basename(checkpoint)
|
||||
if not os.path.exists(ckpt_file) or force_reload:
|
||||
sys.stderr.write('Downloading checkpoint from {}\n'.format(checkpoint))
|
||||
urllib.request.urlretrieve(checkpoint, ckpt_file)
|
||||
ckpt = torch.load(ckpt_file)
|
||||
state_dict = ckpt['state_dict']
|
||||
if checkpoint_from_distributed(state_dict):
|
||||
|
|
Loading…
Reference in a new issue