diff --git a/examples/asr/conf/conformer/conformer_transducer_bpe.yaml b/examples/asr/conf/conformer/conformer_transducer_bpe.yaml index e002b0cce..9f642afdc 100644 --- a/examples/asr/conf/conformer/conformer_transducer_bpe.yaml +++ b/examples/asr/conf/conformer/conformer_transducer_bpe.yaml @@ -28,9 +28,6 @@ model: log_prediction: true # enables logging sample predictions in the output during training model_defaults: - se: true - se_context_size: -1 - # encoder / decoder / joint values enc_hidden: ${model.encoder.d_model} pred_hidden: 640 joint_hidden: 640 @@ -107,8 +104,10 @@ model: # Multi-headed Attention Module's params self_attention_model: rel_pos # rel_pos or abs_pos - n_heads: 8 - xscaling: true # scales up the inputs by sqrt(d_model) + n_heads: 8 # may need to be lower for smaller d_models + # [left, right] specifies the number of steps to be seen from left and right of each step in self-attention + att_context_size: [-1, -1] # -1 means unlimited context + xscaling: true # scales up the input embeddings by sqrt(d_model) untie_biases: true # unties the biases of the TransformerXL layers pos_emb_max_len: 5000 diff --git a/examples/asr/conf/conformer/conformer_transducer_char.yaml b/examples/asr/conf/conformer/conformer_transducer_char.yaml index 3f5d0ec2f..86f734b91 100644 --- a/examples/asr/conf/conformer/conformer_transducer_char.yaml +++ b/examples/asr/conf/conformer/conformer_transducer_char.yaml @@ -31,9 +31,6 @@ model: "n", "o", "p", "q", "r", "s", "t", "u", "v", "w", "x", "y", "z", "'"] model_defaults: - se: true - se_context_size: -1 - # encoder / decoder / joint values enc_hidden: ${model.encoder.d_model} pred_hidden: 640 joint_hidden: 640 @@ -102,8 +99,10 @@ model: # Multi-headed Attention Module's params self_attention_model: rel_pos # rel_pos or abs_pos - n_heads: 8 - xscaling: true # scales up the inputs by sqrt(d_model) + n_heads: 8 # may need to be lower for smaller d_models + # [left, right] specifies the number of steps to be seen from left and right of each step in self-attention + att_context_size: [-1, -1] # -1 means unlimited context + xscaling: true # scales up the input embeddings by sqrt(d_model) untie_biases: true # unties the biases of the TransformerXL layers pos_emb_max_len: 5000 diff --git a/examples/asr/transcribe_speech_parallel.py b/examples/asr/transcribe_speech_parallel.py new file mode 100644 index 000000000..13328119b --- /dev/null +++ b/examples/asr/transcribe_speech_parallel.py @@ -0,0 +1,189 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +# ASR transcribe/inference with multi-GPU/multi-node support for large datasets +# It supports both tarred and non-tarred datasets +# Arguments +# model: path to a nemo/PTL checkpoint file or name of a pretrained model +# predict_ds: config of the dataset/dataloader +# output_path: path to store the predictions +# return_predictions: whether to return the predictions as output other than writing into the files +# use_cer: whether to calculate the error in terms of CER or use the default WER +# +# Results of each GPU/worker is written into a file named 'predictions_{rank}.json, and aggregated results of all workers are written into 'predictions_all.json' + +Example for non-tarred datasets: + +python transcribe_speech_parallel.py \ + model=stt_en_conformer_ctc_large \ + predict_ds.manifest_filepath=/dataset/manifest_file.json \ + predict_ds.batch_size=16 \ + output_path=/tmp/ + +Example for tarred datasets: + +python transcribe_speech_parallel.py \ + predict_ds.is_tarred=true \ + predict_ds.manifest_filepath=/tarred_dataset/tarred_audio_manifest.json \ + predict_ds.tarred_audio_filepaths=/tarred_dataset/audio__OP_0..127_CL_.tar \ + ... + +By default the trainer uses all the GPUs available and default precision is FP32. +By setting the trainer config you may control these configs. For example to do the predictions with AMP on just two GPUs: + +python transcribe_speech_parallel.py \ + trainer.precision=16 \ + trainer.gpus=2 \ + ... + +You may control the dataloader's config by setting the predict_ds: + +python transcribe_speech_parallel.py \ + predict_ds.num_workers=8 \ + predict_ds.min_duration=2.0 \ + predict_ds.sample_rate=16000 \ + model=stt_en_conformer_ctc_small \ + ... + +""" + + +import itertools +import json +import os +from dataclasses import dataclass, is_dataclass +from typing import Optional + +import pytorch_lightning as ptl +import torch +from omegaconf import OmegaConf + +from nemo.collections.asr.data.audio_to_text_dataset import ASRPredictionWriter +from nemo.collections.asr.metrics.rnnt_wer import RNNTDecodingConfig +from nemo.collections.asr.metrics.wer import word_error_rate +from nemo.collections.asr.models import ASRModel +from nemo.collections.asr.models.configs.asr_models_config import ASRDatasetConfig +from nemo.core.config import TrainerConfig, hydra_runner +from nemo.utils import logging +from nemo.utils.get_rank import is_global_rank_zero + + +@dataclass +class ParallelTranscriptionConfig: + model: Optional[str] = None # name + predict_ds: ASRDatasetConfig = ASRDatasetConfig(return_sample_id=True, num_workers=4) + output_path: Optional[str] = None + # when return_predictions is enabled, the prediction call would keep all the predictions in memory and return them when prediction is done + return_predictions: bool = False + use_cer: bool = False + + # decoding strategy for RNNT models + rnnt_decoding: RNNTDecodingConfig = RNNTDecodingConfig() + trainer: TrainerConfig = TrainerConfig(gpus=-1, accelerator="ddp") + + +def match_train_config(predict_ds, train_ds): + # It copies the important configurations from the train dataset of the model + # into the predict_ds to be used for prediction. It is needed to match the training configurations. + if train_ds is None: + return + + predict_ds.sample_rate = train_ds.get("sample_rate", 16000) + cfg_name_list = [ + "int_values", + "use_start_end_token", + "blank_index", + "unk_index", + "normalize", + "parser", + "eos_id", + "bos_id", + "pad_id", + ] + + if is_dataclass(predict_ds): + predict_ds = OmegaConf.structured(predict_ds) + for cfg_name in cfg_name_list: + if hasattr(train_ds, cfg_name): + setattr(predict_ds, cfg_name, getattr(train_ds, cfg_name)) + + return predict_ds + + +@hydra_runner(config_name="TranscriptionConfig", schema=ParallelTranscriptionConfig) +def main(cfg: ParallelTranscriptionConfig): + if cfg.model.endswith(".nemo"): + logging.info("Attempting to initialize from .nemo file") + model = ASRModel.restore_from(restore_path=cfg.model, map_location="cpu") + elif cfg.model.endswith(".ckpt"): + logging.info("Attempting to initialize from .ckpt file") + model = ASRModel.load_from_checkpoint(checkpoint_path=cfg.model, map_location="cpu") + else: + logging.info( + "Attempting to initialize from a pretrained model as the model name does not have the extension of .nemo or .ckpt" + ) + model = ASRModel.from_pretrained(model_name=cfg.model, map_location="cpu") + + trainer = ptl.Trainer(**cfg.trainer) + + cfg.predict_ds.return_sample_id = True + cfg.predict_ds = match_train_config(predict_ds=cfg.predict_ds, train_ds=model.cfg.train_ds) + data_loader = model._setup_dataloader_from_config(cfg.predict_ds) + + os.makedirs(cfg.output_path, exist_ok=True) + # trainer.global_rank is not valid before predict() is called. Need this hack to find the correct global_rank. + global_rank = trainer.node_rank * trainer.num_gpus + int(os.environ.get("LOCAL_RANK", 0)) + output_file = os.path.join(cfg.output_path, f"predictions_{global_rank}.json") + predictor_writer = ASRPredictionWriter(dataset=data_loader.dataset, output_file=output_file) + trainer.callbacks.extend([predictor_writer]) + + predictions = trainer.predict(model=model, dataloaders=data_loader, return_predictions=cfg.return_predictions) + if predictions is not None: + predictions = list(itertools.chain.from_iterable(predictions)) + samples_num = predictor_writer.close_output_file() + + logging.info( + f"Prediction on rank {global_rank} is done for {samples_num} samples and results are stored in {output_file}." + ) + + if torch.distributed.is_initialized(): + torch.distributed.barrier() + + samples_num = 0 + pred_text_list = [] + text_list = [] + if is_global_rank_zero(): + output_file = os.path.join(cfg.output_path, f"predictions_all.json") + logging.info(f"Prediction files are being aggregated in {output_file}.") + with open(output_file, 'w') as outf: + for rank in range(trainer.world_size): + input_file = os.path.join(cfg.output_path, f"predictions_{rank}.json") + with open(input_file, 'r') as inpf: + lines = inpf.readlines() + for line in lines: + item = json.loads(line) + pred_text_list.append(item["pred_text"]) + text_list.append(item["text"]) + outf.write(json.dumps(item) + "\n") + samples_num += 1 + wer_cer = word_error_rate(hypotheses=pred_text_list, references=text_list, use_cer=cfg.use_cer) + logging.info( + f"Prediction is done for {samples_num} samples in total on all workers and results are aggregated in {output_file}." + ) + logging.info("{} for all predictions is {:.4f}.".format("CER" if cfg.use_cer else "WER", wer_cer)) + + +if __name__ == '__main__': + main() diff --git a/nemo/collections/asr/data/audio_to_text.py b/nemo/collections/asr/data/audio_to_text.py index a68db2d18..2da1eedc8 100644 --- a/nemo/collections/asr/data/audio_to_text.py +++ b/nemo/collections/asr/data/audio_to_text.py @@ -50,7 +50,14 @@ def _speech_collate_fn(batch, pad_id): encoded tokens, and encoded tokens length. This collate func assumes the signals are 1d torch tensors (i.e. mono audio). """ - _, audio_lengths, _, tokens_lengths = zip(*batch) + packed_batch = list(zip(*batch)) + if len(packed_batch) == 5: + _, audio_lengths, _, tokens_lengths, sample_ids = packed_batch + elif len(packed_batch) == 4: + sample_ids = None + _, audio_lengths, _, tokens_lengths = packed_batch + else: + raise ValueError("Expects 4 or 5 tensors in the batch!") max_audio_len = 0 has_audio = audio_lengths[0] is not None if has_audio: @@ -58,7 +65,11 @@ def _speech_collate_fn(batch, pad_id): max_tokens_len = max(tokens_lengths).item() audio_signal, tokens = [], [] - for sig, sig_len, tokens_i, tokens_i_len in batch: + for b in batch: + if len(b) == 5: + sig, sig_len, tokens_i, tokens_i_len, _ = b + else: + sig, sig_len, tokens_i, tokens_i_len = b if has_audio: sig_len = sig_len.item() if sig_len < max_audio_len: @@ -78,8 +89,11 @@ def _speech_collate_fn(batch, pad_id): audio_signal, audio_lengths = None, None tokens = torch.stack(tokens) tokens_lengths = torch.stack(tokens_lengths) - - return audio_signal, audio_lengths, tokens, tokens_lengths + if sample_ids is None: + return audio_signal, audio_lengths, tokens, tokens_lengths + else: + sample_ids = torch.tensor(sample_ids, dtype=torch.int32) + return audio_signal, audio_lengths, tokens, tokens_lengths, sample_ids class ASRManifestProcessor: @@ -115,7 +129,7 @@ class ASRManifestProcessor: self.parser = parser self.collection = collections.ASRAudioText( - manifests_files=manifest_filepath.split(','), + manifests_files=manifest_filepath, parser=parser, min_duration=min_duration, max_duration=max_duration, @@ -164,6 +178,7 @@ class _AudioTextDataset(Dataset): normalize: whether to normalize transcript text (default): True bos_id: Id of beginning of sequence symbol to append if not None eos_id: Id of end of sequence symbol to append if not None + return_sample_id (bool): whether to return the sample_id as a part of each sample """ @property @@ -175,6 +190,7 @@ class _AudioTextDataset(Dataset): 'a_sig_length': NeuralType(tuple('B'), LengthsType()), 'transcripts': NeuralType(('B', 'T'), LabelsType()), 'transcript_length': NeuralType(tuple('B'), LengthsType()), + 'sample_id': NeuralType(tuple('B'), LengthsType(), optional=True), } def __init__( @@ -191,6 +207,7 @@ class _AudioTextDataset(Dataset): bos_id: Optional[int] = None, eos_id: Optional[int] = None, pad_id: int = 0, + return_sample_id: bool = False, ): self.manifest_processor = ASRManifestProcessor( manifest_filepath=manifest_filepath, @@ -204,6 +221,10 @@ class _AudioTextDataset(Dataset): ) self.featurizer = WaveformFeaturizer(sample_rate=sample_rate, int_values=int_values, augmentor=augmentor) self.trim = trim + self.return_sample_id = return_sample_id + + def get_manifest_sample(self, sample_id): + return self.manifest_processor.collection[sample_id] def __getitem__(self, index): sample = self.manifest_processor.collection[index] @@ -219,7 +240,10 @@ class _AudioTextDataset(Dataset): t, tl = self.manifest_processor.process_text(index) - output = f, fl, torch.tensor(t).long(), torch.tensor(tl).long() + if self.return_sample_id: + output = f, fl, torch.tensor(t).long(), torch.tensor(tl).long(), index + else: + output = f, fl, torch.tensor(t).long(), torch.tensor(tl).long() return output @@ -258,6 +282,7 @@ class AudioToCharDataset(_AudioTextDataset): normalize: whether to normalize transcript text (default): True bos_id: Id of beginning of sequence symbol to append if not None eos_id: Id of end of sequence symbol to append if not None + return_sample_id (bool): whether to return the sample_id as a part of each sample """ @property @@ -269,6 +294,7 @@ class AudioToCharDataset(_AudioTextDataset): 'a_sig_length': NeuralType(tuple('B'), LengthsType()), 'transcripts': NeuralType(('B', 'T'), LabelsType()), 'transcript_length': NeuralType(tuple('B'), LengthsType()), + 'sample_id': NeuralType(tuple('B'), LengthsType(), optional=True), } def __init__( @@ -289,6 +315,7 @@ class AudioToCharDataset(_AudioTextDataset): eos_id: Optional[int] = None, pad_id: int = 0, parser: Union[str, Callable] = 'en', + return_sample_id: bool = False, ): self.labels = labels @@ -309,6 +336,7 @@ class AudioToCharDataset(_AudioTextDataset): bos_id=bos_id, eos_id=eos_id, pad_id=pad_id, + return_sample_id=return_sample_id, ) @@ -806,6 +834,7 @@ class AudioToBPEDataset(_AudioTextDataset): trim: Whether to trim silence segments use_start_end_token: Boolean which dictates whether to add [BOS] and [EOS] tokens to beginning and ending of speech respectively. + return_sample_id (bool): whether to return the sample_id as a part of each sample """ @property @@ -817,6 +846,7 @@ class AudioToBPEDataset(_AudioTextDataset): 'a_sig_length': NeuralType(tuple('B'), LengthsType()), 'transcripts': NeuralType(('B', 'T'), LabelsType()), 'transcript_length': NeuralType(tuple('B'), LengthsType()), + 'sample_id': NeuralType(tuple('B'), LengthsType(), optional=True), } def __init__( @@ -831,6 +861,7 @@ class AudioToBPEDataset(_AudioTextDataset): max_utts: int = 0, trim: bool = False, use_start_end_token: bool = True, + return_sample_id: bool = False, ): if use_start_end_token and hasattr(tokenizer, 'bos_token'): bos_id = tokenizer.bos_id @@ -868,6 +899,7 @@ class AudioToBPEDataset(_AudioTextDataset): eos_id=eos_id, pad_id=pad_id, trim=trim, + return_sample_id=return_sample_id, ) @@ -956,12 +988,13 @@ class _TarredAudioToTextDataset(IterableDataset): sampled at least once during 1 epoch. global_rank (int): Worker rank, used for partitioning shards. Defaults to 0. world_size (int): Total number of processes, used for partitioning shards. Defaults to 0. + return_sample_id (bool): whether to return the sample_id as a part of each sample """ def __init__( self, audio_tar_filepaths: Union[str, List[str]], - manifest_filepath: Union[str, List[str]], + manifest_filepath: str, parser: Callable, sample_rate: int, int_values: bool = False, @@ -977,6 +1010,7 @@ class _TarredAudioToTextDataset(IterableDataset): shard_strategy: str = "scatter", global_rank: int = 0, world_size: int = 0, + return_sample_id: bool = False, ): self.collection = collections.ASRAudioText( manifests_files=manifest_filepath, @@ -992,6 +1026,7 @@ class _TarredAudioToTextDataset(IterableDataset): self.eos_id = eos_id self.bos_id = bos_id self.pad_id = pad_id + self.return_sample_id = return_sample_id valid_shard_strategies = ['scatter', 'replicate'] if shard_strategy not in valid_shard_strategies: @@ -1118,7 +1153,13 @@ class _TarredAudioToTextDataset(IterableDataset): t = t + [self.eos_id] tl += 1 - return f, fl, torch.tensor(t).long(), torch.tensor(tl).long() + if self.return_sample_id: + return f, fl, torch.tensor(t).long(), torch.tensor(tl).long(), manifest_idx + else: + return f, fl, torch.tensor(t).long(), torch.tensor(tl).long() + + def get_manifest_sample(self, sample_id): + return self.collection[sample_id] def __iter__(self): return self._dataset.__iter__() @@ -1208,6 +1249,7 @@ class TarredAudioToCharDataset(_TarredAudioToTextDataset): sampled at least once during 1 epoch. global_rank (int): Worker rank, used for partitioning shards. Defaults to 0. world_size (int): Total number of processes, used for partitioning shards. Defaults to 0. + return_sample_id (bool): whether to return the sample_id as a part of each sample """ def __init__( @@ -1233,6 +1275,7 @@ class TarredAudioToCharDataset(_TarredAudioToTextDataset): shard_strategy: str = "scatter", global_rank: int = 0, world_size: int = 0, + return_sample_id: bool = False, ): self.labels = labels @@ -1258,6 +1301,7 @@ class TarredAudioToCharDataset(_TarredAudioToTextDataset): shard_strategy=shard_strategy, global_rank=global_rank, world_size=world_size, + return_sample_id=return_sample_id, ) @@ -1332,12 +1376,13 @@ class TarredAudioToBPEDataset(_TarredAudioToTextDataset): sampled at least once during 1 epoch. global_rank (int): Worker rank, used for partitioning shards. Defaults to 0. world_size (int): Total number of processes, used for partitioning shards. Defaults to 0. + return_sample_id (bool): whether to return the sample_id as a part of each sample """ def __init__( self, audio_tar_filepaths: Union[str, List[str]], - manifest_filepath: Union[str, List[str]], + manifest_filepath: str, tokenizer: 'nemo.collections.common.tokenizers.TokenizerSpec', sample_rate: int, int_values: bool = False, @@ -1351,6 +1396,7 @@ class TarredAudioToBPEDataset(_TarredAudioToTextDataset): shard_strategy: str = "scatter", global_rank: int = 0, world_size: int = 0, + return_sample_id: bool = False, ): if use_start_end_token and hasattr(tokenizer, 'bos_token'): bos_id = tokenizer.bos_id @@ -1393,4 +1439,5 @@ class TarredAudioToBPEDataset(_TarredAudioToTextDataset): shard_strategy=shard_strategy, global_rank=global_rank, world_size=world_size, + return_sample_id=return_sample_id, ) diff --git a/nemo/collections/asr/data/audio_to_text_dali.py b/nemo/collections/asr/data/audio_to_text_dali.py index cd6dc03da..3c46523dc 100644 --- a/nemo/collections/asr/data/audio_to_text_dali.py +++ b/nemo/collections/asr/data/audio_to_text_dali.py @@ -140,6 +140,7 @@ class _AudioTextDALIDataset(Iterator): global_rank (int): Worker rank, used for partitioning shards. Defaults to 0. world_size (int): Total number of processes, used for partitioning shards. Defaults to 1. preprocessor_cfg (DictConfig): Preprocessor configuration. Supports AudioToMelSpectrogramPreprocessor and AudioToMFCCPreprocessor. + return_sample_id (bool): whether to return the sample_id as a part of each sample (not supported yet). """ def __init__( @@ -162,8 +163,15 @@ class _AudioTextDALIDataset(Iterator): global_rank: int = 0, world_size: int = 1, preprocessor_cfg: DictConfig = None, + return_sample_id: bool = False, ): self.drop_last = drop_last # used by lr_scheduler + if return_sample_id: + raise ValueError( + "Currently DALI data layers don't support returning the sample_id and return_sample_id can not be enabled." + ) + self.return_sample_id = return_sample_id + if not HAVE_DALI: raise ModuleNotFoundError( f"{self} requires NVIDIA DALI to be installed. " @@ -519,6 +527,7 @@ class AudioToCharDALIDataset(_AudioTextDALIDataset): global_rank (int): Worker rank, used for partitioning shards. Defaults to 0. world_size (int): Total number of processes, used for partitioning shards. Defaults to 1. preprocessor_cfg (DictConfig): Preprocessor configuration. Supports AudioToMelSpectrogramPreprocessor and AudioToMFCCPreprocessor. + return_sample_id (bool): whether to return the sample_id as a part of each sample (not supported yet). """ def __init__( @@ -545,6 +554,7 @@ class AudioToCharDALIDataset(_AudioTextDALIDataset): global_rank: int = 0, world_size: int = 1, preprocessor_cfg: DictConfig = None, + return_sample_id: bool = False, ): self.labels = labels @@ -571,6 +581,7 @@ class AudioToCharDALIDataset(_AudioTextDALIDataset): global_rank=global_rank, world_size=world_size, preprocessor_cfg=preprocessor_cfg, + return_sample_id=return_sample_id, ) @@ -607,6 +618,7 @@ class AudioToBPEDALIDataset(_AudioTextDALIDataset): preprocessor_cfg (DictConfig): Preprocessor configuration. Supports AudioToMelSpectrogramPreprocessor and AudioToMFCCPreprocessor. use_start_end_token (bool): Boolean which dictates whether to add [BOS] and [EOS] tokens to beginning and ending of speech respectively. + return_sample_id (bool): whether to return the sample_id as a part of each sample (not supported yet). """ def __init__( @@ -627,7 +639,9 @@ class AudioToBPEDALIDataset(_AudioTextDALIDataset): world_size: int = 1, preprocessor_cfg: DictConfig = None, use_start_end_token: bool = True, + return_sample_id: bool = False, ): + if use_start_end_token and hasattr(tokenizer, 'bos_token'): bos_id = tokenizer.bos_id else: @@ -670,4 +684,5 @@ class AudioToBPEDALIDataset(_AudioTextDALIDataset): global_rank=global_rank, world_size=world_size, preprocessor_cfg=preprocessor_cfg, + return_sample_id=return_sample_id, ) diff --git a/nemo/collections/asr/data/audio_to_text_dataset.py b/nemo/collections/asr/data/audio_to_text_dataset.py index 52de19cc7..dd82d6c16 100644 --- a/nemo/collections/asr/data/audio_to_text_dataset.py +++ b/nemo/collections/asr/data/audio_to_text_dataset.py @@ -12,11 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Union +import json +from typing import Any, List, Optional, Union import torch from omegaconf import DictConfig, open_dict from omegaconf.listconfig import ListConfig +from pytorch_lightning.callbacks import BasePredictionWriter from torch.utils.data import ChainDataset from nemo.collections.asr.data import audio_to_text, audio_to_text_dali @@ -91,6 +93,7 @@ def get_char_dataset(config: dict, augmentor: Optional['AudioAugmentor'] = None) normalize=config.get('normalize_transcripts', False), trim=config.get('trim_silence', False), parser=config.get('parser', 'en'), + return_sample_id=config.get('return_sample_id', False), ) return dataset @@ -120,6 +123,7 @@ def get_bpe_dataset( max_utts=config.get('max_utts', 0), trim=config.get('trim_silence', False), use_start_end_token=config.get('use_start_end_token', True), + return_sample_id=config.get('return_sample_id', False), ) return dataset @@ -182,6 +186,7 @@ def get_tarred_dataset( shard_strategy=config.get('tarred_shard_strategy', 'scatter'), global_rank=global_rank, world_size=world_size, + return_sample_id=config.get('return_sample_id', False), ) else: dataset = audio_to_text.TarredAudioToBPEDataset( @@ -200,6 +205,7 @@ def get_tarred_dataset( shard_strategy=config.get('tarred_shard_strategy', 'scatter'), global_rank=global_rank, world_size=world_size, + return_sample_id=config.get('return_sample_id', False), ) datasets.append(dataset) @@ -251,6 +257,7 @@ def get_dali_char_dataset( global_rank=global_rank, world_size=world_size, preprocessor_cfg=preprocessor_cfg, + return_sample_id=config.get('return_sample_id', False), ) return dataset @@ -295,10 +302,44 @@ def get_dali_bpe_dataset( global_rank=global_rank, world_size=world_size, preprocessor_cfg=preprocessor_cfg, + return_sample_id=config.get('return_sample_id', False), ) return dataset +class ASRPredictionWriter(BasePredictionWriter): + def __init__(self, dataset, output_file: str): + super().__init__(write_interval="batch") + self.outf = open(output_file, 'w') + self.dataset = dataset + self.samples_num = 0 + + def write_on_batch_end( + self, + trainer, + pl_module: 'LightningModule', + prediction: Any, + batch_indices: List[int], + batch: Any, + batch_idx: int, + dataloader_idx: int, + ): + for sample_id, transcribed_text in prediction: + item = {} + sample = self.dataset.get_manifest_sample(sample_id) + item["audio_filepath"] = sample.audio_file + item["duration"] = sample.duration + item["text"] = sample.text_raw + item["pred_text"] = transcribed_text + self.outf.write(json.dumps(item) + "\n") + self.samples_num += 1 + return + + def close_output_file(self): + self.outf.close() + return self.samples_num + + def convert_to_config_list(initial_list): if initial_list is None or initial_list == []: raise ValueError("manifest_filepaths and tarred_audio_filepaths must not be empty.") diff --git a/nemo/collections/asr/models/configs/__init__.py b/nemo/collections/asr/models/configs/__init__.py index 1dba1be34..072af5a18 100644 --- a/nemo/collections/asr/models/configs/__init__.py +++ b/nemo/collections/asr/models/configs/__init__.py @@ -12,16 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. +from nemo.collections.asr.models.configs.asr_models_config import ( + ASRDatasetConfig, + EncDecCTCConfig, + EncDecCTCModelConfig, +) from nemo.collections.asr.models.configs.classification_models_config import ( EncDecClassificationConfig, EncDecClassificationDatasetConfig, EncDecClassificationModelConfig, ) -from nemo.collections.asr.models.configs.ctc_models_config import ( - EncDecCTCConfig, - EncDecCTCDatasetConfig, - EncDecCTCModelConfig, -) from nemo.collections.asr.models.configs.matchboxnet_config import ( EncDecClassificationModelConfigBuilder, MatchboxNetModelConfig, diff --git a/nemo/collections/asr/models/configs/ctc_models_config.py b/nemo/collections/asr/models/configs/asr_models_config.py similarity index 84% rename from nemo/collections/asr/models/configs/ctc_models_config.py rename to nemo/collections/asr/models/configs/asr_models_config.py index 91d79a4c0..b0ca4fd27 100644 --- a/nemo/collections/asr/models/configs/ctc_models_config.py +++ b/nemo/collections/asr/models/configs/asr_models_config.py @@ -27,15 +27,15 @@ from nemo.core.config import modelPT as model_cfg @dataclass -class EncDecCTCDatasetConfig(nemo.core.classes.dataset.DatasetConfig): - manifest_filepath: Optional[str] = None +class ASRDatasetConfig(nemo.core.classes.dataset.DatasetConfig): + manifest_filepath: Optional[Any] = None sample_rate: int = MISSING labels: List[str] = MISSING trim_silence: bool = False # Tarred dataset support is_tarred: bool = False - tarred_audio_filepaths: Optional[str] = None + tarred_audio_filepaths: Optional[Any] = None tarred_shard_strategy: str = "scatter" shuffle_n: int = 0 @@ -54,6 +54,7 @@ class EncDecCTCDatasetConfig(nemo.core.classes.dataset.DatasetConfig): bos_id: Optional[int] = None pad_id: int = 0 use_start_end_token: bool = False + return_sample_id: Optional[bool] = False @dataclass @@ -66,9 +67,9 @@ class EncDecCTCConfig(model_cfg.ModelConfig): labels: List[str] = MISSING # Dataset configs - train_ds: EncDecCTCDatasetConfig = EncDecCTCDatasetConfig(manifest_filepath=None, shuffle=True) - validation_ds: EncDecCTCDatasetConfig = EncDecCTCDatasetConfig(manifest_filepath=None, shuffle=False) - test_ds: EncDecCTCDatasetConfig = EncDecCTCDatasetConfig(manifest_filepath=None, shuffle=False) + train_ds: ASRDatasetConfig = ASRDatasetConfig(manifest_filepath=None, shuffle=True) + validation_ds: ASRDatasetConfig = ASRDatasetConfig(manifest_filepath=None, shuffle=False) + test_ds: ASRDatasetConfig = ASRDatasetConfig(manifest_filepath=None, shuffle=False) # Optimizer / Scheduler config optim: Optional[model_cfg.OptimConfig] = model_cfg.OptimConfig(sched=model_cfg.SchedConfig()) diff --git a/nemo/collections/asr/models/configs/quartznet_config.py b/nemo/collections/asr/models/configs/quartznet_config.py index 1850fce17..a1231002a 100644 --- a/nemo/collections/asr/models/configs/quartznet_config.py +++ b/nemo/collections/asr/models/configs/quartznet_config.py @@ -17,7 +17,7 @@ from typing import Any, Callable, List, Optional from omegaconf import MISSING -from nemo.collections.asr.models.configs import ctc_models_config as ctc_cfg +from nemo.collections.asr.models.configs import asr_models_config as ctc_cfg from nemo.collections.asr.modules.audio_preprocessing import ( AudioToMelSpectrogramPreprocessorConfig, SpectrogramAugmentationConfig, @@ -174,13 +174,11 @@ class JasperModelConfig(ctc_cfg.EncDecCTCConfig): labels: List[str] = MISSING # Dataset configs - train_ds: ctc_cfg.EncDecCTCDatasetConfig = ctc_cfg.EncDecCTCDatasetConfig( + train_ds: ctc_cfg.ASRDatasetConfig = ctc_cfg.ASRDatasetConfig( manifest_filepath=None, shuffle=True, trim_silence=True ) - validation_ds: ctc_cfg.EncDecCTCDatasetConfig = ctc_cfg.EncDecCTCDatasetConfig( - manifest_filepath=None, shuffle=False - ) - test_ds: ctc_cfg.EncDecCTCDatasetConfig = ctc_cfg.EncDecCTCDatasetConfig(manifest_filepath=None, shuffle=False) + validation_ds: ctc_cfg.ASRDatasetConfig = ctc_cfg.ASRDatasetConfig(manifest_filepath=None, shuffle=False) + test_ds: ctc_cfg.ASRDatasetConfig = ctc_cfg.ASRDatasetConfig(manifest_filepath=None, shuffle=False) # Optimizer / Scheduler config optim: Optional[model_cfg.OptimConfig] = model_cfg.OptimConfig(sched=model_cfg.SchedConfig()) @@ -262,13 +260,13 @@ class EncDecCTCModelConfigBuilder(model_cfg.ModelConfigBuilder): # Note: Autocomplete for users wont work without these overrides # But practically it is not needed since python will infer at runtime - # def set_train_ds(self, cfg: Optional[ctc_cfg.EncDecCTCDatasetConfig] = None): + # def set_train_ds(self, cfg: Optional[ctc_cfg.ASRDatasetConfig] = None): # super().set_train_ds(cfg) # - # def set_validation_ds(self, cfg: Optional[ctc_cfg.EncDecCTCDatasetConfig] = None): + # def set_validation_ds(self, cfg: Optional[ctc_cfg.ASRDatasetConfig] = None): # super().set_validation_ds(cfg) # - # def set_test_ds(self, cfg: Optional[ctc_cfg.EncDecCTCDatasetConfig] = None): + # def set_test_ds(self, cfg: Optional[ctc_cfg.ASRDatasetConfig] = None): # super().set_test_ds(cfg) def _finalize_cfg(self): diff --git a/nemo/collections/asr/models/ctc_models.py b/nemo/collections/asr/models/ctc_models.py index a2b54a8e7..f8affacd5 100644 --- a/nemo/collections/asr/models/ctc_models.py +++ b/nemo/collections/asr/models/ctc_models.py @@ -506,6 +506,7 @@ class EncDecCTCModel(ASRModel, ExportableEncDecModel, ASRModuleMixin): "input_signal_length": NeuralType(tuple('B'), LengthsType(), optional=True), "processed_signal": NeuralType(('B', 'D', 'T'), SpectrogramType(), optional=True), "processed_signal_length": NeuralType(tuple('B'), LengthsType(), optional=True), + "sample_id": NeuralType(tuple('B'), LengthsType(), optional=True), } @property @@ -596,6 +597,22 @@ class EncDecCTCModel(ASRModel, ExportableEncDecModel, ASRModuleMixin): return {'loss': loss_value, 'log': tensorboard_logs} + def predict_step(self, batch, batch_idx, dataloader_idx=0): + signal, signal_len, transcript, transcript_len, sample_id = batch + if isinstance(batch, DALIOutputs) and batch.has_processed_signal: + log_probs, encoded_len, predictions = self.forward( + processed_signal=signal, processed_signal_length=signal_len + ) + else: + log_probs, encoded_len, predictions = self.forward(input_signal=signal, input_signal_length=signal_len) + + transcribed_texts = self._wer.ctc_decoder_predictions_tensor( + predictions=predictions, predictions_len=encoded_len, return_hypotheses=False, + ) + + sample_id = sample_id.cpu().detach().numpy() + return list(zip(sample_id, transcribed_texts)) + def validation_step(self, batch, batch_idx, dataloader_idx=0): signal, signal_len, transcript, transcript_len = batch if isinstance(batch, DALIOutputs) and batch.has_processed_signal: diff --git a/nemo/collections/asr/models/rnnt_models.py b/nemo/collections/asr/models/rnnt_models.py index 3d8fa5bf5..04ea98708 100644 --- a/nemo/collections/asr/models/rnnt_models.py +++ b/nemo/collections/asr/models/rnnt_models.py @@ -647,7 +647,6 @@ class EncDecRNNTModel(ASRModel, ASRModuleMixin, ExportableEncDecJointModel): if hasattr(self, '_trainer') and self._trainer is not None: log_every_n_steps = self._trainer.log_every_n_steps sample_id = self._trainer.global_step - else: log_every_n_steps = 1 sample_id = batch_nb @@ -699,6 +698,23 @@ class EncDecRNNTModel(ASRModel, ASRModuleMixin, ExportableEncDecJointModel): return {'loss': loss_value} + def predict_step(self, batch, batch_idx, dataloader_idx=0): + signal, signal_len, transcript, transcript_len, sample_id = batch + + # forward() only performs encoder forward + if isinstance(batch, DALIOutputs) and batch.has_processed_signal: + encoded, encoded_len = self.forward(processed_signal=signal, processed_signal_length=signal_len) + else: + encoded, encoded_len = self.forward(input_signal=signal, input_signal_length=signal_len) + del signal + + best_hyp_text, all_hyp_text = self.decoding.rnnt_decoder_predictions_tensor( + encoder_output=encoded, encoded_lengths=encoded_len, return_hypotheses=False + ) + + sample_id = sample_id.cpu().detach().numpy() + return list(zip(sample_id, best_hyp_text)) + def validation_step(self, batch, batch_idx, dataloader_idx=0): signal, signal_len, transcript, transcript_len = batch diff --git a/nemo/core/classes/common.py b/nemo/core/classes/common.py index 6bd6058f7..5c71bfbd0 100644 --- a/nemo/core/classes/common.py +++ b/nemo/core/classes/common.py @@ -269,6 +269,7 @@ class Typing(ABC): # Precompute metadata metadata = TypecheckMetadata(original_types=output_types, ignore_collections=ignore_collections) out_types_list = list(metadata.base_types.items()) + mandatory_out_types_list = list(metadata.mandatory_types.items()) # First convert all outputs to list/tuple format to check correct number of outputs if type(out_objects) in (list, tuple): @@ -287,15 +288,15 @@ class Typing(ABC): pass # In all other cases, python will wrap multiple outputs into an outer tuple. - # As such, now the number of elements in this outer tuple should exactly match - # the number of output types defined. - elif len(out_types_list) != len(out_container): + # Allow number of output arguments to be <= total output neural types and >= mandatory outputs. + + elif len(out_container) > len(out_types_list) or len(out_container) < len(mandatory_out_types_list): raise TypeError( - "Number of output arguments provided ({}) is not as expected ({}).\n" - "This can be either because insufficient number of output NeuralTypes were provided," + "Number of output arguments provided ({}) is not as expected. It should be larger than {} and less than {}.\n" + "This can be either because insufficient/extra number of output NeuralTypes were provided," "or the provided NeuralTypes {} should enable container support " "(add '[]' to the NeuralType definition)".format( - len(out_container), len(output_types), output_types + len(out_container), len(out_types_list), len(mandatory_out_types_list), output_types ) ) diff --git a/nemo/core/classes/dataset.py b/nemo/core/classes/dataset.py index 51bd46ef7..738ae22f5 100644 --- a/nemo/core/classes/dataset.py +++ b/nemo/core/classes/dataset.py @@ -105,5 +105,5 @@ class DatasetConfig: batch_size: int = 32 drop_last: bool = False shuffle: bool = False - num_workers: Optional[int] = None + num_workers: Optional[int] = 0 pin_memory: bool = True diff --git a/tests/collections/asr/test_asr_ctc_encoder_model_bpe.py b/tests/collections/asr/test_asr_ctc_encoder_model_bpe.py index 4a3f92b11..fcefb9192 100644 --- a/tests/collections/asr/test_asr_ctc_encoder_model_bpe.py +++ b/tests/collections/asr/test_asr_ctc_encoder_model_bpe.py @@ -237,7 +237,7 @@ class TestEncDecCTCModel: os.rename(old_tokenizer_dir + '.bkp', old_tokenizer_dir) @pytest.mark.unit - def test_EncDecCTCDatasetConfig_for_AudioToBPEDataset(self): + def test_ASRDatasetConfig_for_AudioToBPEDataset(self): # ignore some additional arguments as dataclass is generic IGNORE_ARGS = [ 'is_tarred', @@ -261,10 +261,7 @@ class TestEncDecCTCModel: REMAP_ARGS = {'trim_silence': 'trim', 'labels': 'tokenizer'} result = assert_dataclass_signature_match( - audio_to_text.AudioToBPEDataset, - configs.EncDecCTCDatasetConfig, - ignore_args=IGNORE_ARGS, - remap_args=REMAP_ARGS, + audio_to_text.AudioToBPEDataset, configs.ASRDatasetConfig, ignore_args=IGNORE_ARGS, remap_args=REMAP_ARGS, ) signatures_match, cls_subset, dataclass_subset = result @@ -273,7 +270,7 @@ class TestEncDecCTCModel: assert dataclass_subset is None @pytest.mark.unit - def test_EncDecCTCDatasetConfig_for_TarredAudioToBPEDataset(self): + def test_ASRDatasetConfig_for_TarredAudioToBPEDataset(self): # ignore some additional arguments as dataclass is generic IGNORE_ARGS = [ 'is_tarred', @@ -303,7 +300,7 @@ class TestEncDecCTCModel: result = assert_dataclass_signature_match( audio_to_text.TarredAudioToBPEDataset, - configs.EncDecCTCDatasetConfig, + configs.ASRDatasetConfig, ignore_args=IGNORE_ARGS, remap_args=REMAP_ARGS, ) diff --git a/tests/collections/asr/test_asr_ctcencdec_model.py b/tests/collections/asr/test_asr_ctcencdec_model.py index be9c502f2..ef1c82342 100644 --- a/tests/collections/asr/test_asr_ctcencdec_model.py +++ b/tests/collections/asr/test_asr_ctcencdec_model.py @@ -240,7 +240,7 @@ class TestEncDecCTCModel: # trainer and exp manager should be there @pytest.mark.unit - def test_EncDecCTCDatasetConfig_for_AudioToCharDataset(self): + def test_ASRDatasetConfig_for_AudioToCharDataset(self): # ignore some additional arguments as dataclass is generic IGNORE_ARGS = [ 'is_tarred', @@ -259,10 +259,7 @@ class TestEncDecCTCModel: REMAP_ARGS = {'trim_silence': 'trim'} result = assert_dataclass_signature_match( - audio_to_text.AudioToCharDataset, - configs.EncDecCTCDatasetConfig, - ignore_args=IGNORE_ARGS, - remap_args=REMAP_ARGS, + audio_to_text.AudioToCharDataset, configs.ASRDatasetConfig, ignore_args=IGNORE_ARGS, remap_args=REMAP_ARGS, ) signatures_match, cls_subset, dataclass_subset = result @@ -271,7 +268,7 @@ class TestEncDecCTCModel: assert dataclass_subset is None @pytest.mark.unit - def test_EncDecCTCDatasetConfig_for_TarredAudioToCharDataset(self): + def test_ASRDatasetConfig_for_TarredAudioToCharDataset(self): # ignore some additional arguments as dataclass is generic IGNORE_ARGS = [ 'is_tarred', @@ -294,7 +291,7 @@ class TestEncDecCTCModel: result = assert_dataclass_signature_match( audio_to_text.TarredAudioToCharDataset, - configs.EncDecCTCDatasetConfig, + configs.ASRDatasetConfig, ignore_args=IGNORE_ARGS, remap_args=REMAP_ARGS, ) diff --git a/tests/collections/asr/test_asr_datasets.py b/tests/collections/asr/test_asr_datasets.py index 1e2fe7acb..eaa4abd4c 100644 --- a/tests/collections/asr/test_asr_datasets.py +++ b/tests/collections/asr/test_asr_datasets.py @@ -341,7 +341,6 @@ class TestASRDatasets: num_samples = 10 batch_size = 1 - device = 'gpu' if torch.cuda.is_available() else 'cpu' texts = [] with tempfile.NamedTemporaryFile(mode='w', encoding='utf-8') as f: @@ -394,7 +393,6 @@ class TestASRDatasets: ) ref_preprocessor = EncDecCTCModel.from_config_dict(preprocessor_cfg) - count = 0 for ref_data, dali_data in zip(ref_dataloader, dali_dataset): ref_audio, ref_audio_len, _, _ = ref_data ref_features, ref_features_len = ref_preprocessor(input_signal=ref_audio, length=ref_audio_len)