NeMo/nemo/collections/asr/models/ctc_models.py
Eric Harper aaacc4b089
Merge r1.5.0 bugfixes and doc updates to main (#3133)
* update branch

Signed-off-by: ericharper <complex451@gmail.com>

* Always save last checkpoint on train end even if folder does not exist (#2976)

* add fix for no checkpoint folder when training ends

Signed-off-by: Jason <jasoli@nvidia.com>

* update

Signed-off-by: Jason <jasoli@nvidia.com>

* fix test

Signed-off-by: Jason <jasoli@nvidia.com>

* fixes

Signed-off-by: Jason <jasoli@nvidia.com>

* typo

Signed-off-by: Jason <jasoli@nvidia.com>

* change check

Signed-off-by: Jason <jasoli@nvidia.com>

* [NLP] Add Apex import guard (#3041)

* add apex import guard

Signed-off-by: ericharper <complex451@gmail.com>

* add apex import guard

Signed-off-by: ericharper <complex451@gmail.com>

* add apex import guard

Signed-off-by: ericharper <complex451@gmail.com>

* style

Signed-off-by: ericharper <complex451@gmail.com>

* remove from init add logging to constructor

Signed-off-by: ericharper <complex451@gmail.com>

* remove from init add logging to constructor

Signed-off-by: ericharper <complex451@gmail.com>

* remove import from init

Signed-off-by: ericharper <complex451@gmail.com>

* remove megatron bert encoder logic from NLPModel

Signed-off-by: ericharper <complex451@gmail.com>

* remove megatron bert from init

Signed-off-by: ericharper <complex451@gmail.com>

* remove megatron bert from init

Signed-off-by: ericharper <complex451@gmail.com>

* remove megatron bert from init

Signed-off-by: ericharper <complex451@gmail.com>

* remove megatron bert from init

Signed-off-by: ericharper <complex451@gmail.com>

* remove megatron bert from init

Signed-off-by: ericharper <complex451@gmail.com>

* remove megatron bert from init

Signed-off-by: ericharper <complex451@gmail.com>

* style

Signed-off-by: ericharper <complex451@gmail.com>

* Exp manager small refactor (#3067)

* Exp manager small refactor

Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca>

* move super() call earlier in the function

Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca>

Co-authored-by: Somshubra Majumdar <titu1994@gmail.com>

* Change container (#3087)

Signed-off-by: smajumdar <titu1994@gmail.com>

Co-authored-by: Eric Harper <complex451@gmail.com>

* Training of machine translation model fails if config parameter `trainer.max_epochs` is used instead of `trainer.max_steps`. (#3112)

* fix: replace distributed_backend for accelarator

Signed-off-by: PeganovAnton <peganoff2@mail.ru>

* Add debug script

Signed-off-by: PeganovAnton <peganoff2@mail.ru>

* Remove debug script

Signed-off-by: PeganovAnton <peganoff2@mail.ru>

* update (#3113)

Signed-off-by: Jason <jasoli@nvidia.com>

* Fix: punctuation capitalization inference on short queries (#3111)

Signed-off-by: PeganovAnton <peganoff2@mail.ru>

Co-authored-by: Eric Harper <complex451@gmail.com>

* Multiple ASR Fixes to SPE tokenization (#3119)

* Reduce num workers for transcribe

Signed-off-by: smajumdar <titu1994@gmail.com>

* Fix SPE tokenizer vocabulary construction

Signed-off-by: smajumdar <titu1994@gmail.com>

* Update tokenizer building script

Signed-off-by: smajumdar <titu1994@gmail.com>

* Remove logs

Signed-off-by: smajumdar <titu1994@gmail.com>

* Megatron GPT training in BCP (#3095)

* BCP megatron training

Signed-off-by: madhukar <madhukar@penguin>

* Add quotes

Signed-off-by: madhukar <madhukar@penguin>

* Style fix

Signed-off-by: madhukar <madhukar@penguin>

Co-authored-by: madhukar <madhukar@penguin>

* Upgrade to PTL 1.5.0 (#3127)

* update for ptl 1.5.0

Signed-off-by: ericharper <complex451@gmail.com>

* update trainer config

Signed-off-by: ericharper <complex451@gmail.com>

* limit cuda visible devices to the first two gpus on check for ranks CI test

Signed-off-by: ericharper <complex451@gmail.com>

* remove comments

Signed-off-by: ericharper <complex451@gmail.com>

* make datasets larger for test

Signed-off-by: ericharper <complex451@gmail.com>

* make datasets larger for test

Signed-off-by: ericharper <complex451@gmail.com>

* update compute_max_steps

Signed-off-by: ericharper <complex451@gmail.com>

* update compute_max_steps

Signed-off-by: ericharper <complex451@gmail.com>

* update package info

Signed-off-by: ericharper <complex451@gmail.com>

* remove duplicate code

Signed-off-by: ericharper <complex451@gmail.com>

* remove comment

Signed-off-by: ericharper <complex451@gmail.com>

Co-authored-by: Jason <jasoli@nvidia.com>
Co-authored-by: Sandeep Subramanian <sandeep.subramanian.1@umontreal.ca>
Co-authored-by: Somshubra Majumdar <titu1994@gmail.com>
Co-authored-by: PeganovAnton <peganoff2@mail.ru>
Co-authored-by: Madhukar K <26607911+madhukarkm@users.noreply.github.com>
Co-authored-by: madhukar <madhukar@penguin>
2021-11-04 10:26:58 -06:00

667 lines
32 KiB
Python

# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import json
import os
import tempfile
from math import ceil
from typing import Dict, List, Optional, Union
import torch
from omegaconf import DictConfig, OmegaConf, open_dict
from pytorch_lightning import Trainer
from torch.utils.data import ChainDataset
from tqdm.auto import tqdm
from nemo.collections.asr.data import audio_to_text_dataset
from nemo.collections.asr.data.audio_to_text_dali import DALIOutputs
from nemo.collections.asr.losses.ctc import CTCLoss
from nemo.collections.asr.metrics.wer import WER
from nemo.collections.asr.models.asr_model import ASRModel, ExportableEncDecModel
from nemo.collections.asr.parts.mixins import ASRModuleMixin
from nemo.collections.asr.parts.preprocessing.perturb import process_augmentations
from nemo.core.classes.common import PretrainedModelInfo, typecheck
from nemo.core.neural_types import AudioSignal, LabelsType, LengthsType, LogprobsType, NeuralType, SpectrogramType
from nemo.utils import logging
__all__ = ['EncDecCTCModel']
class EncDecCTCModel(ASRModel, ExportableEncDecModel, ASRModuleMixin):
"""Base class for encoder decoder CTC-based models."""
@classmethod
def list_available_models(cls) -> Optional[PretrainedModelInfo]:
"""
This method returns a list of pre-trained model which can be instantiated directly from NVIDIA's NGC cloud.
Returns:
List of available pre-trained models.
"""
results = []
model = PretrainedModelInfo(
pretrained_model_name="QuartzNet15x5Base-En",
description="QuartzNet15x5 model trained on six datasets: LibriSpeech, Mozilla Common Voice (validated clips from en_1488h_2019-12-10), WSJ, Fisher, Switchboard, and NSC Singapore English. It was trained with Apex/Amp optimization level O1 for 600 epochs. The model achieves a WER of 3.79% on LibriSpeech dev-clean, and a WER of 10.05% on dev-other. Please visit https://ngc.nvidia.com/catalog/models/nvidia:nemospeechmodels for further details.",
location="https://api.ngc.nvidia.com/v2/models/nvidia/nemospeechmodels/versions/1.0.0a5/files/QuartzNet15x5Base-En.nemo",
)
results.append(model)
model = PretrainedModelInfo(
pretrained_model_name="stt_en_quartznet15x5",
description="For details about this model, please visit https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_en_quartznet15x5",
location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/stt_en_quartznet15x5/versions/1.0.0rc1/files/stt_en_quartznet15x5.nemo",
)
results.append(model)
model = PretrainedModelInfo(
pretrained_model_name="stt_en_jasper10x5dr",
description="For details about this model, please visit https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_en_jasper10x5dr",
location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/stt_en_jasper10x5dr/versions/1.0.0rc1/files/stt_en_jasper10x5dr.nemo",
)
results.append(model)
model = PretrainedModelInfo(
pretrained_model_name="stt_ca_quartznet15x5",
description="For details about this model, please visit https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_ca_quartznet15x5",
location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/stt_ca_quartznet15x5/versions/1.0.0rc1/files/stt_ca_quartznet15x5.nemo",
)
results.append(model)
model = PretrainedModelInfo(
pretrained_model_name="stt_it_quartznet15x5",
description="For details about this model, please visit https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_it_quartznet15x5",
location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/stt_it_quartznet15x5/versions/1.0.0rc1/files/stt_it_quartznet15x5.nemo",
)
results.append(model)
model = PretrainedModelInfo(
pretrained_model_name="stt_fr_quartznet15x5",
description="For details about this model, please visit https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_fr_quartznet15x5",
location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/stt_fr_quartznet15x5/versions/1.0.0rc1/files/stt_fr_quartznet15x5.nemo",
)
results.append(model)
model = PretrainedModelInfo(
pretrained_model_name="stt_es_quartznet15x5",
description="For details about this model, please visit https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_es_quartznet15x5",
location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/stt_es_quartznet15x5/versions/1.0.0rc1/files/stt_es_quartznet15x5.nemo",
)
results.append(model)
model = PretrainedModelInfo(
pretrained_model_name="stt_de_quartznet15x5",
description="For details about this model, please visit https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_de_quartznet15x5",
location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/stt_de_quartznet15x5/versions/1.0.0rc1/files/stt_de_quartznet15x5.nemo",
)
results.append(model)
model = PretrainedModelInfo(
pretrained_model_name="stt_pl_quartznet15x5",
description="For details about this model, please visit https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_pl_quartznet15x5",
location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/stt_pl_quartznet15x5/versions/1.0.0rc1/files/stt_pl_quartznet15x5.nemo",
)
results.append(model)
model = PretrainedModelInfo(
pretrained_model_name="stt_ru_quartznet15x5",
description="For details about this model, please visit https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_ru_quartznet15x5",
location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/stt_ru_quartznet15x5/versions/1.0.0rc1/files/stt_ru_quartznet15x5.nemo",
)
results.append(model)
model = PretrainedModelInfo(
pretrained_model_name="stt_zh_citrinet_512",
description="For details about this model, please visit https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_zh_citrinet_512",
location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/stt_zh_citrinet_512/versions/1.0.0rc1/files/stt_zh_citrinet_512.nemo",
)
results.append(model)
model = PretrainedModelInfo(
pretrained_model_name="stt_zh_citrinet_1024_gamma_0_25",
description="For details about this model, please visit https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_zh_citrinet_1024_gamma_0_25",
location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/stt_zh_citrinet_1024_gamma_0_25/versions/1.0.0/files/stt_zh_citrinet_1024_gamma_0_25.nemo",
)
results.append(model)
model = PretrainedModelInfo(
pretrained_model_name="stt_zh_citrinet_1024_gamma_0_25",
description="For details about this model, please visit https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_zh_citrinet_1024_gamma_0_25",
location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/stt_zh_citrinet_1024_gamma_0_25/versions/1.0.0/files/stt_zh_citrinet_1024_gamma_0_25.nemo",
)
results.append(model)
model = PretrainedModelInfo(
pretrained_model_name="asr_talknet_aligner",
description="For details about this model, please visit https://ngc.nvidia.com/catalog/models/nvidia:nemo:asr_talknet_aligner",
location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/asr_talknet_aligner/versions/1.0.0rc1/files/qn5x5_libri_tts_phonemes.nemo",
)
results.append(model)
return results
def __init__(self, cfg: DictConfig, trainer: Trainer = None):
# Get global rank and total number of GPU workers for IterableDataset partitioning, if applicable
# Global_rank and local_rank is set by LightningModule in Lightning 1.2.0
self.world_size = 1
if trainer is not None:
self.world_size = trainer.num_nodes * trainer.num_gpus
super().__init__(cfg=cfg, trainer=trainer)
self.preprocessor = EncDecCTCModel.from_config_dict(self._cfg.preprocessor)
self.encoder = EncDecCTCModel.from_config_dict(self._cfg.encoder)
with open_dict(self._cfg):
if "feat_in" not in self._cfg.decoder or (
not self._cfg.decoder.feat_in and hasattr(self.encoder, '_feat_out')
):
self._cfg.decoder.feat_in = self.encoder._feat_out
if "feat_in" not in self._cfg.decoder or not self._cfg.decoder.feat_in:
raise ValueError("param feat_in of the decoder's config is not set!")
self.decoder = EncDecCTCModel.from_config_dict(self._cfg.decoder)
self.loss = CTCLoss(
num_classes=self.decoder.num_classes_with_blank - 1,
zero_infinity=True,
reduction=self._cfg.get("ctc_reduction", "mean_batch"),
)
if hasattr(self._cfg, 'spec_augment') and self._cfg.spec_augment is not None:
self.spec_augmentation = EncDecCTCModel.from_config_dict(self._cfg.spec_augment)
else:
self.spec_augmentation = None
# Setup metric objects
self._wer = WER(
vocabulary=self.decoder.vocabulary,
batch_dim_index=0,
use_cer=self._cfg.get('use_cer', False),
ctc_decode=True,
dist_sync_on_step=True,
log_prediction=self._cfg.get("log_prediction", False),
)
@torch.no_grad()
def transcribe(
self,
paths2audio_files: List[str],
batch_size: int = 4,
logprobs: bool = False,
return_hypotheses: bool = False,
) -> List[str]:
"""
Uses greedy decoding to transcribe audio files. Use this method for debugging and prototyping.
Args:
paths2audio_files: (a list) of paths to audio files. \
Recommended length per file is between 5 and 25 seconds. \
But it is possible to pass a few hours long file if enough GPU memory is available.
batch_size: (int) batch size to use during inference.
Bigger will result in better throughput performance but would use more memory.
logprobs: (bool) pass True to get log probabilities instead of transcripts.
return_hypotheses: (bool) Either return hypotheses or text
With hypotheses can do some postprocessing like getting timestamp or rescoring
Returns:
A list of transcriptions (or raw log probabilities if logprobs is True) in the same order as paths2audio_files
"""
if paths2audio_files is None or len(paths2audio_files) == 0:
return {}
if return_hypotheses and logprobs:
raise ValueError(
"Either `return_hypotheses` or `logprobs` can be True at any given time."
"Returned hypotheses will contain the logprobs."
)
# We will store transcriptions here
hypotheses = []
# Model's mode and device
mode = self.training
device = next(self.parameters()).device
dither_value = self.preprocessor.featurizer.dither
pad_to_value = self.preprocessor.featurizer.pad_to
try:
self.preprocessor.featurizer.dither = 0.0
self.preprocessor.featurizer.pad_to = 0
# Switch model to evaluation mode
self.eval()
# Freeze the encoder and decoder modules
self.encoder.freeze()
self.decoder.freeze()
logging_level = logging.get_verbosity()
logging.set_verbosity(logging.WARNING)
# Work in tmp directory - will store manifest file there
with tempfile.TemporaryDirectory() as tmpdir:
with open(os.path.join(tmpdir, 'manifest.json'), 'w') as fp:
for audio_file in paths2audio_files:
entry = {'audio_filepath': audio_file, 'duration': 100000, 'text': 'nothing'}
fp.write(json.dumps(entry) + '\n')
config = {'paths2audio_files': paths2audio_files, 'batch_size': batch_size, 'temp_dir': tmpdir}
temporary_datalayer = self._setup_transcribe_dataloader(config)
for test_batch in tqdm(temporary_datalayer, desc="Transcribing"):
logits, logits_len, greedy_predictions = self.forward(
input_signal=test_batch[0].to(device), input_signal_length=test_batch[1].to(device)
)
if logprobs:
# dump log probs per file
for idx in range(logits.shape[0]):
lg = logits[idx][: logits_len[idx]]
hypotheses.append(lg.cpu().numpy())
else:
current_hypotheses = self._wer.ctc_decoder_predictions_tensor(
greedy_predictions, predictions_len=logits_len, return_hypotheses=return_hypotheses,
)
if return_hypotheses:
# dump log probs per file
for idx in range(logits.shape[0]):
current_hypotheses[idx].y_sequence = logits[idx][: logits_len[idx]]
hypotheses += current_hypotheses
del greedy_predictions
del logits
del test_batch
finally:
# set mode back to its original value
self.train(mode=mode)
self.preprocessor.featurizer.dither = dither_value
self.preprocessor.featurizer.pad_to = pad_to_value
if mode is True:
self.encoder.unfreeze()
self.decoder.unfreeze()
logging.set_verbosity(logging_level)
return hypotheses
def change_vocabulary(self, new_vocabulary: List[str]):
"""
Changes vocabulary used during CTC decoding process. Use this method when fine-tuning on from pre-trained model.
This method changes only decoder and leaves encoder and pre-processing modules unchanged. For example, you would
use it if you want to use pretrained encoder when fine-tuning on a data in another language, or when you'd need
model to learn capitalization, punctuation and/or special characters.
If new_vocabulary == self.decoder.vocabulary then nothing will be changed.
Args:
new_vocabulary: list with new vocabulary. Must contain at least 2 elements. Typically, \
this is target alphabet.
Returns: None
"""
if self.decoder.vocabulary == new_vocabulary:
logging.warning(f"Old {self.decoder.vocabulary} and new {new_vocabulary} match. Not changing anything.")
else:
if new_vocabulary is None or len(new_vocabulary) == 0:
raise ValueError(f'New vocabulary must be non-empty list of chars. But I got: {new_vocabulary}')
decoder_config = self.decoder.to_config_dict()
new_decoder_config = copy.deepcopy(decoder_config)
new_decoder_config['vocabulary'] = new_vocabulary
new_decoder_config['num_classes'] = len(new_vocabulary)
del self.decoder
self.decoder = EncDecCTCModel.from_config_dict(new_decoder_config)
del self.loss
self.loss = CTCLoss(
num_classes=self.decoder.num_classes_with_blank - 1,
zero_infinity=True,
reduction=self._cfg.get("ctc_reduction", "mean_batch"),
)
self._wer = WER(
vocabulary=self.decoder.vocabulary,
batch_dim_index=0,
use_cer=self._cfg.get('use_cer', False),
ctc_decode=True,
dist_sync_on_step=True,
log_prediction=self._cfg.get("log_prediction", False),
)
# Update config
OmegaConf.set_struct(self._cfg.decoder, False)
self._cfg.decoder = new_decoder_config
OmegaConf.set_struct(self._cfg.decoder, True)
ds_keys = ['train_ds', 'validation_ds', 'test_ds']
for key in ds_keys:
if key in self.cfg:
with open_dict(self.cfg[key]):
self.cfg[key]['labels'] = OmegaConf.create(new_vocabulary)
logging.info(f"Changed decoder to output to {self.decoder.vocabulary} vocabulary.")
def _setup_dataloader_from_config(self, config: Optional[Dict]):
if 'augmentor' in config:
augmentor = process_augmentations(config['augmentor'])
else:
augmentor = None
# Automatically inject args from model config to dataloader config
audio_to_text_dataset.inject_dataloader_value_from_model_config(self.cfg, config, key='sample_rate')
audio_to_text_dataset.inject_dataloader_value_from_model_config(self.cfg, config, key='labels')
shuffle = config['shuffle']
device = 'gpu' if torch.cuda.is_available() else 'cpu'
if config.get('use_dali', False):
device_id = self.local_rank if device == 'gpu' else None
dataset = audio_to_text_dataset.get_dali_char_dataset(
config=config,
shuffle=shuffle,
device_id=device_id,
global_rank=self.global_rank,
world_size=self.world_size,
preprocessor_cfg=self._cfg.preprocessor,
)
return dataset
# Instantiate tarred dataset loader or normal dataset loader
if config.get('is_tarred', False):
if ('tarred_audio_filepaths' in config and config['tarred_audio_filepaths'] is None) or (
'manifest_filepath' in config and config['manifest_filepath'] is None
):
logging.warning(
"Could not load dataset as `manifest_filepath` was None or "
f"`tarred_audio_filepaths` is None. Provided config : {config}"
)
return None
shuffle_n = config.get('shuffle_n', 4 * config['batch_size']) if shuffle else 0
dataset = audio_to_text_dataset.get_tarred_dataset(
config=config,
shuffle_n=shuffle_n,
global_rank=self.global_rank,
world_size=self.world_size,
augmentor=augmentor,
)
shuffle = False
else:
if 'manifest_filepath' in config and config['manifest_filepath'] is None:
logging.warning(f"Could not load dataset as `manifest_filepath` was None. Provided config : {config}")
return None
dataset = audio_to_text_dataset.get_char_dataset(config=config, augmentor=augmentor)
if type(dataset) is ChainDataset:
collate_fn = dataset.datasets[0].collate_fn
else:
collate_fn = dataset.collate_fn
return torch.utils.data.DataLoader(
dataset=dataset,
batch_size=config['batch_size'],
collate_fn=collate_fn,
drop_last=config.get('drop_last', False),
shuffle=shuffle,
num_workers=config.get('num_workers', 0),
pin_memory=config.get('pin_memory', False),
)
def setup_training_data(self, train_data_config: Optional[Union[DictConfig, Dict]]):
"""
Sets up the training data loader via a Dict-like object.
Args:
train_data_config: A config that contains the information regarding construction
of an ASR Training dataset.
Supported Datasets:
- :class:`~nemo.collections.asr.data.audio_to_text.AudioToCharDataset`
- :class:`~nemo.collections.asr.data.audio_to_text.AudioToBPEDataset`
- :class:`~nemo.collections.asr.data.audio_to_text.TarredAudioToCharDataset`
- :class:`~nemo.collections.asr.data.audio_to_text.TarredAudioToBPEDataset`
- :class:`~nemo.collections.asr.data.audio_to_text_dali.AudioToCharDALIDataset`
"""
if 'shuffle' not in train_data_config:
train_data_config['shuffle'] = True
# preserve config
self._update_dataset_config(dataset_name='train', config=train_data_config)
self._train_dl = self._setup_dataloader_from_config(config=train_data_config)
# Need to set this because if using an IterableDataset, the length of the dataloader is the total number
# of samples rather than the number of batches, and this messes up the tqdm progress bar.
# So we set the number of steps manually (to the correct number) to fix this.
if 'is_tarred' in train_data_config and train_data_config['is_tarred']:
# We also need to check if limit_train_batches is already set.
# If it's an int, we assume that the user has set it to something sane, i.e. <= # training batches,
# and don't change it. Otherwise, adjust batches accordingly if it's a float (including 1.0).
if isinstance(self._trainer.limit_train_batches, float):
self._trainer.limit_train_batches = int(
self._trainer.limit_train_batches
* ceil((len(self._train_dl.dataset) / self.world_size) / train_data_config['batch_size'])
)
def setup_validation_data(self, val_data_config: Optional[Union[DictConfig, Dict]]):
"""
Sets up the validation data loader via a Dict-like object.
Args:
val_data_config: A config that contains the information regarding construction
of an ASR Training dataset.
Supported Datasets:
- :class:`~nemo.collections.asr.data.audio_to_text.AudioToCharDataset`
- :class:`~nemo.collections.asr.data.audio_to_text.AudioToBPEDataset`
- :class:`~nemo.collections.asr.data.audio_to_text.TarredAudioToCharDataset`
- :class:`~nemo.collections.asr.data.audio_to_text.TarredAudioToBPEDataset`
- :class:`~nemo.collections.asr.data.audio_to_text_dali.AudioToCharDALIDataset`
"""
if 'shuffle' not in val_data_config:
val_data_config['shuffle'] = False
# preserve config
self._update_dataset_config(dataset_name='validation', config=val_data_config)
self._validation_dl = self._setup_dataloader_from_config(config=val_data_config)
def setup_test_data(self, test_data_config: Optional[Union[DictConfig, Dict]]):
"""
Sets up the test data loader via a Dict-like object.
Args:
test_data_config: A config that contains the information regarding construction
of an ASR Training dataset.
Supported Datasets:
- :class:`~nemo.collections.asr.data.audio_to_text.AudioToCharDataset`
- :class:`~nemo.collections.asr.data.audio_to_text.AudioToBPEDataset`
- :class:`~nemo.collections.asr.data.audio_to_text.TarredAudioToCharDataset`
- :class:`~nemo.collections.asr.data.audio_to_text.TarredAudioToBPEDataset`
- :class:`~nemo.collections.asr.data.audio_to_text_dali.AudioToCharDALIDataset`
"""
if 'shuffle' not in test_data_config:
test_data_config['shuffle'] = False
# preserve config
self._update_dataset_config(dataset_name='test', config=test_data_config)
self._test_dl = self._setup_dataloader_from_config(config=test_data_config)
@property
def input_types(self) -> Optional[Dict[str, NeuralType]]:
if hasattr(self.preprocessor, '_sample_rate'):
input_signal_eltype = AudioSignal(freq=self.preprocessor._sample_rate)
else:
input_signal_eltype = AudioSignal()
return {
"input_signal": NeuralType(('B', 'T'), input_signal_eltype, optional=True),
"input_signal_length": NeuralType(tuple('B'), LengthsType(), optional=True),
"processed_signal": NeuralType(('B', 'D', 'T'), SpectrogramType(), optional=True),
"processed_signal_length": NeuralType(tuple('B'), LengthsType(), optional=True),
}
@property
def output_types(self) -> Optional[Dict[str, NeuralType]]:
return {
"outputs": NeuralType(('B', 'T', 'D'), LogprobsType()),
"encoded_lengths": NeuralType(tuple('B'), LengthsType()),
"greedy_predictions": NeuralType(('B', 'T'), LabelsType()),
}
@typecheck()
def forward(
self, input_signal=None, input_signal_length=None, processed_signal=None, processed_signal_length=None
):
"""
Forward pass of the model.
Args:
input_signal: Tensor that represents a batch of raw audio signals,
of shape [B, T]. T here represents timesteps, with 1 second of audio represented as
`self.sample_rate` number of floating point values.
input_signal_length: Vector of length B, that contains the individual lengths of the audio
sequences.
processed_signal: Tensor that represents a batch of processed audio signals,
of shape (B, D, T) that has undergone processing via some DALI preprocessor.
processed_signal_length: Vector of length B, that contains the individual lengths of the
processed audio sequences.
Returns:
A tuple of 3 elements -
1) The log probabilities tensor of shape [B, T, D].
2) The lengths of the acoustic sequence after propagation through the encoder, of shape [B].
3) The greedy token predictions of the model of shape [B, T] (via argmax)
"""
has_input_signal = input_signal is not None and input_signal_length is not None
has_processed_signal = processed_signal is not None and processed_signal_length is not None
if (has_input_signal ^ has_processed_signal) == False:
raise ValueError(
f"{self} Arguments ``input_signal`` and ``input_signal_length`` are mutually exclusive "
" with ``processed_signal`` and ``processed_signal_len`` arguments."
)
if not has_processed_signal:
processed_signal, processed_signal_length = self.preprocessor(
input_signal=input_signal, length=input_signal_length,
)
if self.spec_augmentation is not None and self.training:
processed_signal = self.spec_augmentation(input_spec=processed_signal, length=processed_signal_length)
encoded, encoded_len = self.encoder(audio_signal=processed_signal, length=processed_signal_length)
log_probs = self.decoder(encoder_output=encoded)
greedy_predictions = log_probs.argmax(dim=-1, keepdim=False)
return log_probs, encoded_len, greedy_predictions
# PTL-specific methods
def training_step(self, batch, batch_nb):
signal, signal_len, transcript, transcript_len = batch
if isinstance(batch, DALIOutputs) and batch.has_processed_signal:
log_probs, encoded_len, predictions = self.forward(
processed_signal=signal, processed_signal_length=signal_len
)
else:
log_probs, encoded_len, predictions = self.forward(input_signal=signal, input_signal_length=signal_len)
loss_value = self.loss(
log_probs=log_probs, targets=transcript, input_lengths=encoded_len, target_lengths=transcript_len
)
tensorboard_logs = {'train_loss': loss_value, 'learning_rate': self._optimizer.param_groups[0]['lr']}
if hasattr(self, '_trainer') and self._trainer is not None:
log_every_n_steps = self._trainer.log_every_n_steps
else:
log_every_n_steps = 1
if (batch_nb + 1) % log_every_n_steps == 0:
self._wer.update(
predictions=predictions,
targets=transcript,
target_lengths=transcript_len,
predictions_lengths=encoded_len,
)
wer, _, _ = self._wer.compute()
self._wer.reset()
tensorboard_logs.update({'training_batch_wer': wer})
return {'loss': loss_value, 'log': tensorboard_logs}
def validation_step(self, batch, batch_idx, dataloader_idx=0):
signal, signal_len, transcript, transcript_len = batch
if isinstance(batch, DALIOutputs) and batch.has_processed_signal:
log_probs, encoded_len, predictions = self.forward(
processed_signal=signal, processed_signal_length=signal_len
)
else:
log_probs, encoded_len, predictions = self.forward(input_signal=signal, input_signal_length=signal_len)
loss_value = self.loss(
log_probs=log_probs, targets=transcript, input_lengths=encoded_len, target_lengths=transcript_len
)
self._wer.update(
predictions=predictions, targets=transcript, target_lengths=transcript_len, predictions_lengths=encoded_len
)
wer, wer_num, wer_denom = self._wer.compute()
self._wer.reset()
return {
'val_loss': loss_value,
'val_wer_num': wer_num,
'val_wer_denom': wer_denom,
'val_wer': wer,
}
def test_step(self, batch, batch_idx, dataloader_idx=0):
logs = self.validation_step(batch, batch_idx, dataloader_idx=dataloader_idx)
test_logs = {
'test_loss': logs['val_loss'],
'test_wer_num': logs['val_wer_num'],
'test_wer_denom': logs['val_wer_denom'],
'test_wer': logs['val_wer'],
}
return test_logs
def test_dataloader(self):
if self._test_dl is not None:
return self._test_dl
def _setup_transcribe_dataloader(self, config: Dict) -> 'torch.utils.data.DataLoader':
"""
Setup function for a temporary data loader which wraps the provided audio file.
Args:
config: A python dictionary which contains the following keys:
paths2audio_files: (a list) of paths to audio files. The files should be relatively short fragments. \
Recommended length per file is between 5 and 25 seconds.
batch_size: (int) batch size to use during inference. \
Bigger will result in better throughput performance but would use more memory.
temp_dir: (str) A temporary directory where the audio manifest is temporarily
stored.
Returns:
A pytorch DataLoader for the given audio file(s).
"""
batch_size = min(config['batch_size'], len(config['paths2audio_files']))
dl_config = {
'manifest_filepath': os.path.join(config['temp_dir'], 'manifest.json'),
'sample_rate': self.preprocessor._sample_rate,
'labels': self.decoder.vocabulary,
'batch_size': batch_size,
'trim_silence': False,
'shuffle': False,
'num_workers': min(batch_size, os.cpu_count() - 1),
'pin_memory': True,
}
temporary_datalayer = self._setup_dataloader_from_config(config=DictConfig(dl_config))
return temporary_datalayer