fix urllib error on colab

This commit is contained in:
Ailing Zhang 2019-08-08 18:57:39 -07:00
parent 9e28f35158
commit b1d9921414

View file

@ -34,6 +34,15 @@ def unwrap_distributed(state_dict):
new_state_dict[new_key] = value
return new_state_dict
def _download_checkpoint(checkpoint, force_reload):
model_dir = os.path.join(torch.hub._get_torch_home(), 'checkpoints')
if not os.path.exists(model_dir):
os.makedirs(model_dir)
ckpt_file = os.path.join(model_dir, os.path.basename(checkpoint))
if not os.path.exists(ckpt_file) or force_reload:
sys.stderr.write('Downloading checkpoint from {}\n'.format(checkpoint))
urllib.request.urlretrieve(checkpoint, ckpt_file)
return ckpt_file
dependencies = ['torch']
@ -66,10 +75,7 @@ def nvidia_ncf(pretrained=True, **kwargs):
checkpoint = 'https://developer.nvidia.com/joc-ncf-fp16-pyt-20190225'
else:
checkpoint = 'https://developer.nvidia.com/joc-ncf-fp32-pyt-20190225'
ckpt_file = os.path.join(torch.hub._get_torch_home(), "checkpoints", os.path.basename(checkpoint))
if not os.path.exists(ckpt_file) or force_reload:
sys.stderr.write('Downloading checkpoint from {}\n'.format(checkpoint))
urllib.request.urlretrieve(checkpoint, ckpt_file)
ckpt_file = _download_checkpoint(checkpoint, force_reload)
ckpt = torch.load(ckpt_file)
if checkpoint_from_distributed(ckpt):
@ -130,10 +136,7 @@ def nvidia_tacotron2(pretrained=True, **kwargs):
checkpoint = 'https://developer.nvidia.com/joc-tacotron2-fp16-pyt-20190306'
else:
checkpoint = 'https://developer.nvidia.com/joc-tacotron2-fp32-pyt-20190306'
ckpt_file = os.path.join(torch.hub._get_torch_home(), "checkpoints", os.path.basename(checkpoint))
if not os.path.exists(ckpt_file) or force_reload:
sys.stderr.write('Downloading checkpoint from {}\n'.format(checkpoint))
urllib.request.urlretrieve(checkpoint, ckpt_file)
ckpt_file = _download_checkpoint(checkpoint, force_reload)
ckpt = torch.load(ckpt_file)
state_dict = ckpt['state_dict']
if checkpoint_from_distributed(state_dict):
@ -190,10 +193,7 @@ def nvidia_waveglow(pretrained=True, **kwargs):
checkpoint = 'https://developer.nvidia.com/joc-waveglow-fp16-pyt-20190306'
else:
checkpoint = 'https://developer.nvidia.com/joc-waveglow-fp32-pyt-20190306'
ckpt_file = os.path.join(torch.hub._get_torch_home(), "checkpoints", os.path.basename(checkpoint))
if not os.path.exists(ckpt_file) or force_reload:
sys.stderr.write('Downloading checkpoint from {}\n'.format(checkpoint))
urllib.request.urlretrieve(checkpoint, ckpt_file)
ckpt_file = _download_checkpoint(checkpoint, force_reload)
ckpt = torch.load(ckpt_file)
state_dict = ckpt['state_dict']
if checkpoint_from_distributed(state_dict):
@ -360,10 +360,7 @@ def nvidia_ssd(pretrained=True, **kwargs):
checkpoint = 'https://developer.nvidia.com/joc-ssd-fp16-pyt-20190225'
else:
checkpoint = 'https://developer.nvidia.com/joc-ssd-fp32-pyt-20190225'
ckpt_file = os.path.join(torch.hub._get_torch_home(), "checkpoints", os.path.basename(checkpoint))
if not os.path.exists(ckpt_file) or force_reload:
sys.stderr.write('Downloading checkpoint from {}\n'.format(checkpoint))
urllib.request.urlretrieve(checkpoint, ckpt_file)
ckpt_file = _download_checkpoint(checkpoint, force_reload)
ckpt = torch.load(ckpt_file)
ckpt = ckpt['model']
if checkpoint_from_distributed(ckpt):