Adding parallel transcribe for ASR models - suppports multi-gpu/multi-node (#3017)

* added transcribe_speech_parallel.py.

Signed-off-by: Vahid <vnoroozi@nvidia.com>

* fixed style.

Signed-off-by: Vahid <vnoroozi@nvidia.com>

* removed comments.

Signed-off-by: Vahid <vnoroozi@nvidia.com>

* fixed bug.

Signed-off-by: Vahid <vnoroozi@nvidia.com>

* fixed bug.

Signed-off-by: Vahid <vnoroozi@nvidia.com>

* added comments inside the script.

Signed-off-by: Vahid <vnoroozi@nvidia.com>

* fixed speed_collate_fn for TTS.

Signed-off-by: Vahid <vnoroozi@nvidia.com>

* fixed speed_collate_fn for TTS.

Signed-off-by: Vahid <vnoroozi@nvidia.com>

* fixed speed_collate_fn for TTS.

Signed-off-by: Vahid <vnoroozi@nvidia.com>

* fixed speed_collate_fn for TTS.

Signed-off-by: Vahid <vnoroozi@nvidia.com>

* added return_sample_id.

Signed-off-by: Vahid <vnoroozi@nvidia.com>

* added return_sample_id.

Signed-off-by: Vahid <vnoroozi@nvidia.com>

* merged dataset configs.

Signed-off-by: Vahid <vnoroozi@nvidia.com>

* merged dataset configs.

Signed-off-by: Vahid <vnoroozi@nvidia.com>

* dropped sample_ids from train and validation.

Signed-off-by: Vahid <vnoroozi@nvidia.com>

* dropped sample_ids from train and validation.

Signed-off-by: Vahid <vnoroozi@nvidia.com>

* reverted tts patches.

Signed-off-by: Vahid <vnoroozi@nvidia.com>

* reverted tts patches.

Signed-off-by: Vahid <vnoroozi@nvidia.com>

* fixed the default values in the dataset's config

Signed-off-by: Vahid <vnoroozi@nvidia.com>

* fixed the default values in the dataset's config

Signed-off-by: Vahid <vnoroozi@nvidia.com>

* Fixed the bug for optional outputs.

Signed-off-by: Vahid <vnoroozi@nvidia.com>

* Fixed the bug for optional outputs.

Signed-off-by: Vahid <vnoroozi@nvidia.com>

* addressed some comments.

Signed-off-by: Vahid <vnoroozi@nvidia.com>

* disabled dali support for return_sample_id.

Signed-off-by: Vahid <vnoroozi@nvidia.com>

* disabled dali support for return_sample_id.

Signed-off-by: Vahid <vnoroozi@nvidia.com>

* converted the config to omegaconf.

Signed-off-by: Vahid <vnoroozi@nvidia.com>

* converted the config to omegaconf.

Signed-off-by: Vahid <vnoroozi@nvidia.com>

* converted the config to omegaconf.

Signed-off-by: Vahid <vnoroozi@nvidia.com>

* moved wer/cer calculation to the end.

Signed-off-by: Vahid <vnoroozi@nvidia.com>

* moved wer/cer calculation to the end.

Signed-off-by: Vahid <vnoroozi@nvidia.com>

* moved wer/cer calculation to the end.

Signed-off-by: Vahid <vnoroozi@nvidia.com>

* moved wer/cer calculation to the end.

Signed-off-by: Vahid <vnoroozi@nvidia.com>

* moved wer/cer calculation to the end.

Signed-off-by: Vahid <vnoroozi@nvidia.com>

* calculates global wer instead of per sample.

Signed-off-by: Vahid <vnoroozi@nvidia.com>

* calculates global wer instead of per sample.

Signed-off-by: Vahid <vnoroozi@nvidia.com>

* calculates global wer instead of per sample.

Signed-off-by: Vahid <vnoroozi@nvidia.com>
This commit is contained in:
Vahid Noroozi 2021-11-10 00:37:19 -08:00 committed by GitHub
parent b12ac8ae85
commit cfcf694e30
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
16 changed files with 379 additions and 64 deletions

View file

@ -28,9 +28,6 @@ model:
log_prediction: true # enables logging sample predictions in the output during training
model_defaults:
se: true
se_context_size: -1
# encoder / decoder / joint values
enc_hidden: ${model.encoder.d_model}
pred_hidden: 640
joint_hidden: 640
@ -107,8 +104,10 @@ model:
# Multi-headed Attention Module's params
self_attention_model: rel_pos # rel_pos or abs_pos
n_heads: 8
xscaling: true # scales up the inputs by sqrt(d_model)
n_heads: 8 # may need to be lower for smaller d_models
# [left, right] specifies the number of steps to be seen from left and right of each step in self-attention
att_context_size: [-1, -1] # -1 means unlimited context
xscaling: true # scales up the input embeddings by sqrt(d_model)
untie_biases: true # unties the biases of the TransformerXL layers
pos_emb_max_len: 5000

View file

@ -31,9 +31,6 @@ model:
"n", "o", "p", "q", "r", "s", "t", "u", "v", "w", "x", "y", "z", "'"]
model_defaults:
se: true
se_context_size: -1
# encoder / decoder / joint values
enc_hidden: ${model.encoder.d_model}
pred_hidden: 640
joint_hidden: 640
@ -102,8 +99,10 @@ model:
# Multi-headed Attention Module's params
self_attention_model: rel_pos # rel_pos or abs_pos
n_heads: 8
xscaling: true # scales up the inputs by sqrt(d_model)
n_heads: 8 # may need to be lower for smaller d_models
# [left, right] specifies the number of steps to be seen from left and right of each step in self-attention
att_context_size: [-1, -1] # -1 means unlimited context
xscaling: true # scales up the input embeddings by sqrt(d_model)
untie_biases: true # unties the biases of the TransformerXL layers
pos_emb_max_len: 5000

