Merge pull request #150 from ailzhang/torchhub_fix
Fix urllib error on colab hub example
This commit is contained in:
commit
0cabbfb67c
29
hubconf.py
29
hubconf.py
|
@ -34,6 +34,15 @@ def unwrap_distributed(state_dict):
|
|||
new_state_dict[new_key] = value
|
||||
return new_state_dict
|
||||
|
||||
def _download_checkpoint(checkpoint, force_reload):
|
||||
model_dir = os.path.join(torch.hub._get_torch_home(), 'checkpoints')
|
||||
if not os.path.exists(model_dir):
|
||||
os.makedirs(model_dir)
|
||||
ckpt_file = os.path.join(model_dir, os.path.basename(checkpoint))
|
||||
if not os.path.exists(ckpt_file) or force_reload:
|
||||
sys.stderr.write('Downloading checkpoint from {}\n'.format(checkpoint))
|
||||
urllib.request.urlretrieve(checkpoint, ckpt_file)
|
||||
return ckpt_file
|
||||
|
||||
dependencies = ['torch']
|
||||
|
||||
|
@ -66,10 +75,7 @@ def nvidia_ncf(pretrained=True, **kwargs):
|
|||
checkpoint = 'https://developer.nvidia.com/joc-ncf-fp16-pyt-20190225'
|
||||
else:
|
||||
checkpoint = 'https://developer.nvidia.com/joc-ncf-fp32-pyt-20190225'
|
||||
ckpt_file = os.path.join(torch.hub._get_torch_home(), "checkpoints", os.path.basename(checkpoint))
|
||||
if not os.path.exists(ckpt_file) or force_reload:
|
||||
sys.stderr.write('Downloading checkpoint from {}\n'.format(checkpoint))
|
||||
urllib.request.urlretrieve(checkpoint, ckpt_file)
|
||||
ckpt_file = _download_checkpoint(checkpoint, force_reload)
|
||||
ckpt = torch.load(ckpt_file)
|
||||
|
||||
if checkpoint_from_distributed(ckpt):
|
||||
|
@ -130,10 +136,7 @@ def nvidia_tacotron2(pretrained=True, **kwargs):
|
|||
checkpoint = 'https://developer.nvidia.com/joc-tacotron2-fp16-pyt-20190306'
|
||||
else:
|
||||
checkpoint = 'https://developer.nvidia.com/joc-tacotron2-fp32-pyt-20190306'
|
||||
ckpt_file = os.path.join(torch.hub._get_torch_home(), "checkpoints", os.path.basename(checkpoint))
|
||||
if not os.path.exists(ckpt_file) or force_reload:
|
||||
sys.stderr.write('Downloading checkpoint from {}\n'.format(checkpoint))
|
||||
urllib.request.urlretrieve(checkpoint, ckpt_file)
|
||||
ckpt_file = _download_checkpoint(checkpoint, force_reload)
|
||||
ckpt = torch.load(ckpt_file)
|
||||
state_dict = ckpt['state_dict']
|
||||
if checkpoint_from_distributed(state_dict):
|
||||
|
@ -190,10 +193,7 @@ def nvidia_waveglow(pretrained=True, **kwargs):
|
|||
checkpoint = 'https://developer.nvidia.com/joc-waveglow-fp16-pyt-20190306'
|
||||
else:
|
||||
checkpoint = 'https://developer.nvidia.com/joc-waveglow-fp32-pyt-20190306'
|
||||
ckpt_file = os.path.join(torch.hub._get_torch_home(), "checkpoints", os.path.basename(checkpoint))
|
||||
if not os.path.exists(ckpt_file) or force_reload:
|
||||
sys.stderr.write('Downloading checkpoint from {}\n'.format(checkpoint))
|
||||
urllib.request.urlretrieve(checkpoint, ckpt_file)
|
||||
ckpt_file = _download_checkpoint(checkpoint, force_reload)
|
||||
ckpt = torch.load(ckpt_file)
|
||||
state_dict = ckpt['state_dict']
|
||||
if checkpoint_from_distributed(state_dict):
|
||||
|
@ -360,10 +360,7 @@ def nvidia_ssd(pretrained=True, **kwargs):
|
|||
checkpoint = 'https://developer.nvidia.com/joc-ssd-fp16-pyt-20190225'
|
||||
else:
|
||||
checkpoint = 'https://developer.nvidia.com/joc-ssd-fp32-pyt-20190225'
|
||||
ckpt_file = os.path.join(torch.hub._get_torch_home(), "checkpoints", os.path.basename(checkpoint))
|
||||
if not os.path.exists(ckpt_file) or force_reload:
|
||||
sys.stderr.write('Downloading checkpoint from {}\n'.format(checkpoint))
|
||||
urllib.request.urlretrieve(checkpoint, ckpt_file)
|
||||
ckpt_file = _download_checkpoint(checkpoint, force_reload)
|
||||
ckpt = torch.load(ckpt_file)
|
||||
ckpt = ckpt['model']
|
||||
if checkpoint_from_distributed(ckpt):
|
||||
|
|
Loading…
Reference in a new issue