caching of pre-trained weights added to entrypoints

This commit is contained in:
Krzysztof Kudrynski 2019-05-21 15:10:38 +02:00
parent 14fe91fad5
commit 161a4ea165

View file

@ -1,5 +1,7 @@
import urllib.request
import torch
import os
import sys
# from https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/SpeechSynthesis/Tacotron2/inference.py
def checkpoint_from_distributed(state_dict):
@ -54,6 +56,7 @@ def nvidia_ncf(pretrained=True, **kwargs):
from PyTorch.Recommendation.NCF import neumf as ncf
fp16 = "model_math" in kwargs and kwargs["model_math"] == "fp16"
force_reload = "force_reload" in kwargs and kwargs["force_reload"]
config = {'nb_users': None, 'nb_items': None, 'mf_dim': 64, 'mf_reg': 0.,
'mlp_layer_sizes': [256, 256, 128, 64], 'mlp_layer_regs':[0, 0, 0, 0], 'dropout': 0.5}
@ -63,8 +66,10 @@ def nvidia_ncf(pretrained=True, **kwargs):
checkpoint = 'https://developer.nvidia.com/joc-ncf-fp16-pyt-20190225'
else:
checkpoint = 'https://developer.nvidia.com/joc-ncf-fp32-pyt-20190225'
ckpt_file = "ncf_ckpt.pt"
urllib.request.urlretrieve(checkpoint, ckpt_file)
ckpt_file = os.path.basename(checkpoint)
if not os.path.exists(ckpt_file) or force_reload:
sys.stderr.write('Downloading checkpoint from {}\n'.format(checkpoint))
urllib.request.urlretrieve(checkpoint, ckpt_file)
ckpt = torch.load(ckpt_file)
if checkpoint_from_distributed(ckpt):
@ -117,14 +122,17 @@ def nvidia_tacotron2(pretrained=True, **kwargs):
from PyTorch.SpeechSynthesis.Tacotron2.models import lstmcell_to_float, batchnorm_to_float
fp16 = "model_math" in kwargs and kwargs["model_math"] == "fp16"
force_reload = "force_reload" in kwargs and kwargs["force_reload"]
if pretrained:
if fp16:
checkpoint = 'https://developer.nvidia.com/joc-tacotron2-fp16-pyt-20190306'
else:
checkpoint = 'https://developer.nvidia.com/joc-tacotron2-fp32-pyt-20190306'
ckpt_file = "tacotron2_ckpt.pt"
urllib.request.urlretrieve(checkpoint, ckpt_file)
ckpt_file = os.path.basename(checkpoint)
if not os.path.exists(ckpt_file) or force_reload:
sys.stderr.write('Downloading checkpoint from {}\n'.format(checkpoint))
urllib.request.urlretrieve(checkpoint, ckpt_file)
ckpt = torch.load(ckpt_file)
state_dict = ckpt['state_dict']
if checkpoint_from_distributed(state_dict):
@ -172,14 +180,17 @@ def nvidia_waveglow(pretrained=True, **kwargs):
from PyTorch.SpeechSynthesis.Tacotron2.models import batchnorm_to_float
fp16 = "model_math" in kwargs and kwargs["model_math"] == "fp16"
force_reload = "force_reload" in kwargs and kwargs["force_reload"]
if pretrained:
if fp16:
checkpoint = 'https://developer.nvidia.com/joc-waveglow-fp16-pyt-20190306'
else:
checkpoint = 'https://developer.nvidia.com/joc-waveglow-fp32-pyt-20190306'
ckpt_file = "waveglow_ckpt.pt"
urllib.request.urlretrieve(checkpoint, ckpt_file)
ckpt_file = os.path.basename(checkpoint)
if not os.path.exists(ckpt_file) or force_reload:
sys.stderr.write('Downloading checkpoint from {}\n'.format(checkpoint))
urllib.request.urlretrieve(checkpoint, ckpt_file)
ckpt = torch.load(ckpt_file)
state_dict = ckpt['state_dict']
if checkpoint_from_distributed(state_dict):