Update hubconf.py with valid NGC checkpoints
This commit is contained in:
parent
dfe30b6baf
commit
b7596093f0
13
hubconf.py
13
hubconf.py
|
@ -136,9 +136,9 @@ def nvidia_tacotron2(pretrained=True, **kwargs):
|
|||
|
||||
if pretrained:
|
||||
if fp16:
|
||||
checkpoint = 'https://developer.nvidia.com/joc-tacotron2-fp16-pyt-20190306'
|
||||
checkpoint = 'https://api.ngc.nvidia.com/v2/models/nvidia/tacotron2_pyt_ckpt_amp/versions/19.09.0/files/nvidia_tacotron2pyt_fp16_20190427'
|
||||
else:
|
||||
checkpoint = 'https://developer.nvidia.com/joc-tacotron2-fp32-pyt-20190306'
|
||||
checkpoint = 'https://api.ngc.nvidia.com/v2/models/nvidia/tacotron2_pyt_ckpt_fp32/versions/19.09.0/files/nvidia_tacotron2pyt_fp32_20190427'
|
||||
ckpt_file = _download_checkpoint(checkpoint, force_reload)
|
||||
ckpt = torch.load(ckpt_file)
|
||||
state_dict = ckpt['state_dict']
|
||||
|
@ -193,9 +193,9 @@ def nvidia_waveglow(pretrained=True, **kwargs):
|
|||
|
||||
if pretrained:
|
||||
if fp16:
|
||||
checkpoint = 'https://developer.nvidia.com/joc-waveglow-fp16-pyt-20190306'
|
||||
checkpoint = 'https://api.ngc.nvidia.com/v2/models/nvidia/waveglow_ckpt_amp/versions/19.09.0/files/nvidia_waveglowpyt_fp16_20190427'
|
||||
else:
|
||||
checkpoint = 'https://developer.nvidia.com/joc-waveglow-fp32-pyt-20190306'
|
||||
checkpoint = 'https://api.ngc.nvidia.com/v2/models/nvidia/waveglow_ckpt_fp32/versions/19.09.0/files/nvidia_waveglowpyt_fp32_20190427'
|
||||
ckpt_file = _download_checkpoint(checkpoint, force_reload)
|
||||
ckpt = torch.load(ckpt_file)
|
||||
state_dict = ckpt['state_dict']
|
||||
|
@ -363,10 +363,7 @@ def nvidia_ssd(pretrained=True, **kwargs):
|
|||
m = batchnorm_to_float(m)
|
||||
|
||||
if pretrained:
|
||||
if fp16:
|
||||
checkpoint = 'https://developer.nvidia.com/joc-ssd-fp16-pyt-20190225'
|
||||
else:
|
||||
checkpoint = 'https://developer.nvidia.com/joc-ssd-fp32-pyt-20190225'
|
||||
checkpoint = 'https://api.ngc.nvidia.com/v2/models/nvidia/ssd_pyt_ckpt_amp/versions/19.09.0/files/nvidia_ssdpyt_fp16_190826.pt'
|
||||
# ckpt = torch.hub.load_state_dict_from_url(checkpoint, progress=True, check_hash=False)
|
||||
ckpt_file = _download_checkpoint(checkpoint, force_reload)
|
||||
ckpt = torch.load(ckpt_file)
|
||||
|
|
Loading…
Reference in a new issue