Compare commits

...

15 commits

Author SHA1 Message Date
nv-kkudrynski b7596093f0
Update hubconf.py with valid NGC checkpoints 2020-11-02 16:26:38 +01:00
nv-kkudrynski dfe30b6baf
Temporary fallback to backup checkpoints 2020-10-30 22:18:37 +01:00
PrzemekS 3e961c3a14
Merge pull request #521 from mwawrzos/torchhub
[SSD][hubconf] contiguous input tensor
2020-05-26 19:31:18 +02:00
Marek Wawrzos ffe84ae70a [SSD][hubconf] contiguous input tensor 2020-05-20 16:00:47 +02:00
GrzegorzKarchNV 9189f07093
Merge pull request #463 from GrzegorzKarchNV/torchhub-taco2-infer-fix
fixed device in tacotron2 inference
2020-04-21 22:08:26 +02:00
gkarch 46758dadf8 fixed device in tacotron2 inference 2020-04-21 22:07:11 +02:00
skierat db408837c2
Merge pull request #418 from bilalsal/patch-1
Fix skimage import issue, for NVIDIA's PyTorch SSD
2020-03-13 12:03:16 +01:00
Bilal Alsallakh 8e3ef240b8
Fix skimage import issue, for NVIDIA's PyTorch SSD
The PyTorch SSD implemented by NVIDIA relies on the nvidia_ssd_processing_utils() routines in this module
(see https://pytorch.org/hub/nvidia_deeplearningexamples_ssd/).
Python 3 has issues accessing skimage's io and transform in the original code (issue reported in issue https://github.com/pytorch/hub/issues/77).
2020-03-10 23:29:34 -07:00
nv-kkudrynski e0b6d222b4
Merge pull request #325 from nv-kkudrynski/torchhub
Torchhub
2019-12-02 17:57:42 +01:00
Krzysztof Kudrynski 4c12145030 tacotron2 and waveglow checkpoints hosted on ngc 2019-12-02 17:21:50 +01:00
Krzysztof Kudrynski 5639257b22 switching ssd to ngc hosting 2019-11-21 13:53:33 +01:00
nv-kkudrynski 0cabbfb67c
Merge pull request #150 from ailzhang/torchhub_fix
Fix urllib error on colab hub example
2019-08-14 13:40:46 +02:00
Ailing Zhang b1d9921414 fix urllib error on colab 2019-08-08 18:57:39 -07:00
nvpstr 9e28f35158
Merge pull request #133 from nv-kkudrynski/torchhub
ckpt download location changed to torchhub home
2019-07-31 17:14:04 +02:00
Krzysztof Kudrynski 21d7520e14 ckpt download location changed to torchhub home 2019-07-30 18:37:31 +02:00
2 changed files with 32 additions and 30 deletions

View file

@ -530,8 +530,8 @@ class Decoder(nn.Module):
self.initialize_decoder_states(memory, mask=None) self.initialize_decoder_states(memory, mask=None)
mel_lengths = torch.zeros([memory.size(0)], dtype=torch.int32).cuda() mel_lengths = torch.zeros([memory.size(0)], dtype=torch.int32).to(memory.device)
not_finished = torch.ones([memory.size(0)], dtype=torch.int32).cuda() not_finished = torch.ones([memory.size(0)], dtype=torch.int32).to(memory.device)
mel_outputs, gate_outputs, alignments = [], [], [] mel_outputs, gate_outputs, alignments = [], [], []
while True: while True:
decoder_input = self.prenet(decoder_input) decoder_input = self.prenet(decoder_input)

View file

@ -3,6 +3,7 @@ import torch
import os import os
import sys import sys
# from https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/SpeechSynthesis/Tacotron2/inference.py # from https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/SpeechSynthesis/Tacotron2/inference.py
def checkpoint_from_distributed(state_dict): def checkpoint_from_distributed(state_dict):
""" """
@ -35,6 +36,17 @@ def unwrap_distributed(state_dict):
return new_state_dict return new_state_dict
def _download_checkpoint(checkpoint, force_reload):
model_dir = os.path.join(torch.hub._get_torch_home(), 'checkpoints')
if not os.path.exists(model_dir):
os.makedirs(model_dir)
ckpt_file = os.path.join(model_dir, os.path.basename(checkpoint))
if not os.path.exists(ckpt_file) or force_reload:
sys.stderr.write('Downloading checkpoint from {}\n'.format(checkpoint))
urllib.request.urlretrieve(checkpoint, ckpt_file)
return ckpt_file
dependencies = ['torch'] dependencies = ['torch']
@ -66,10 +78,7 @@ def nvidia_ncf(pretrained=True, **kwargs):
checkpoint = 'https://developer.nvidia.com/joc-ncf-fp16-pyt-20190225' checkpoint = 'https://developer.nvidia.com/joc-ncf-fp16-pyt-20190225'
else: else:
checkpoint = 'https://developer.nvidia.com/joc-ncf-fp32-pyt-20190225' checkpoint = 'https://developer.nvidia.com/joc-ncf-fp32-pyt-20190225'
ckpt_file = os.path.basename(checkpoint) ckpt_file = _download_checkpoint(checkpoint, force_reload)
if not os.path.exists(ckpt_file) or force_reload:
sys.stderr.write('Downloading checkpoint from {}\n'.format(checkpoint))
urllib.request.urlretrieve(checkpoint, ckpt_file)
ckpt = torch.load(ckpt_file) ckpt = torch.load(ckpt_file)
if checkpoint_from_distributed(ckpt): if checkpoint_from_distributed(ckpt):
@ -127,13 +136,10 @@ def nvidia_tacotron2(pretrained=True, **kwargs):
if pretrained: if pretrained:
if fp16: if fp16:
checkpoint = 'https://developer.nvidia.com/joc-tacotron2-fp16-pyt-20190306' checkpoint = 'https://api.ngc.nvidia.com/v2/models/nvidia/tacotron2_pyt_ckpt_amp/versions/19.09.0/files/nvidia_tacotron2pyt_fp16_20190427'
else: else:
checkpoint = 'https://developer.nvidia.com/joc-tacotron2-fp32-pyt-20190306' checkpoint = 'https://api.ngc.nvidia.com/v2/models/nvidia/tacotron2_pyt_ckpt_fp32/versions/19.09.0/files/nvidia_tacotron2pyt_fp32_20190427'
ckpt_file = os.path.basename(checkpoint) ckpt_file = _download_checkpoint(checkpoint, force_reload)
if not os.path.exists(ckpt_file) or force_reload:
sys.stderr.write('Downloading checkpoint from {}\n'.format(checkpoint))
urllib.request.urlretrieve(checkpoint, ckpt_file)
ckpt = torch.load(ckpt_file) ckpt = torch.load(ckpt_file)
state_dict = ckpt['state_dict'] state_dict = ckpt['state_dict']
if checkpoint_from_distributed(state_dict): if checkpoint_from_distributed(state_dict):
@ -187,13 +193,10 @@ def nvidia_waveglow(pretrained=True, **kwargs):
if pretrained: if pretrained:
if fp16: if fp16:
checkpoint = 'https://developer.nvidia.com/joc-waveglow-fp16-pyt-20190306' checkpoint = 'https://api.ngc.nvidia.com/v2/models/nvidia/waveglow_ckpt_amp/versions/19.09.0/files/nvidia_waveglowpyt_fp16_20190427'
else: else:
checkpoint = 'https://developer.nvidia.com/joc-waveglow-fp32-pyt-20190306' checkpoint = 'https://api.ngc.nvidia.com/v2/models/nvidia/waveglow_ckpt_fp32/versions/19.09.0/files/nvidia_waveglowpyt_fp32_20190427'
ckpt_file = os.path.basename(checkpoint) ckpt_file = _download_checkpoint(checkpoint, force_reload)
if not os.path.exists(ckpt_file) or force_reload:
sys.stderr.write('Downloading checkpoint from {}\n'.format(checkpoint))
urllib.request.urlretrieve(checkpoint, ckpt_file)
ckpt = torch.load(ckpt_file) ckpt = torch.load(ckpt_file)
state_dict = ckpt['state_dict'] state_dict = ckpt['state_dict']
if checkpoint_from_distributed(state_dict): if checkpoint_from_distributed(state_dict):
@ -225,13 +228,15 @@ def nvidia_waveglow(pretrained=True, **kwargs):
def nvidia_ssd_processing_utils(): def nvidia_ssd_processing_utils():
import numpy as np import numpy as np
import skimage import skimage
from skimage import io, transform
from PyTorch.Detection.SSD.src.utils import dboxes300_coco, Encoder from PyTorch.Detection.SSD.src.utils import dboxes300_coco, Encoder
class Processing: class Processing:
@staticmethod @staticmethod
def load_image(image_path): def load_image(image_path):
"""Code from Loading_Pretrained_Models.ipynb - a Caffe2 tutorial""" """Code from Loading_Pretrained_Models.ipynb - a Caffe2 tutorial"""
img = skimage.img_as_float(skimage.io.imread(image_path)) img = skimage.img_as_float(io.imread(image_path))
if len(img.shape) == 2: if len(img.shape) == 2:
img = np.array([img, img, img]).swapaxes(0, 2) img = np.array([img, img, img]).swapaxes(0, 2)
return img return img
@ -243,13 +248,13 @@ def nvidia_ssd_processing_utils():
if (aspect > 1): if (aspect > 1):
# landscape orientation - wide image # landscape orientation - wide image
res = int(aspect * input_height) res = int(aspect * input_height)
imgScaled = skimage.transform.resize(img, (input_width, res)) imgScaled = transform.resize(img, (input_width, res))
if (aspect < 1): if (aspect < 1):
# portrait orientation - tall image # portrait orientation - tall image
res = int(input_width / aspect) res = int(input_width / aspect)
imgScaled = skimage.transform.resize(img, (res, input_height)) imgScaled = transform.resize(img, (res, input_height))
if (aspect == 1): if (aspect == 1):
imgScaled = skimage.transform.resize(img, (input_width, input_height)) imgScaled = transform.resize(img, (input_width, input_height))
return imgScaled return imgScaled
@staticmethod @staticmethod
@ -270,6 +275,7 @@ def nvidia_ssd_processing_utils():
NHWC = np.array(inputs) NHWC = np.array(inputs)
NCHW = np.swapaxes(np.swapaxes(NHWC, 1, 3), 2, 3) NCHW = np.swapaxes(np.swapaxes(NHWC, 1, 3), 2, 3)
tensor = torch.from_numpy(NCHW) tensor = torch.from_numpy(NCHW)
tensor = tensor.contiguous()
tensor = tensor.cuda() tensor = tensor.cuda()
tensor = tensor.float() tensor = tensor.float()
if fp16: if fp16:
@ -326,6 +332,7 @@ def nvidia_ssd_processing_utils():
return Processing() return Processing()
def nvidia_ssd(pretrained=True, **kwargs): def nvidia_ssd(pretrained=True, **kwargs):
"""Constructs an SSD300 model. """Constructs an SSD300 model.
For detailed information on model input and output, training recipies, inference and performance For detailed information on model input and output, training recipies, inference and performance
@ -356,14 +363,9 @@ def nvidia_ssd(pretrained=True, **kwargs):
m = batchnorm_to_float(m) m = batchnorm_to_float(m)
if pretrained: if pretrained:
if fp16: checkpoint = 'https://api.ngc.nvidia.com/v2/models/nvidia/ssd_pyt_ckpt_amp/versions/19.09.0/files/nvidia_ssdpyt_fp16_190826.pt'
checkpoint = 'https://developer.nvidia.com/joc-ssd-fp16-pyt-20190225' # ckpt = torch.hub.load_state_dict_from_url(checkpoint, progress=True, check_hash=False)
else: ckpt_file = _download_checkpoint(checkpoint, force_reload)
checkpoint = 'https://developer.nvidia.com/joc-ssd-fp32-pyt-20190225'
ckpt_file = os.path.basename(checkpoint)
if not os.path.exists(ckpt_file) or force_reload:
sys.stderr.write('Downloading checkpoint from {}\n'.format(checkpoint))
urllib.request.urlretrieve(checkpoint, ckpt_file)
ckpt = torch.load(ckpt_file) ckpt = torch.load(ckpt_file)
ckpt = ckpt['model'] ckpt = ckpt['model']
if checkpoint_from_distributed(ckpt): if checkpoint_from_distributed(ckpt):