View file

@ -0,0 +1,189 @@
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
# ASR transcribe/inference with multi-GPU/multi-node support for large datasets
# It supports both tarred and non-tarred datasets
# Arguments
# model: path to a nemo/PTL checkpoint file or name of a pretrained model
# predict_ds: config of the dataset/dataloader
# output_path: path to store the predictions
# return_predictions: whether to return the predictions as output other than writing into the files
# use_cer: whether to calculate the error in terms of CER or use the default WER
#
# Results of each GPU/worker is written into a file named 'predictions_{rank}.json, and aggregated results of all workers are written into 'predictions_all.json'
Example for non-tarred datasets:
python transcribe_speech_parallel.py \
model=stt_en_conformer_ctc_large \
predict_ds.manifest_filepath=/dataset/manifest_file.json \
predict_ds.batch_size=16 \
output_path=/tmp/
Example for tarred datasets:
python transcribe_speech_parallel.py \
predict_ds.is_tarred=true \
predict_ds.manifest_filepath=/tarred_dataset/tarred_audio_manifest.json \
predict_ds.tarred_audio_filepaths=/tarred_dataset/audio__OP_0..127_CL_.tar \
...
By default the trainer uses all the GPUs available and default precision is FP32.
By setting the trainer config you may control these configs. For example to do the predictions with AMP on just two GPUs:
python transcribe_speech_parallel.py \
trainer.precision=16 \
trainer.gpus=2 \
...
You may control the dataloader's config by setting the predict_ds:
python transcribe_speech_parallel.py \
predict_ds.num_workers=8 \
predict_ds.min_duration=2.0 \
predict_ds.sample_rate=16000 \
model=stt_en_conformer_ctc_small \
...
"""
import itertools
import json
import os
from dataclasses import dataclass, is_dataclass
from typing import Optional
import pytorch_lightning as ptl
import torch
from omegaconf import OmegaConf
from nemo.collections.asr.data.audio_to_text_dataset import ASRPredictionWriter
from nemo.collections.asr.metrics.rnnt_wer import RNNTDecodingConfig
from nemo.collections.asr.metrics.wer import word_error_rate
from nemo.collections.asr.models import ASRModel
from nemo.collections.asr.models.configs.asr_models_config import ASRDatasetConfig
from nemo.core.config import TrainerConfig, hydra_runner
from nemo.utils import logging
from nemo.utils.get_rank import is_global_rank_zero
@dataclass
class ParallelTranscriptionConfig:
model: Optional[str] = None # name
predict_ds: ASRDatasetConfig = ASRDatasetConfig(return_sample_id=True, num_workers=4)
output_path: Optional[str] = None
# when return_predictions is enabled, the prediction call would keep all the predictions in memory and return them when prediction is done
return_predictions: bool = False
use_cer: bool = False
# decoding strategy for RNNT models
rnnt_decoding: RNNTDecodingConfig = RNNTDecodingConfig()
trainer: TrainerConfig = TrainerConfig(gpus=-1, accelerator="ddp")
def match_train_config(predict_ds, train_ds):
# It copies the important configurations from the train dataset of the model
# into the predict_ds to be used for prediction. It is needed to match the training configurations.
if train_ds is None:
return
predict_ds.sample_rate = train_ds.get("sample_rate", 16000)
cfg_name_list = [
"int_values",
"use_start_end_token",
"blank_index",
"unk_index",
"normalize",
"parser",
"eos_id",
"bos_id",
"pad_id",
]
if is_dataclass(predict_ds):
predict_ds = OmegaConf.structured(predict_ds)
for cfg_name in cfg_name_list:
if hasattr(train_ds, cfg_name):
setattr(predict_ds, cfg_name, getattr(train_ds, cfg_name))
return predict_ds
@hydra_runner(config_name="TranscriptionConfig", schema=ParallelTranscriptionConfig)
def main(cfg: ParallelTranscriptionConfig):
if cfg.model.endswith(".nemo"):
logging.info("Attempting to initialize from .nemo file")
model = ASRModel.restore_from(restore_path=cfg.model, map_location="cpu")
elif cfg.model.endswith(".ckpt"):
logging.info("Attempting to initialize from .ckpt file")
model = ASRModel.load_from_checkpoint(checkpoint_path=cfg.model, map_location="cpu")
else:
logging.info(
"Attempting to initialize from a pretrained model as the model name does not have the extension of .nemo or .ckpt"
)
model = ASRModel.from_pretrained(model_name=cfg.model, map_location="cpu")
trainer = ptl.Trainer(**cfg.trainer)
cfg.predict_ds.return_sample_id = True
cfg.predict_ds = match_train_config(predict_ds=cfg.predict_ds, train_ds=model.cfg.train_ds)
data_loader = model._setup_dataloader_from_config(cfg.predict_ds)
os.makedirs(cfg.output_path, exist_ok=True)
# trainer.global_rank is not valid before predict() is called. Need this hack to find the correct global_rank.
global_rank = trainer.node_rank * trainer.num_gpus + int(os.environ.get("LOCAL_RANK", 0))
output_file = os.path.join(cfg.output_path, f"predictions_{global_rank}.json")
predictor_writer = ASRPredictionWriter(dataset=data_loader.dataset, output_file=output_file)
trainer.callbacks.extend([predictor_writer])
predictions = trainer.predict(model=model, dataloaders=data_loader, return_predictions=cfg.return_predictions)
if predictions is not None:
predictions = list(itertools.chain.from_iterable(predictions))
samples_num = predictor_writer.close_output_file()
logging.info(
f"Prediction on rank {global_rank} is done for {samples_num} samples and results are stored in {output_file}."
)
if torch.distributed.is_initialized():
torch.distributed.barrier()
samples_num = 0
pred_text_list = []
text_list = []
if is_global_rank_zero():
output_file = os.path.join(cfg.output_path, f"predictions_all.json")
logging.info(f"Prediction files are being aggregated in {output_file}.")
with open(output_file, 'w') as outf:
for rank in range(trainer.world_size):
input_file = os.path.join(cfg.output_path, f"predictions_{rank}.json")
with open(input_file, 'r') as inpf:
lines = inpf.readlines()
for line in lines:
item = json.loads(line)
pred_text_list.append(item["pred_text"])
text_list.append(item["text"])
outf.write(json.dumps(item) + "\n")
samples_num += 1
wer_cer = word_error_rate(hypotheses=pred_text_list, references=text_list, use_cer=cfg.use_cer)
logging.info(
f"Prediction is done for {samples_num} samples in total on all workers and results are aggregated in {output_file}."
)
logging.info("{} for all predictions is {:.4f}.".format("CER" if cfg.use_cer else "WER", wer_cer))
if __name__ == '__main__':
main()

View file

@ -50,7 +50,14 @@ def _speech_collate_fn(batch, pad_id):
encoded tokens, and encoded tokens length. This collate func
assumes the signals are 1d torch tensors (i.e. mono audio).
"""
_, audio_lengths, _, tokens_lengths = zip(*batch)
packed_batch = list(zip(*batch))
if len(packed_batch) == 5:
_, audio_lengths, _, tokens_lengths, sample_ids = packed_batch
elif len(packed_batch) == 4:
sample_ids = None
_, audio_lengths, _, tokens_lengths = packed_batch
else:
raise ValueError("Expects 4 or 5 tensors in the batch!")
max_audio_len = 0
has_audio = audio_lengths[0] is not None
if has_audio:
@ -58,7 +65,11 @@ def _speech_collate_fn(batch, pad_id):
max_tokens_len = max(tokens_lengths).item()
audio_signal, tokens = [], []
for sig, sig_len, tokens_i, tokens_i_len in batch:
for b in batch:
if len(b) == 5:
sig, sig_len, tokens_i, tokens_i_len, _ = b
else:
sig, sig_len, tokens_i, tokens_i_len = b
if has_audio:
sig_len = sig_len.item()
if sig_len < max_audio_len:
@ -78,8 +89,11 @@ def _speech_collate_fn(batch, pad_id):
audio_signal, audio_lengths = None, None
tokens = torch.stack(tokens)
tokens_lengths = torch.stack(tokens_lengths)
return audio_signal, audio_lengths, tokens, tokens_lengths
if sample_ids is None:
return audio_signal, audio_lengths, tokens, tokens_lengths
else:
sample_ids = torch.tensor(sample_ids, dtype=torch.int32)
return audio_signal, audio_lengths, tokens, tokens_lengths, sample_ids
class ASRManifestProcessor:
@ -115,7 +129,7 @@ class ASRManifestProcessor:
self.parser = parser
self.collection = collections.ASRAudioText(
manifests_files=manifest_filepath.split(','),
manifests_files=manifest_filepath,
parser=parser,
min_duration=min_duration,
max_duration=max_duration,
@ -164,6 +178,7 @@ class _AudioTextDataset(Dataset):
normalize: whether to normalize transcript text (default): True
bos_id: Id of beginning of sequence symbol to append if not None
eos_id: Id of end of sequence symbol to append if not None
return_sample_id (bool): whether to return the sample_id as a part of each sample
"""
@property
@ -175,6 +190,7 @@ class _AudioTextDataset(Dataset):
'a_sig_length': NeuralType(tuple('B'), LengthsType()),
'transcripts': NeuralType(('B', 'T'), LabelsType()),
'transcript_length': NeuralType(tuple('B'), LengthsType()),
'sample_id': NeuralType(tuple('B'), LengthsType(), optional=True),
}
def __init__(
@ -191,6 +207,7 @@ class _AudioTextDataset(Dataset):
bos_id: Optional[int] = None,
eos_id: Optional[int] = None,
pad_id: int = 0,
return_sample_id: bool = False,
):
self.manifest_processor = ASRManifestProcessor(
manifest_filepath=manifest_filepath,
@ -204,6 +221,10 @@ class _AudioTextDataset(Dataset):
)
self.featurizer = WaveformFeaturizer(sample_rate=sample_rate, int_values=int_values, augmentor=augmentor)
self.trim = trim
self.return_sample_id = return_sample_id
def get_manifest_sample(self, sample_id):
return self.manifest_processor.collection[sample_id]
def __getitem__(self, index):
sample = self.manifest_processor.collection[index]
@ -219,7 +240,10 @@ class _AudioTextDataset(Dataset):
t, tl = self.manifest_processor.process_text(index)
output = f, fl, torch.tensor(t).long(), torch.tensor(tl).long()
if self.return_sample_id:
output = f, fl, torch.tensor(t).long(), torch.tensor(tl).long(), index
else:
output = f, fl, torch.tensor(t).long(), torch.tensor(tl).long()
return output
@ -258,6 +282,7 @@ class AudioToCharDataset(_AudioTextDataset):
normalize: whether to normalize transcript text (default): True
bos_id: Id of beginning of sequence symbol to append if not None
eos_id: Id of end of sequence symbol to append if not None
return_sample_id (bool): whether to return the sample_id as a part of each sample
"""
@property
@ -269,6 +294,7 @@ class AudioToCharDataset(_AudioTextDataset):
'a_sig_length': NeuralType(tuple('B'), LengthsType()),
'transcripts': NeuralType(('B', 'T'), LabelsType()),
'transcript_length': NeuralType(tuple('B'), LengthsType()),
'sample_id': NeuralType(tuple('B'), LengthsType(), optional=True),
}
def __init__(
@ -289,6 +315,7 @@ class AudioToCharDataset(_AudioTextDataset):
eos_id: Optional[int] = None,
pad_id: int = 0,
parser: Union[str, Callable] = 'en',
return_sample_id: bool = False,
):
self.labels = labels
@ -309,6 +336,7 @@ class AudioToCharDataset(_AudioTextDataset):
bos_id=bos_id,
eos_id=eos_id,
pad_id=pad_id,
return_sample_id=return_sample_id,
)
@ -806,6 +834,7 @@ class AudioToBPEDataset(_AudioTextDataset):
trim: Whether to trim silence segments
use_start_end_token: Boolean which dictates whether to add [BOS] and [EOS]
tokens to beginning and ending of speech respectively.
return_sample_id (bool): whether to return the sample_id as a part of each sample
"""
@property
@ -817,6 +846,7 @@ class AudioToBPEDataset(_AudioTextDataset):
'a_sig_length': NeuralType(tuple('B'), LengthsType()),
'transcripts': NeuralType(('B', 'T'), LabelsType()),
'transcript_length': NeuralType(tuple('B'), LengthsType()),
'sample_id': NeuralType(tuple('B'), LengthsType(), optional=True),
}
def __init__(
@ -831,6 +861,7 @@ class AudioToBPEDataset(_AudioTextDataset):
max_utts: int = 0,
trim: bool = False,
use_start_end_token: bool = True,
return_sample_id: bool = False,
):
if use_start_end_token and hasattr(tokenizer, 'bos_token'):
bos_id = tokenizer.bos_id
@ -868,6 +899,7 @@ class AudioToBPEDataset(_AudioTextDataset):
eos_id=eos_id,
pad_id=pad_id,
trim=trim,
return_sample_id=return_sample_id,
)
@ -956,12 +988,13 @@ class _TarredAudioToTextDataset(IterableDataset):
sampled at least once during 1 epoch.
global_rank (int): Worker rank, used for partitioning shards. Defaults to 0.
world_size (int): Total number of processes, used for partitioning shards. Defaults to 0.
return_sample_id (bool): whether to return the sample_id as a part of each sample
"""
def __init__(
self,
audio_tar_filepaths: Union[str, List[str]],
manifest_filepath: Union[str, List[str]],
manifest_filepath: str,
parser: Callable,
sample_rate: int,
int_values: bool = False,
@ -977,6 +1010,7 @@ class _TarredAudioToTextDataset(IterableDataset):
shard_strategy: str = "scatter",
global_rank: int = 0,
world_size: int = 0,
return_sample_id: bool = False,
):
self.collection = collections.ASRAudioText(
manifests_files=manifest_filepath,
@ -992,6 +1026,7 @@ class _TarredAudioToTextDataset(IterableDataset):
self.eos_id = eos_id
self.bos_id = bos_id
self.pad_id = pad_id
self.return_sample_id = return_sample_id
valid_shard_strategies = ['scatter', 'replicate']
if shard_strategy not in valid_shard_strategies:
@ -1118,7 +1153,13 @@ class _TarredAudioToTextDataset(IterableDataset):
t = t + [self.eos_id]
tl += 1
return f, fl, torch.tensor(t).long(), torch.tensor(tl).long()
if self.return_sample_id:
return f, fl, torch.tensor(t).long(), torch.tensor(tl).long(), manifest_idx
else:
return f, fl, torch.tensor(t).long(), torch.tensor(tl).long()
def get_manifest_sample(self, sample_id):
return self.collection[sample_id]
def __iter__(self):
return self._dataset.__iter__()
@ -1208,6 +1249,7 @@ class TarredAudioToCharDataset(_TarredAudioToTextDataset):
sampled at least once during 1 epoch.
global_rank (int): Worker rank, used for partitioning shards. Defaults to 0.
world_size (int): Total number of processes, used for partitioning shards. Defaults to 0.
return_sample_id (bool): whether to return the sample_id as a part of each sample
"""
def __init__(
@ -1233,6 +1275,7 @@ class TarredAudioToCharDataset(_TarredAudioToTextDataset):
shard_strategy: str = "scatter",
global_rank: int = 0,
world_size: int = 0,
return_sample_id: bool = False,
):
self.labels = labels
@ -1258,6 +1301,7 @@ class TarredAudioToCharDataset(_TarredAudioToTextDataset):
shard_strategy=shard_strategy,
global_rank=global_rank,
world_size=world_size,
return_sample_id=return_sample_id,
)
@ -1332,12 +1376,13 @@ class TarredAudioToBPEDataset(_TarredAudioToTextDataset):
sampled at least once during 1 epoch.
global_rank (int): Worker rank, used for partitioning shards. Defaults to 0.
world_size (int): Total number of processes, used for partitioning shards. Defaults to 0.
return_sample_id (bool): whether to return the sample_id as a part of each sample
"""
def __init__(
self,
audio_tar_filepaths: Union[str, List[str]],
manifest_filepath: Union[str, List[str]],
manifest_filepath: str,
tokenizer: 'nemo.collections.common.tokenizers.TokenizerSpec',
sample_rate: int,
int_values: bool = False,
@ -1351,6 +1396,7 @@ class TarredAudioToBPEDataset(_TarredAudioToTextDataset):
shard_strategy: str = "scatter",
global_rank: int = 0,
world_size: int = 0,
return_sample_id: bool = False,
):
if use_start_end_token and hasattr(tokenizer, 'bos_token'):
bos_id = tokenizer.bos_id
@ -1393,4 +1439,5 @@ class TarredAudioToBPEDataset(_TarredAudioToTextDataset):
shard_strategy=shard_strategy,
global_rank=global_rank,
world_size=world_size,
return_sample_id=return_sample_id,
)

