DeepLearningExamples/PyTorch/SpeechRecognition/QuartzNet/common/dali/pipeline.py
2021-09-14 06:03:36 -07:00

344 lines
13 KiB
Python

# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import itertools
import math
import multiprocessing
import numpy as np
import nvidia.dali as dali
import nvidia.dali.fn as fn
import nvidia.dali.types as types
import torch
import torch.distributed as dist
def _interleave_lists(*lists):
"""
[*, **, ***], [1, 2, 3], [a, b, c] -> [*, 1, a, **, 2, b, ***, 3, c]
Returns:
iterator over interleaved list
"""
assert all((len(lists[0]) == len(test_l) for test_l in lists)), \
"All lists have to have the same length"
return itertools.chain(*zip(*lists))
def _generate_cutouts(mask_params, nfeatures):
"""
Returns:
Generates anchors and shapes of the cutout regions.
Single call generates one batch of data.
The output shall be passed to DALI's Erase operator
anchors = [f0 t0 f1 t1 ...]
shapes = [f0w t0h f1w t1h ...]
"""
MAX_TIME_DIMENSION = 20 * 16000
freq_anchors = np.random.random(mask_params['freq_num_regions'])
time_anchors = np.random.random(mask_params['time_num_regions'])
both_anchors_freq = np.random.random(mask_params['both_num_regions'])
both_anchors_time = np.random.random(mask_params['both_num_regions'])
anchors = []
for anch in freq_anchors:
anchors.extend([anch, 0])
for anch in time_anchors:
anchors.extend([0, anch])
for t, f in zip(both_anchors_time, both_anchors_freq):
anchors.extend([f, t])
shapes = []
shapes.extend(
_interleave_lists(
np.random.randint(mask_params['freq_min'],
mask_params['freq_max'] + 1,
mask_params['freq_num_regions']),
# XXX: Here, a time dimension of the spectrogram shall be passed.
# However, in DALI ArgumentInput can't come from GPU.
# So we leave the job for Erase (masking operator) to get it together.
[int(MAX_TIME_DIMENSION)] * mask_params['freq_num_regions']
)
)
shapes.extend(
_interleave_lists(
[nfeatures] * mask_params['time_num_regions'],
np.random.randint(mask_params['time_min'],
mask_params['time_max'] + 1,
mask_params['time_num_regions'])
)
)
shapes.extend(
_interleave_lists(
np.random.randint(mask_params['both_min_freq'],
mask_params['both_max_freq'] + 1,
mask_params['both_num_regions']),
np.random.randint(mask_params['both_min_time'],
mask_params['both_max_time'] + 1,
mask_params['both_num_regions'])
)
)
return anchors, shapes
def _tuples2list(tuples: list):
"""
[(a, b), (c, d)] -> [[a, c], [b, d]]
"""
return map(list, zip(*tuples))
def _dali_init_log(args: dict):
if not dist.is_initialized() or dist.get_rank() == 0:
max_len = max([len(ii) for ii in args.keys()])
fmt_string = '\t%' + str(max_len) + 's : %s'
print('Initializing DALI with parameters:')
for keyPair in sorted(args.items()):
print(fmt_string % keyPair)
@dali.pipeline_def
def dali_asr_pipeline(train_pipeline, # True if training, False if validation
file_root,
file_list,
sample_rate,
silence_threshold,
resample_range,
discrete_resample_range,
window_size,
window_stride,
nfeatures,
nfft,
frame_splicing_factor,
dither_coeff,
pad_align,
preemph_coeff,
do_spectrogram_masking=False,
cutouts_generator=None,
shard_id=0,
n_shards=1,
preprocessing_device="gpu"):
do_remove_silence = silence_threshold is not None
def _div_ceil(dividend, divisor):
return (dividend + (divisor - 1)) // divisor
encoded, label = fn.readers.file(
device="cpu", name="file_reader", file_root=file_root,
file_list=file_list, shard_id=shard_id, num_shards=n_shards,
shuffle_after_epoch=train_pipeline)
speed_perturbation_coeffs = None
if resample_range is not None:
if discrete_resample_range:
values = [resample_range[0], 1.0, resample_range[1]]
speed_perturbation_coeffs = fn.random.uniform(device="cpu",
values=values)
else:
speed_perturbation_coeffs = fn.random.uniform(device="cpu",
range=resample_range)
if train_pipeline and speed_perturbation_coeffs is not None:
dec_sample_rate_arg = speed_perturbation_coeffs * sample_rate
elif resample_range is None:
dec_sample_rate_arg = sample_rate
else:
dec_sample_rate_arg = None
audio, _ = fn.decoders.audio(encoded, sample_rate=dec_sample_rate_arg,
dtype=types.FLOAT, downmix=True)
if do_remove_silence:
begin, length = fn.nonsilent_region(audio, cutoff_db=silence_threshold)
audio = fn.slice(audio, begin, length, axes=[0])
# Max duration drop is performed at DataLayer stage
if preprocessing_device == "gpu":
audio = audio.gpu()
if dither_coeff != 0.:
audio = audio + fn.random.normal(device=preprocessing_device
) * dither_coeff
audio = fn.preemphasis_filter(audio, preemph_coeff=preemph_coeff)
spec = fn.spectrogram(audio, nfft=nfft,
window_length=window_size * sample_rate,
window_step=window_stride * sample_rate)
mel_spec = fn.mel_filter_bank(spec, sample_rate=sample_rate,
nfilter=nfeatures, normalize=True)
log_features = fn.to_decibels(mel_spec, multiplier=np.log(10),
reference=1.0, cutoff_db=math.log(1e-20))
log_features_len = fn.shapes(log_features)
if frame_splicing_factor != 1:
log_features_len = _div_ceil(log_features_len, frame_splicing_factor)
log_features = fn.normalize(log_features, axes=[1])
log_features = fn.pad(log_features, axes=[1], fill_value=0, align=pad_align)
if train_pipeline and do_spectrogram_masking:
anchors, shapes = fn.external_source(source=cutouts_generator,
num_outputs=2, cycle=True)
log_features = fn.erase(log_features, anchor=anchors, shape=shapes,
axes=[0, 1], fill_value=0,
normalized_anchor=True)
# When modifying DALI pipeline returns, make sure you update `output_map`
# in DALIGenericIterator invocation
return log_features.gpu(), label.gpu(), log_features_len.gpu()
def make_dali_asr_pipeline(train_pipeline: bool, device_id, batch_size,
file_root: str, file_list: str, config_data: dict,
config_features: dict, device_type: str = "gpu",
do_resampling: bool = True,
num_cpu_threads: int = multiprocessing.cpu_count()):
max_duration = config_data['max_duration']
sample_rate = config_data['sample_rate']
silence_threshold = -60 if config_data['trim_silence'] else None
# TODO Take into account resampling probablity
# TODO config_features['speed_perturbation']['p']
if do_resampling and config_data['speed_perturbation'] is not None:
resample_range = [config_data['speed_perturbation']['min_rate'],
config_data['speed_perturbation']['max_rate']]
discrete_resample_range = config_data['speed_perturbation']['discrete']
else:
resample_range = None
discrete_resample_range = False
window_size = config_features['window_size']
window_stride = config_features['window_stride']
nfeatures = config_features['n_filt']
nfft = config_features['n_fft']
frame_splicing_factor = config_features['frame_splicing']
dither_coeff = config_features['dither']
pad_align = config_features['pad_align']
pad_to_max_duration = config_features['pad_to_max_duration']
assert not pad_to_max_duration, \
"Padding to max duration currently not supported in DALI"
preemph_coeff = .97
config_spec = config_features['spec_augment']
if config_spec is not None:
mask_time_num_regions = config_spec['time_masks']
mask_time_min = config_spec['min_time']
mask_time_max = config_spec['max_time']
mask_freq_num_regions = config_spec['freq_masks']
mask_freq_min = config_spec['min_freq']
mask_freq_max = config_spec['max_freq']
else:
mask_time_num_regions = 0
mask_time_min = 0
mask_time_max = 0
mask_freq_num_regions = 0
mask_freq_min = 0
mask_freq_max = 0
config_cutout = config_features['cutout_augment']
if config_cutout is not None:
mask_both_num_regions = config_cutout['masks']
mask_both_min_time = config_cutout['min_time']
mask_both_max_time = config_cutout['max_time']
mask_both_min_freq = config_cutout['min_freq']
mask_both_max_freq = config_cutout['max_freq']
else:
mask_both_num_regions = 0
mask_both_min_time = 0
mask_both_max_time = 0
mask_both_min_freq = 0
mask_both_max_freq = 0
nfeatures = config_features['n_filt']
do_spectrogram_masking = \
mask_time_num_regions > 0 or mask_freq_num_regions > 0 or \
mask_both_num_regions > 0
do_remove_silence = silence_threshold is not None
del(config_spec)
del(config_cutout)
del(config_data)
del(config_features)
_dali_init_log(locals())
mask_params = {
'time_num_regions': mask_time_num_regions,
'time_min': mask_time_min,
'time_max': mask_time_max,
'freq_num_regions': mask_freq_num_regions,
'freq_min': mask_freq_min,
'freq_max': mask_freq_max,
'both_num_regions': mask_both_num_regions,
'both_min_time': mask_both_min_time,
'both_max_time': mask_both_max_time,
'both_min_freq': mask_both_min_freq,
'both_max_freq': mask_both_max_freq,
}
def _cutouts_generator():
"""
Generator, that wraps cutouts creation in order to randomize inputs
and allow passing them to DALI's ExternalSource operator
"""
[anchors, shapes] = _tuples2list(
[_generate_cutouts(mask_params, nfeatures)
for _ in range(batch_size)])
yield (np.array(anchors, dtype=np.float32),
np.array(shapes, dtype=np.float32))
cutouts_gen = _cutouts_generator if do_spectrogram_masking else None
if torch.distributed.is_initialized():
shard_id = torch.distributed.get_rank()
n_shards = torch.distributed.get_world_size()
else:
shard_id = 0
n_shards = 1
preprocessing_device = device_type.lower()
assert preprocessing_device == "cpu" or preprocessing_device == "gpu", \
"Incorrect preprocessing device. Please choose either 'cpu' or 'gpu'"
pipe = dali_asr_pipeline(
train_pipeline=train_pipeline,
file_root=file_root,
file_list=file_list,
sample_rate=sample_rate,
silence_threshold=silence_threshold,
resample_range=resample_range,
discrete_resample_range=discrete_resample_range,
window_size=window_size,
window_stride=window_stride,
nfeatures=nfeatures,
nfft=nfft,
frame_splicing_factor=frame_splicing_factor,
dither_coeff=dither_coeff,
pad_align=pad_align,
preemph_coeff=preemph_coeff,
do_spectrogram_masking=do_spectrogram_masking,
cutouts_generator=cutouts_gen,
shard_id=shard_id,
n_shards=n_shards,
preprocessing_device=preprocessing_device,
batch_size=batch_size,
num_threads=num_cpu_threads,
device_id=device_id
)
return pipe