[Jasper/PyT] Fix pyprof import

This commit is contained in:
Adrian Lancucki 2021-06-14 15:35:58 +02:00
parent 15af494a8e
commit f41d86db2a
4 changed files with 28 additions and 13 deletions

View file

@ -24,7 +24,7 @@ COPY requirements.txt .
RUN conda install -y pyyaml==5.4.1
RUN pip install --disable-pip-version-check -U -r requirements.txt
RUN pip install --force-reinstall --extra-index-url https://developer.download.nvidia.com/compute/redist/nightly nvidia-dali-nightly-cuda110==0.28.0.dev20201026
RUN pip install --force-reinstall --extra-index-url https://developer.download.nvidia.com/compute/redist nvidia-dali-cuda110==1.2.0
# Copy rest of files
COPY . .

View file

@ -14,6 +14,7 @@
import nvidia.dali
import nvidia.dali.ops as ops
import nvidia.dali.ops.random as random
import nvidia.dali.types as types
import multiprocessing
import numpy as np
@ -103,14 +104,14 @@ class DaliPipeline(nvidia.dali.pipeline.Pipeline):
self.speed_perturbation_coeffs = ops.ExternalSource(device="cpu", cycle=True,
source=self._discrete_resample_coeffs_generator)
elif resample_range is not None:
self.speed_perturbation_coeffs = ops.Uniform(device="cpu", range=resample_range)
self.speed_perturbation_coeffs = random.Uniform(device="cpu", range=resample_range)
else:
self.speed_perturbation_coeffs = None
self.decode = ops.AudioDecoder(device="cpu", sample_rate=self.sample_rate if resample_range is None else None,
dtype=types.FLOAT, downmix=True)
self.normal_distribution = ops.NormalDistribution(device=preprocessing_device)
self.normal_distribution = random.Normal(device=preprocessing_device)
self.preemph = ops.PreemphasisFilter(device=preprocessing_device, preemph_coeff=preemph_coeff)
@ -325,7 +326,7 @@ class DaliPipeline(nvidia.dali.pipeline.Pipeline):
Generate resample coeffs from discrete set
"""
yield np.random.choice([self.resample_range[0], 1.0, self.resample_range[1]],
size=self.batch_size).astype('float32')
size=self.max_batch_size).astype('float32')
def _cutouts_generator(self):
"""
@ -339,7 +340,7 @@ class DaliPipeline(nvidia.dali.pipeline.Pipeline):
"""
return map(list, zip(*tuples))
[anchors, shapes] = tuples2list([self._generate_cutouts() for _ in range(self.batch_size)])
[anchors, shapes] = tuples2list([self._generate_cutouts() for _ in range(self.max_batch_size)])
yield np.array(anchors, dtype=np.float32), np.array(shapes, dtype=np.float32)
def define_graph(self):

View file

@ -175,18 +175,28 @@ def normalize_batch(x, seq_len, normalize_type: str):
@torch.jit.script
def splice_frames(x, frame_splicing: int):
""" Stacks frames together across feature dim
def stack_subsample_frames(x, x_lens, stacking: int = 1, subsampling: int = 1):
""" Stacks frames together across feature dim, and then subsamples
input is batch_size, feature_dim, num_frames
output is batch_size, feature_dim*frame_splicing, num_frames
output is batch_size, feature_dim * stacking, num_frames / subsampling
"""
seq = [x]
# TORCHSCRIPT: JIT doesnt like range(start, stop)
for n in range(frame_splicing - 1):
seq.append(torch.cat([x[:, :, :n + 1], x[:, :, n + 1:]], dim=2))
return torch.cat(seq, dim=1)
for n in range(1, stacking):
tmp = torch.zeros_like(x)
tmp[:, :, :-n] = x[:, :, n:]
seq.append(tmp)
x = torch.cat(seq, dim=1)[:, :, ::subsampling]
if subsampling > 1:
x_lens = torch.ceil(x_lens.float() / subsampling).int()
if x.size(2) > x_lens.max().item():
assert abs(x.size(2) - x_lens.max().item()) <= 1
x = x[:,:,:x_lens.max().item()]
return x, x_lens
class FilterbankFeatures(BaseFeatures):

View file

@ -18,7 +18,11 @@ import os
import random
import time
import pyprof
try:
import nvidia_dlprof_pytorch_nvtx as pyprof
except ModuleNotFoundError:
import pyprof
import torch
import numpy as np
import torch.cuda.profiler as profiler