View file

@ -140,6 +140,7 @@ class _AudioTextDALIDataset(Iterator):
global_rank (int): Worker rank, used for partitioning shards. Defaults to 0.
world_size (int): Total number of processes, used for partitioning shards. Defaults to 1.
preprocessor_cfg (DictConfig): Preprocessor configuration. Supports AudioToMelSpectrogramPreprocessor and AudioToMFCCPreprocessor.
return_sample_id (bool): whether to return the sample_id as a part of each sample (not supported yet).
"""
def __init__(
@ -162,8 +163,15 @@ class _AudioTextDALIDataset(Iterator):
global_rank: int = 0,
world_size: int = 1,
preprocessor_cfg: DictConfig = None,
return_sample_id: bool = False,
):
self.drop_last = drop_last # used by lr_scheduler
if return_sample_id:
raise ValueError(
"Currently DALI data layers don't support returning the sample_id and return_sample_id can not be enabled."
)
self.return_sample_id = return_sample_id
if not HAVE_DALI:
raise ModuleNotFoundError(
f"{self} requires NVIDIA DALI to be installed. "
@ -519,6 +527,7 @@ class AudioToCharDALIDataset(_AudioTextDALIDataset):
global_rank (int): Worker rank, used for partitioning shards. Defaults to 0.
world_size (int): Total number of processes, used for partitioning shards. Defaults to 1.
preprocessor_cfg (DictConfig): Preprocessor configuration. Supports AudioToMelSpectrogramPreprocessor and AudioToMFCCPreprocessor.
return_sample_id (bool): whether to return the sample_id as a part of each sample (not supported yet).
"""
def __init__(
@ -545,6 +554,7 @@ class AudioToCharDALIDataset(_AudioTextDALIDataset):
global_rank: int = 0,
world_size: int = 1,
preprocessor_cfg: DictConfig = None,
return_sample_id: bool = False,
):
self.labels = labels
@ -571,6 +581,7 @@ class AudioToCharDALIDataset(_AudioTextDALIDataset):
global_rank=global_rank,
world_size=world_size,
preprocessor_cfg=preprocessor_cfg,
return_sample_id=return_sample_id,
)
@ -607,6 +618,7 @@ class AudioToBPEDALIDataset(_AudioTextDALIDataset):
preprocessor_cfg (DictConfig): Preprocessor configuration. Supports AudioToMelSpectrogramPreprocessor and AudioToMFCCPreprocessor.
use_start_end_token (bool): Boolean which dictates whether to add [BOS] and [EOS] tokens to beginning and
ending of speech respectively.
return_sample_id (bool): whether to return the sample_id as a part of each sample (not supported yet).
"""
def __init__(
@ -627,7 +639,9 @@ class AudioToBPEDALIDataset(_AudioTextDALIDataset):
world_size: int = 1,
preprocessor_cfg: DictConfig = None,
use_start_end_token: bool = True,
return_sample_id: bool = False,
):
if use_start_end_token and hasattr(tokenizer, 'bos_token'):
bos_id = tokenizer.bos_id
else:
@ -670,4 +684,5 @@ class AudioToBPEDALIDataset(_AudioTextDALIDataset):
global_rank=global_rank,
world_size=world_size,
preprocessor_cfg=preprocessor_cfg,
return_sample_id=return_sample_id,
)

View file

@ -12,11 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Union
import json
from typing import Any, List, Optional, Union
import torch
from omegaconf import DictConfig, open_dict
from omegaconf.listconfig import ListConfig
from pytorch_lightning.callbacks import BasePredictionWriter
from torch.utils.data import ChainDataset
from nemo.collections.asr.data import audio_to_text, audio_to_text_dali
@ -91,6 +93,7 @@ def get_char_dataset(config: dict, augmentor: Optional['AudioAugmentor'] = None)
normalize=config.get('normalize_transcripts', False),
trim=config.get('trim_silence', False),
parser=config.get('parser', 'en'),
return_sample_id=config.get('return_sample_id', False),
)
return dataset
@ -120,6 +123,7 @@ def get_bpe_dataset(
max_utts=config.get('max_utts', 0),
trim=config.get('trim_silence', False),
use_start_end_token=config.get('use_start_end_token', True),
return_sample_id=config.get('return_sample_id', False),
)
return dataset
@ -182,6 +186,7 @@ def get_tarred_dataset(
shard_strategy=config.get('tarred_shard_strategy', 'scatter'),
global_rank=global_rank,
world_size=world_size,
return_sample_id=config.get('return_sample_id', False),
)
else:
dataset = audio_to_text.TarredAudioToBPEDataset(
@ -200,6 +205,7 @@ def get_tarred_dataset(
shard_strategy=config.get('tarred_shard_strategy', 'scatter'),
global_rank=global_rank,
world_size=world_size,
return_sample_id=config.get('return_sample_id', False),
)
datasets.append(dataset)
@ -251,6 +257,7 @@ def get_dali_char_dataset(
global_rank=global_rank,
world_size=world_size,
preprocessor_cfg=preprocessor_cfg,
return_sample_id=config.get('return_sample_id', False),
)
return dataset
@ -295,10 +302,44 @@ def get_dali_bpe_dataset(
global_rank=global_rank,
world_size=world_size,
preprocessor_cfg=preprocessor_cfg,
return_sample_id=config.get('return_sample_id', False),
)
return dataset
class ASRPredictionWriter(BasePredictionWriter):
def __init__(self, dataset, output_file: str):
super().__init__(write_interval="batch")
self.outf = open(output_file, 'w')
self.dataset = dataset
self.samples_num = 0
def write_on_batch_end(
self,
trainer,
pl_module: 'LightningModule',
prediction: Any,
batch_indices: List[int],
batch: Any,
batch_idx: int,
dataloader_idx: int,
):
for sample_id, transcribed_text in prediction:
item = {}
sample = self.dataset.get_manifest_sample(sample_id)
item["audio_filepath"] = sample.audio_file
item["duration"] = sample.duration
item["text"] = sample.text_raw
item["pred_text"] = transcribed_text
self.outf.write(json.dumps(item) + "\n")
self.samples_num += 1
return
def close_output_file(self):
self.outf.close()
return self.samples_num
def convert_to_config_list(initial_list):
if initial_list is None or initial_list == []:
raise ValueError("manifest_filepaths and tarred_audio_filepaths must not be empty.")

View file

@ -12,16 +12,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from nemo.collections.asr.models.configs.asr_models_config import (
ASRDatasetConfig,
EncDecCTCConfig,
EncDecCTCModelConfig,
)
from nemo.collections.asr.models.configs.classification_models_config import (
EncDecClassificationConfig,
EncDecClassificationDatasetConfig,
EncDecClassificationModelConfig,
)
from nemo.collections.asr.models.configs.ctc_models_config import (
EncDecCTCConfig,
EncDecCTCDatasetConfig,
EncDecCTCModelConfig,
)
from nemo.collections.asr.models.configs.matchboxnet_config import (
EncDecClassificationModelConfigBuilder,
MatchboxNetModelConfig,

View file

@ -27,15 +27,15 @@ from nemo.core.config import modelPT as model_cfg
@dataclass
class EncDecCTCDatasetConfig(nemo.core.classes.dataset.DatasetConfig):
manifest_filepath: Optional[str] = None
class ASRDatasetConfig(nemo.core.classes.dataset.DatasetConfig):
manifest_filepath: Optional[Any] = None
sample_rate: int = MISSING
labels: List[str] = MISSING
trim_silence: bool = False
# Tarred dataset support
is_tarred: bool = False
tarred_audio_filepaths: Optional[str] = None
tarred_audio_filepaths: Optional[Any] = None
tarred_shard_strategy: str = "scatter"
shuffle_n: int = 0
@ -54,6 +54,7 @@ class EncDecCTCDatasetConfig(nemo.core.classes.dataset.DatasetConfig):
bos_id: Optional[int] = None
pad_id: int = 0
use_start_end_token: bool = False
return_sample_id: Optional[bool] = False
@dataclass
@ -66,9 +67,9 @@ class EncDecCTCConfig(model_cfg.ModelConfig):
labels: List[str] = MISSING
# Dataset configs
train_ds: EncDecCTCDatasetConfig = EncDecCTCDatasetConfig(manifest_filepath=None, shuffle=True)
validation_ds: EncDecCTCDatasetConfig = EncDecCTCDatasetConfig(manifest_filepath=None, shuffle=False)
test_ds: EncDecCTCDatasetConfig = EncDecCTCDatasetConfig(manifest_filepath=None, shuffle=False)
train_ds: ASRDatasetConfig = ASRDatasetConfig(manifest_filepath=None, shuffle=True)
validation_ds: ASRDatasetConfig = ASRDatasetConfig(manifest_filepath=None, shuffle=False)
test_ds: ASRDatasetConfig = ASRDatasetConfig(manifest_filepath=None, shuffle=False)
# Optimizer / Scheduler config
optim: Optional[model_cfg.OptimConfig] = model_cfg.OptimConfig(sched=model_cfg.SchedConfig())

View file

@ -17,7 +17,7 @@ from typing import Any, Callable, List, Optional
from omegaconf import MISSING
from nemo.collections.asr.models.configs import ctc_models_config as ctc_cfg
from nemo.collections.asr.models.configs import asr_models_config as ctc_cfg
from nemo.collections.asr.modules.audio_preprocessing import (
AudioToMelSpectrogramPreprocessorConfig,
SpectrogramAugmentationConfig,
@ -174,13 +174,11 @@ class JasperModelConfig(ctc_cfg.EncDecCTCConfig):
labels: List[str] = MISSING
# Dataset configs
train_ds: ctc_cfg.EncDecCTCDatasetConfig = ctc_cfg.EncDecCTCDatasetConfig(
train_ds: ctc_cfg.ASRDatasetConfig = ctc_cfg.ASRDatasetConfig(
manifest_filepath=None, shuffle=True, trim_silence=True
)
validation_ds: ctc_cfg.EncDecCTCDatasetConfig = ctc_cfg.EncDecCTCDatasetConfig(
manifest_filepath=None, shuffle=False
)
test_ds: ctc_cfg.EncDecCTCDatasetConfig = ctc_cfg.EncDecCTCDatasetConfig(manifest_filepath=None, shuffle=False)
validation_ds: ctc_cfg.ASRDatasetConfig = ctc_cfg.ASRDatasetConfig(manifest_filepath=None, shuffle=False)
test_ds: ctc_cfg.ASRDatasetConfig = ctc_cfg.ASRDatasetConfig(manifest_filepath=None, shuffle=False)
# Optimizer / Scheduler config
optim: Optional[model_cfg.OptimConfig] = model_cfg.OptimConfig(sched=model_cfg.SchedConfig())
@ -262,13 +260,13 @@ class EncDecCTCModelConfigBuilder(model_cfg.ModelConfigBuilder):
# Note: Autocomplete for users wont work without these overrides
# But practically it is not needed since python will infer at runtime
# def set_train_ds(self, cfg: Optional[ctc_cfg.EncDecCTCDatasetConfig] = None):
# def set_train_ds(self, cfg: Optional[ctc_cfg.ASRDatasetConfig] = None):
# super().set_train_ds(cfg)
#
# def set_validation_ds(self, cfg: Optional[ctc_cfg.EncDecCTCDatasetConfig] = None):
# def set_validation_ds(self, cfg: Optional[ctc_cfg.ASRDatasetConfig] = None):
# super().set_validation_ds(cfg)
#
# def set_test_ds(self, cfg: Optional[ctc_cfg.EncDecCTCDatasetConfig] = None):
# def set_test_ds(self, cfg: Optional[ctc_cfg.ASRDatasetConfig] = None):
# super().set_test_ds(cfg)
def _finalize_cfg(self):

View file

@ -506,6 +506,7 @@ class EncDecCTCModel(ASRModel, ExportableEncDecModel, ASRModuleMixin):
"input_signal_length": NeuralType(tuple('B'), LengthsType(), optional=True),
"processed_signal": NeuralType(('B', 'D', 'T'), SpectrogramType(), optional=True),
"processed_signal_length": NeuralType(tuple('B'), LengthsType(), optional=True),
"sample_id": NeuralType(tuple('B'), LengthsType(), optional=True),
}
@property
@ -596,6 +597,22 @@ class EncDecCTCModel(ASRModel, ExportableEncDecModel, ASRModuleMixin):
return {'loss': loss_value, 'log': tensorboard_logs}
def predict_step(self, batch, batch_idx, dataloader_idx=0):
signal, signal_len, transcript, transcript_len, sample_id = batch
if isinstance(batch, DALIOutputs) and batch.has_processed_signal:
log_probs, encoded_len, predictions = self.forward(
processed_signal=signal, processed_signal_length=signal_len
)
else:
log_probs, encoded_len, predictions = self.forward(input_signal=signal, input_signal_length=signal_len)
transcribed_texts = self._wer.ctc_decoder_predictions_tensor(
predictions=predictions, predictions_len=encoded_len, return_hypotheses=False,
)
sample_id = sample_id.cpu().detach().numpy()
return list(zip(sample_id, transcribed_texts))
def validation_step(self, batch, batch_idx, dataloader_idx=0):
signal, signal_len, transcript, transcript_len = batch
if isinstance(batch, DALIOutputs) and batch.has_processed_signal:

View file

@ -647,7 +647,6 @@ class EncDecRNNTModel(ASRModel, ASRModuleMixin, ExportableEncDecJointModel):
if hasattr(self, '_trainer') and self._trainer is not None:
log_every_n_steps = self._trainer.log_every_n_steps
sample_id = self._trainer.global_step
else:
log_every_n_steps = 1
sample_id = batch_nb
@ -699,6 +698,23 @@ class EncDecRNNTModel(ASRModel, ASRModuleMixin, ExportableEncDecJointModel):
return {'loss': loss_value}
def predict_step(self, batch, batch_idx, dataloader_idx=0):
signal, signal_len, transcript, transcript_len, sample_id = batch
# forward() only performs encoder forward
if isinstance(batch, DALIOutputs) and batch.has_processed_signal:
encoded, encoded_len = self.forward(processed_signal=signal, processed_signal_length=signal_len)
else:
encoded, encoded_len = self.forward(input_signal=signal, input_signal_length=signal_len)
del signal
best_hyp_text, all_hyp_text = self.decoding.rnnt_decoder_predictions_tensor(
encoder_output=encoded, encoded_lengths=encoded_len, return_hypotheses=False
)
sample_id = sample_id.cpu().detach().numpy()
return list(zip(sample_id, best_hyp_text))
def validation_step(self, batch, batch_idx, dataloader_idx=0):
signal, signal_len, transcript, transcript_len = batch

View file

@ -269,6 +269,7 @@ class Typing(ABC):
# Precompute metadata
metadata = TypecheckMetadata(original_types=output_types, ignore_collections=ignore_collections)
out_types_list = list(metadata.base_types.items())
mandatory_out_types_list = list(metadata.mandatory_types.items())
# First convert all outputs to list/tuple format to check correct number of outputs
if type(out_objects) in (list, tuple):
@ -287,15 +288,15 @@ class Typing(ABC):
pass
# In all other cases, python will wrap multiple outputs into an outer tuple.
# As such, now the number of elements in this outer tuple should exactly match
# the number of output types defined.
elif len(out_types_list) != len(out_container):
# Allow number of output arguments to be <= total output neural types and >= mandatory outputs.
elif len(out_container) > len(out_types_list) or len(out_container) < len(mandatory_out_types_list):
raise TypeError(
"Number of output arguments provided ({}) is not as expected ({}).\n"
"This can be either because insufficient number of output NeuralTypes were provided,"
"Number of output arguments provided ({}) is not as expected. It should be larger than {} and less than {}.\n"
"This can be either because insufficient/extra number of output NeuralTypes were provided,"
"or the provided NeuralTypes {} should enable container support "
"(add '[]' to the NeuralType definition)".format(
len(out_container), len(output_types), output_types
len(out_container), len(out_types_list), len(mandatory_out_types_list), output_types
)
)

View file

@ -105,5 +105,5 @@ class DatasetConfig:
batch_size: int = 32
drop_last: bool = False
shuffle: bool = False
num_workers: Optional[int] = None
num_workers: Optional[int] = 0
pin_memory: bool = True

View file

@ -237,7 +237,7 @@ class TestEncDecCTCModel:
os.rename(old_tokenizer_dir + '.bkp', old_tokenizer_dir)
@pytest.mark.unit
def test_EncDecCTCDatasetConfig_for_AudioToBPEDataset(self):
def test_ASRDatasetConfig_for_AudioToBPEDataset(self):
# ignore some additional arguments as dataclass is generic
IGNORE_ARGS = [
'is_tarred',
@ -261,10 +261,7 @@ class TestEncDecCTCModel:
REMAP_ARGS = {'trim_silence': 'trim', 'labels': 'tokenizer'}
result = assert_dataclass_signature_match(
audio_to_text.AudioToBPEDataset,
configs.EncDecCTCDatasetConfig,
ignore_args=IGNORE_ARGS,
remap_args=REMAP_ARGS,
audio_to_text.AudioToBPEDataset, configs.ASRDatasetConfig, ignore_args=IGNORE_ARGS, remap_args=REMAP_ARGS,
)
signatures_match, cls_subset, dataclass_subset = result
@ -273,7 +270,7 @@ class TestEncDecCTCModel:
assert dataclass_subset is None
@pytest.mark.unit
def test_EncDecCTCDatasetConfig_for_TarredAudioToBPEDataset(self):
def test_ASRDatasetConfig_for_TarredAudioToBPEDataset(self):
# ignore some additional arguments as dataclass is generic
IGNORE_ARGS = [
'is_tarred',
@ -303,7 +300,7 @@ class TestEncDecCTCModel:
result = assert_dataclass_signature_match(
audio_to_text.TarredAudioToBPEDataset,
configs.EncDecCTCDatasetConfig,
configs.ASRDatasetConfig,
ignore_args=IGNORE_ARGS,
remap_args=REMAP_ARGS,
)

View file

@ -240,7 +240,7 @@ class TestEncDecCTCModel:
# trainer and exp manager should be there
@pytest.mark.unit
def test_EncDecCTCDatasetConfig_for_AudioToCharDataset(self):
def test_ASRDatasetConfig_for_AudioToCharDataset(self):
# ignore some additional arguments as dataclass is generic
IGNORE_ARGS = [
'is_tarred',
@ -259,10 +259,7 @@ class TestEncDecCTCModel:
REMAP_ARGS = {'trim_silence': 'trim'}
result = assert_dataclass_signature_match(
audio_to_text.AudioToCharDataset,
configs.EncDecCTCDatasetConfig,
ignore_args=IGNORE_ARGS,
remap_args=REMAP_ARGS,
audio_to_text.AudioToCharDataset, configs.ASRDatasetConfig, ignore_args=IGNORE_ARGS, remap_args=REMAP_ARGS,
)
signatures_match, cls_subset, dataclass_subset = result
@ -271,7 +268,7 @@ class TestEncDecCTCModel:
assert dataclass_subset is None
@pytest.mark.unit
def test_EncDecCTCDatasetConfig_for_TarredAudioToCharDataset(self):
def test_ASRDatasetConfig_for_TarredAudioToCharDataset(self):
# ignore some additional arguments as dataclass is generic
IGNORE_ARGS = [
'is_tarred',
@ -294,7 +291,7 @@ class TestEncDecCTCModel:
result = assert_dataclass_signature_match(
audio_to_text.TarredAudioToCharDataset,
configs.EncDecCTCDatasetConfig,
configs.ASRDatasetConfig,
ignore_args=IGNORE_ARGS,
remap_args=REMAP_ARGS,
)

View file

@ -341,7 +341,6 @@ class TestASRDatasets:
num_samples = 10
batch_size = 1
device = 'gpu' if torch.cuda.is_available() else 'cpu'
texts = []
with tempfile.NamedTemporaryFile(mode='w', encoding='utf-8') as f:
@ -394,7 +393,6 @@ class TestASRDatasets:
)
ref_preprocessor = EncDecCTCModel.from_config_dict(preprocessor_cfg)
count = 0
for ref_data, dali_data in zip(ref_dataloader, dali_dataset):
ref_audio, ref_audio_len, _, _ = ref_data
ref_features, ref_features_len = ref_preprocessor(input_signal=ref_audio, length=ref_audio_len)