Update hubconf.py with valid NGC checkpoints

This commit is contained in:
nv-kkudrynski 2020-11-02 16:26:38 +01:00 committed by GitHub
parent dfe30b6baf
commit b7596093f0
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -136,9 +136,9 @@ def nvidia_tacotron2(pretrained=True, **kwargs):
if pretrained:
if fp16:
checkpoint = 'https://developer.nvidia.com/joc-tacotron2-fp16-pyt-20190306'
checkpoint = 'https://api.ngc.nvidia.com/v2/models/nvidia/tacotron2_pyt_ckpt_amp/versions/19.09.0/files/nvidia_tacotron2pyt_fp16_20190427'
else:
checkpoint = 'https://developer.nvidia.com/joc-tacotron2-fp32-pyt-20190306'
checkpoint = 'https://api.ngc.nvidia.com/v2/models/nvidia/tacotron2_pyt_ckpt_fp32/versions/19.09.0/files/nvidia_tacotron2pyt_fp32_20190427'
ckpt_file = _download_checkpoint(checkpoint, force_reload)
ckpt = torch.load(ckpt_file)
state_dict = ckpt['state_dict']
@ -193,9 +193,9 @@ def nvidia_waveglow(pretrained=True, **kwargs):
if pretrained:
if fp16:
checkpoint = 'https://developer.nvidia.com/joc-waveglow-fp16-pyt-20190306'
checkpoint = 'https://api.ngc.nvidia.com/v2/models/nvidia/waveglow_ckpt_amp/versions/19.09.0/files/nvidia_waveglowpyt_fp16_20190427'
else:
checkpoint = 'https://developer.nvidia.com/joc-waveglow-fp32-pyt-20190306'
checkpoint = 'https://api.ngc.nvidia.com/v2/models/nvidia/waveglow_ckpt_fp32/versions/19.09.0/files/nvidia_waveglowpyt_fp32_20190427'
ckpt_file = _download_checkpoint(checkpoint, force_reload)
ckpt = torch.load(ckpt_file)
state_dict = ckpt['state_dict']
@ -363,10 +363,7 @@ def nvidia_ssd(pretrained=True, **kwargs):
m = batchnorm_to_float(m)
if pretrained:
if fp16:
checkpoint = 'https://developer.nvidia.com/joc-ssd-fp16-pyt-20190225'
else:
checkpoint = 'https://developer.nvidia.com/joc-ssd-fp32-pyt-20190225'
checkpoint = 'https://api.ngc.nvidia.com/v2/models/nvidia/ssd_pyt_ckpt_amp/versions/19.09.0/files/nvidia_ssdpyt_fp16_190826.pt'
# ckpt = torch.hub.load_state_dict_from_url(checkpoint, progress=True, check_hash=False)
ckpt_file = _download_checkpoint(checkpoint, force_reload)
ckpt = torch.load(ckpt_file)