Merge pull request #150 from ailzhang/torchhub_fix

Fix urllib error on colab hub example
This commit is contained in:
nv-kkudrynski 2019-08-14 13:40:46 +02:00 committed by GitHub
commit 0cabbfb67c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -34,6 +34,15 @@ def unwrap_distributed(state_dict):
new_state_dict[new_key] = value
return new_state_dict
def _download_checkpoint(checkpoint, force_reload):
model_dir = os.path.join(torch.hub._get_torch_home(), 'checkpoints')
if not os.path.exists(model_dir):
os.makedirs(model_dir)
ckpt_file = os.path.join(model_dir, os.path.basename(checkpoint))
if not os.path.exists(ckpt_file) or force_reload:
sys.stderr.write('Downloading checkpoint from {}\n'.format(checkpoint))
urllib.request.urlretrieve(checkpoint, ckpt_file)
return ckpt_file
dependencies = ['torch']
@ -66,10 +75,7 @@ def nvidia_ncf(pretrained=True, **kwargs):
checkpoint = 'https://developer.nvidia.com/joc-ncf-fp16-pyt-20190225'
else:
checkpoint = 'https://developer.nvidia.com/joc-ncf-fp32-pyt-20190225'
ckpt_file = os.path.join(torch.hub._get_torch_home(), "checkpoints", os.path.basename(checkpoint))
if not os.path.exists(ckpt_file) or force_reload:
sys.stderr.write('Downloading checkpoint from {}\n'.format(checkpoint))
urllib.request.urlretrieve(checkpoint, ckpt_file)
ckpt_file = _download_checkpoint(checkpoint, force_reload)
ckpt = torch.load(ckpt_file)
if checkpoint_from_distributed(ckpt):
@ -130,10 +136,7 @@ def nvidia_tacotron2(pretrained=True, **kwargs):
checkpoint = 'https://developer.nvidia.com/joc-tacotron2-fp16-pyt-20190306'
else:
checkpoint = 'https://developer.nvidia.com/joc-tacotron2-fp32-pyt-20190306'
ckpt_file = os.path.join(torch.hub._get_torch_home(), "checkpoints", os.path.basename(checkpoint))
if not os.path.exists(ckpt_file) or force_reload:
sys.stderr.write('Downloading checkpoint from {}\n'.format(checkpoint))
urllib.request.urlretrieve(checkpoint, ckpt_file)
ckpt_file = _download_checkpoint(checkpoint, force_reload)
ckpt = torch.load(ckpt_file)
state_dict = ckpt['state_dict']
if checkpoint_from_distributed(state_dict):
@ -190,10 +193,7 @@ def nvidia_waveglow(pretrained=True, **kwargs):
checkpoint = 'https://developer.nvidia.com/joc-waveglow-fp16-pyt-20190306'
else:
checkpoint = 'https://developer.nvidia.com/joc-waveglow-fp32-pyt-20190306'
ckpt_file = os.path.join(torch.hub._get_torch_home(), "checkpoints", os.path.basename(checkpoint))
if not os.path.exists(ckpt_file) or force_reload:
sys.stderr.write('Downloading checkpoint from {}\n'.format(checkpoint))
urllib.request.urlretrieve(checkpoint, ckpt_file)
ckpt_file = _download_checkpoint(checkpoint, force_reload)
ckpt = torch.load(ckpt_file)
state_dict = ckpt['state_dict']
if checkpoint_from_distributed(state_dict):
@ -360,10 +360,7 @@ def nvidia_ssd(pretrained=True, **kwargs):
checkpoint = 'https://developer.nvidia.com/joc-ssd-fp16-pyt-20190225'
else:
checkpoint = 'https://developer.nvidia.com/joc-ssd-fp32-pyt-20190225'
ckpt_file = os.path.join(torch.hub._get_torch_home(), "checkpoints", os.path.basename(checkpoint))
if not os.path.exists(ckpt_file) or force_reload:
sys.stderr.write('Downloading checkpoint from {}\n'.format(checkpoint))
urllib.request.urlretrieve(checkpoint, ckpt_file)
ckpt_file = _download_checkpoint(checkpoint, force_reload)
ckpt = torch.load(ckpt_file)
ckpt = ckpt['model']
if checkpoint_from_distributed(ckpt):