Adding parallel transcribe for ASR models - suppports multi-gpu/multi-node (#3017)
* added transcribe_speech_parallel.py. Signed-off-by: Vahid <vnoroozi@nvidia.com> * fixed style. Signed-off-by: Vahid <vnoroozi@nvidia.com> * removed comments. Signed-off-by: Vahid <vnoroozi@nvidia.com> * fixed bug. Signed-off-by: Vahid <vnoroozi@nvidia.com> * fixed bug. Signed-off-by: Vahid <vnoroozi@nvidia.com> * added comments inside the script. Signed-off-by: Vahid <vnoroozi@nvidia.com> * fixed speed_collate_fn for TTS. Signed-off-by: Vahid <vnoroozi@nvidia.com> * fixed speed_collate_fn for TTS. Signed-off-by: Vahid <vnoroozi@nvidia.com> * fixed speed_collate_fn for TTS. Signed-off-by: Vahid <vnoroozi@nvidia.com> * fixed speed_collate_fn for TTS. Signed-off-by: Vahid <vnoroozi@nvidia.com> * added return_sample_id. Signed-off-by: Vahid <vnoroozi@nvidia.com> * added return_sample_id. Signed-off-by: Vahid <vnoroozi@nvidia.com> * merged dataset configs. Signed-off-by: Vahid <vnoroozi@nvidia.com> * merged dataset configs. Signed-off-by: Vahid <vnoroozi@nvidia.com> * dropped sample_ids from train and validation. Signed-off-by: Vahid <vnoroozi@nvidia.com> * dropped sample_ids from train and validation. Signed-off-by: Vahid <vnoroozi@nvidia.com> * reverted tts patches. Signed-off-by: Vahid <vnoroozi@nvidia.com> * reverted tts patches. Signed-off-by: Vahid <vnoroozi@nvidia.com> * fixed the default values in the dataset's config Signed-off-by: Vahid <vnoroozi@nvidia.com> * fixed the default values in the dataset's config Signed-off-by: Vahid <vnoroozi@nvidia.com> * Fixed the bug for optional outputs. Signed-off-by: Vahid <vnoroozi@nvidia.com> * Fixed the bug for optional outputs. Signed-off-by: Vahid <vnoroozi@nvidia.com> * addressed some comments. Signed-off-by: Vahid <vnoroozi@nvidia.com> * disabled dali support for return_sample_id. Signed-off-by: Vahid <vnoroozi@nvidia.com> * disabled dali support for return_sample_id. Signed-off-by: Vahid <vnoroozi@nvidia.com> * converted the config to omegaconf. Signed-off-by: Vahid <vnoroozi@nvidia.com> * converted the config to omegaconf. Signed-off-by: Vahid <vnoroozi@nvidia.com> * converted the config to omegaconf. Signed-off-by: Vahid <vnoroozi@nvidia.com> * moved wer/cer calculation to the end. Signed-off-by: Vahid <vnoroozi@nvidia.com> * moved wer/cer calculation to the end. Signed-off-by: Vahid <vnoroozi@nvidia.com> * moved wer/cer calculation to the end. Signed-off-by: Vahid <vnoroozi@nvidia.com> * moved wer/cer calculation to the end. Signed-off-by: Vahid <vnoroozi@nvidia.com> * moved wer/cer calculation to the end. Signed-off-by: Vahid <vnoroozi@nvidia.com> * calculates global wer instead of per sample. Signed-off-by: Vahid <vnoroozi@nvidia.com> * calculates global wer instead of per sample. Signed-off-by: Vahid <vnoroozi@nvidia.com> * calculates global wer instead of per sample. Signed-off-by: Vahid <vnoroozi@nvidia.com>
This commit is contained in:
parent
b12ac8ae85
commit
cfcf694e30
|
@ -28,9 +28,6 @@ model:
|
|||
log_prediction: true # enables logging sample predictions in the output during training
|
||||
|
||||
model_defaults:
|
||||
se: true
|
||||
se_context_size: -1
|
||||
# encoder / decoder / joint values
|
||||
enc_hidden: ${model.encoder.d_model}
|
||||
pred_hidden: 640
|
||||
joint_hidden: 640
|
||||
|
@ -107,8 +104,10 @@ model:
|
|||
|
||||
# Multi-headed Attention Module's params
|
||||
self_attention_model: rel_pos # rel_pos or abs_pos
|
||||
n_heads: 8
|
||||
xscaling: true # scales up the inputs by sqrt(d_model)
|
||||
n_heads: 8 # may need to be lower for smaller d_models
|
||||
# [left, right] specifies the number of steps to be seen from left and right of each step in self-attention
|
||||
att_context_size: [-1, -1] # -1 means unlimited context
|
||||
xscaling: true # scales up the input embeddings by sqrt(d_model)
|
||||
untie_biases: true # unties the biases of the TransformerXL layers
|
||||
pos_emb_max_len: 5000
|
||||
|
||||
|
|
|
@ -31,9 +31,6 @@ model:
|
|||
"n", "o", "p", "q", "r", "s", "t", "u", "v", "w", "x", "y", "z", "'"]
|
||||
|
||||
model_defaults:
|
||||
se: true
|
||||
se_context_size: -1
|
||||
# encoder / decoder / joint values
|
||||
enc_hidden: ${model.encoder.d_model}
|
||||
pred_hidden: 640
|
||||
joint_hidden: 640
|
||||
|
@ -102,8 +99,10 @@ model:
|
|||
|
||||
# Multi-headed Attention Module's params
|
||||
self_attention_model: rel_pos # rel_pos or abs_pos
|
||||
n_heads: 8
|
||||
xscaling: true # scales up the inputs by sqrt(d_model)
|
||||
n_heads: 8 # may need to be lower for smaller d_models
|
||||
# [left, right] specifies the number of steps to be seen from left and right of each step in self-attention
|
||||
att_context_size: [-1, -1] # -1 means unlimited context
|
||||
xscaling: true # scales up the input embeddings by sqrt(d_model)
|
||||
untie_biases: true # unties the biases of the TransformerXL layers
|
||||
pos_emb_max_len: 5000
|
||||
|
||||
|
|
189
examples/asr/transcribe_speech_parallel.py
Normal file
189
examples/asr/transcribe_speech_parallel.py
Normal file
|
@ -0,0 +1,189 @@
|
|||
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
# ASR transcribe/inference with multi-GPU/multi-node support for large datasets
|
||||
# It supports both tarred and non-tarred datasets
|
||||
# Arguments
|
||||
# model: path to a nemo/PTL checkpoint file or name of a pretrained model
|
||||
# predict_ds: config of the dataset/dataloader
|
||||
# output_path: path to store the predictions
|
||||
# return_predictions: whether to return the predictions as output other than writing into the files
|
||||
# use_cer: whether to calculate the error in terms of CER or use the default WER
|
||||
#
|
||||
# Results of each GPU/worker is written into a file named 'predictions_{rank}.json, and aggregated results of all workers are written into 'predictions_all.json'
|
||||
|
||||
Example for non-tarred datasets:
|
||||
|
||||
python transcribe_speech_parallel.py \
|
||||
model=stt_en_conformer_ctc_large \
|
||||
predict_ds.manifest_filepath=/dataset/manifest_file.json \
|
||||
predict_ds.batch_size=16 \
|
||||
output_path=/tmp/
|
||||
|
||||
Example for tarred datasets:
|
||||
|
||||
python transcribe_speech_parallel.py \
|
||||
predict_ds.is_tarred=true \
|
||||
predict_ds.manifest_filepath=/tarred_dataset/tarred_audio_manifest.json \
|
||||
predict_ds.tarred_audio_filepaths=/tarred_dataset/audio__OP_0..127_CL_.tar \
|
||||
...
|
||||
|
||||
By default the trainer uses all the GPUs available and default precision is FP32.
|
||||
By setting the trainer config you may control these configs. For example to do the predictions with AMP on just two GPUs:
|
||||
|
||||
python transcribe_speech_parallel.py \
|
||||
trainer.precision=16 \
|
||||
trainer.gpus=2 \
|
||||
...
|
||||
|
||||
You may control the dataloader's config by setting the predict_ds:
|
||||
|
||||
python transcribe_speech_parallel.py \
|
||||
predict_ds.num_workers=8 \
|
||||
predict_ds.min_duration=2.0 \
|
||||
predict_ds.sample_rate=16000 \
|
||||
model=stt_en_conformer_ctc_small \
|
||||
...
|
||||
|
||||
"""
|
||||
|
||||
|
||||
import itertools
|
||||
import json
|
||||
import os
|
||||
from dataclasses import dataclass, is_dataclass
|
||||
from typing import Optional
|
||||
|
||||
import pytorch_lightning as ptl
|
||||
import torch
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from nemo.collections.asr.data.audio_to_text_dataset import ASRPredictionWriter
|
||||
from nemo.collections.asr.metrics.rnnt_wer import RNNTDecodingConfig
|
||||
from nemo.collections.asr.metrics.wer import word_error_rate
|
||||
from nemo.collections.asr.models import ASRModel
|
||||
from nemo.collections.asr.models.configs.asr_models_config import ASRDatasetConfig
|
||||
from nemo.core.config import TrainerConfig, hydra_runner
|
||||
from nemo.utils import logging
|
||||
from nemo.utils.get_rank import is_global_rank_zero
|
||||
|
||||
|
||||
@dataclass
|
||||
class ParallelTranscriptionConfig:
|
||||
model: Optional[str] = None # name
|
||||
predict_ds: ASRDatasetConfig = ASRDatasetConfig(return_sample_id=True, num_workers=4)
|
||||
output_path: Optional[str] = None
|
||||
# when return_predictions is enabled, the prediction call would keep all the predictions in memory and return them when prediction is done
|
||||
return_predictions: bool = False
|
||||
use_cer: bool = False
|
||||
|
||||
# decoding strategy for RNNT models
|
||||
rnnt_decoding: RNNTDecodingConfig = RNNTDecodingConfig()
|
||||
trainer: TrainerConfig = TrainerConfig(gpus=-1, accelerator="ddp")
|
||||
|
||||
|
||||
def match_train_config(predict_ds, train_ds):
|
||||
# It copies the important configurations from the train dataset of the model
|
||||
# into the predict_ds to be used for prediction. It is needed to match the training configurations.
|
||||
if train_ds is None:
|
||||
return
|
||||
|
||||
predict_ds.sample_rate = train_ds.get("sample_rate", 16000)
|
||||
cfg_name_list = [
|
||||
"int_values",
|
||||
"use_start_end_token",
|
||||
"blank_index",
|
||||
"unk_index",
|
||||
"normalize",
|
||||
"parser",
|
||||
"eos_id",
|
||||
"bos_id",
|
||||
"pad_id",
|
||||
]
|
||||
|
||||
if is_dataclass(predict_ds):
|
||||
predict_ds = OmegaConf.structured(predict_ds)
|
||||
for cfg_name in cfg_name_list:
|
||||
if hasattr(train_ds, cfg_name):
|
||||
setattr(predict_ds, cfg_name, getattr(train_ds, cfg_name))
|
||||
|
||||
return predict_ds
|
||||
|
||||
|
||||
@hydra_runner(config_name="TranscriptionConfig", schema=ParallelTranscriptionConfig)
|
||||
def main(cfg: ParallelTranscriptionConfig):
|
||||
if cfg.model.endswith(".nemo"):
|
||||
logging.info("Attempting to initialize from .nemo file")
|
||||
model = ASRModel.restore_from(restore_path=cfg.model, map_location="cpu")
|
||||
elif cfg.model.endswith(".ckpt"):
|
||||
logging.info("Attempting to initialize from .ckpt file")
|
||||
model = ASRModel.load_from_checkpoint(checkpoint_path=cfg.model, map_location="cpu")
|
||||
else:
|
||||
logging.info(
|
||||
"Attempting to initialize from a pretrained model as the model name does not have the extension of .nemo or .ckpt"
|
||||
)
|
||||
model = ASRModel.from_pretrained(model_name=cfg.model, map_location="cpu")
|
||||
|
||||
trainer = ptl.Trainer(**cfg.trainer)
|
||||
|
||||
cfg.predict_ds.return_sample_id = True
|
||||
cfg.predict_ds = match_train_config(predict_ds=cfg.predict_ds, train_ds=model.cfg.train_ds)
|
||||
data_loader = model._setup_dataloader_from_config(cfg.predict_ds)
|
||||
|
||||
os.makedirs(cfg.output_path, exist_ok=True)
|
||||
# trainer.global_rank is not valid before predict() is called. Need this hack to find the correct global_rank.
|
||||
global_rank = trainer.node_rank * trainer.num_gpus + int(os.environ.get("LOCAL_RANK", 0))
|
||||
output_file = os.path.join(cfg.output_path, f"predictions_{global_rank}.json")
|
||||
predictor_writer = ASRPredictionWriter(dataset=data_loader.dataset, output_file=output_file)
|
||||
trainer.callbacks.extend([predictor_writer])
|
||||
|
||||
predictions = trainer.predict(model=model, dataloaders=data_loader, return_predictions=cfg.return_predictions)
|
||||
if predictions is not None:
|
||||
predictions = list(itertools.chain.from_iterable(predictions))
|
||||
samples_num = predictor_writer.close_output_file()
|
||||
|
||||
logging.info(
|
||||
f"Prediction on rank {global_rank} is done for {samples_num} samples and results are stored in {output_file}."
|
||||
)
|
||||
|
||||
if torch.distributed.is_initialized():
|
||||
torch.distributed.barrier()
|
||||
|
||||
samples_num = 0
|
||||
pred_text_list = []
|
||||
text_list = []
|
||||
if is_global_rank_zero():
|
||||
output_file = os.path.join(cfg.output_path, f"predictions_all.json")
|
||||
logging.info(f"Prediction files are being aggregated in {output_file}.")
|
||||
with open(output_file, 'w') as outf:
|
||||
for rank in range(trainer.world_size):
|
||||
input_file = os.path.join(cfg.output_path, f"predictions_{rank}.json")
|
||||
with open(input_file, 'r') as inpf:
|
||||
lines = inpf.readlines()
|
||||
for line in lines:
|
||||
item = json.loads(line)
|
||||
pred_text_list.append(item["pred_text"])
|
||||
text_list.append(item["text"])
|
||||
outf.write(json.dumps(item) + "\n")
|
||||
samples_num += 1
|
||||
wer_cer = word_error_rate(hypotheses=pred_text_list, references=text_list, use_cer=cfg.use_cer)
|
||||
logging.info(
|
||||
f"Prediction is done for {samples_num} samples in total on all workers and results are aggregated in {output_file}."
|
||||
)
|
||||
logging.info("{} for all predictions is {:.4f}.".format("CER" if cfg.use_cer else "WER", wer_cer))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
|
@ -50,7 +50,14 @@ def _speech_collate_fn(batch, pad_id):
|
|||
encoded tokens, and encoded tokens length. This collate func
|
||||
assumes the signals are 1d torch tensors (i.e. mono audio).
|
||||
"""
|
||||
_, audio_lengths, _, tokens_lengths = zip(*batch)
|
||||
packed_batch = list(zip(*batch))
|
||||
if len(packed_batch) == 5:
|
||||
_, audio_lengths, _, tokens_lengths, sample_ids = packed_batch
|
||||
elif len(packed_batch) == 4:
|
||||
sample_ids = None
|
||||
_, audio_lengths, _, tokens_lengths = packed_batch
|
||||
else:
|
||||
raise ValueError("Expects 4 or 5 tensors in the batch!")
|
||||
max_audio_len = 0
|
||||
has_audio = audio_lengths[0] is not None
|
||||
if has_audio:
|
||||
|
@ -58,7 +65,11 @@ def _speech_collate_fn(batch, pad_id):
|
|||
max_tokens_len = max(tokens_lengths).item()
|
||||
|
||||
audio_signal, tokens = [], []
|
||||
for sig, sig_len, tokens_i, tokens_i_len in batch:
|
||||
for b in batch:
|
||||
if len(b) == 5:
|
||||
sig, sig_len, tokens_i, tokens_i_len, _ = b
|
||||
else:
|
||||
sig, sig_len, tokens_i, tokens_i_len = b
|
||||
if has_audio:
|
||||
sig_len = sig_len.item()
|
||||
if sig_len < max_audio_len:
|
||||
|
@ -78,8 +89,11 @@ def _speech_collate_fn(batch, pad_id):
|
|||
audio_signal, audio_lengths = None, None
|
||||
tokens = torch.stack(tokens)
|
||||
tokens_lengths = torch.stack(tokens_lengths)
|
||||
|
||||
return audio_signal, audio_lengths, tokens, tokens_lengths
|
||||
if sample_ids is None:
|
||||
return audio_signal, audio_lengths, tokens, tokens_lengths
|
||||
else:
|
||||
sample_ids = torch.tensor(sample_ids, dtype=torch.int32)
|
||||
return audio_signal, audio_lengths, tokens, tokens_lengths, sample_ids
|
||||
|
||||
|
||||
class ASRManifestProcessor:
|
||||
|
@ -115,7 +129,7 @@ class ASRManifestProcessor:
|
|||
self.parser = parser
|
||||
|
||||
self.collection = collections.ASRAudioText(
|
||||
manifests_files=manifest_filepath.split(','),
|
||||
manifests_files=manifest_filepath,
|
||||
parser=parser,
|
||||
min_duration=min_duration,
|
||||
max_duration=max_duration,
|
||||
|
@ -164,6 +178,7 @@ class _AudioTextDataset(Dataset):
|
|||
normalize: whether to normalize transcript text (default): True
|
||||
bos_id: Id of beginning of sequence symbol to append if not None
|
||||
eos_id: Id of end of sequence symbol to append if not None
|
||||
return_sample_id (bool): whether to return the sample_id as a part of each sample
|
||||
"""
|
||||
|
||||
@property
|
||||
|
@ -175,6 +190,7 @@ class _AudioTextDataset(Dataset):
|
|||
'a_sig_length': NeuralType(tuple('B'), LengthsType()),
|
||||
'transcripts': NeuralType(('B', 'T'), LabelsType()),
|
||||
'transcript_length': NeuralType(tuple('B'), LengthsType()),
|
||||
'sample_id': NeuralType(tuple('B'), LengthsType(), optional=True),
|
||||
}
|
||||
|
||||
def __init__(
|
||||
|
@ -191,6 +207,7 @@ class _AudioTextDataset(Dataset):
|
|||
bos_id: Optional[int] = None,
|
||||
eos_id: Optional[int] = None,
|
||||
pad_id: int = 0,
|
||||
return_sample_id: bool = False,
|
||||
):
|
||||
self.manifest_processor = ASRManifestProcessor(
|
||||
manifest_filepath=manifest_filepath,
|
||||
|
@ -204,6 +221,10 @@ class _AudioTextDataset(Dataset):
|
|||
)
|
||||
self.featurizer = WaveformFeaturizer(sample_rate=sample_rate, int_values=int_values, augmentor=augmentor)
|
||||
self.trim = trim
|
||||
self.return_sample_id = return_sample_id
|
||||
|
||||
def get_manifest_sample(self, sample_id):
|
||||
return self.manifest_processor.collection[sample_id]
|
||||
|
||||
def __getitem__(self, index):
|
||||
sample = self.manifest_processor.collection[index]
|
||||
|
@ -219,7 +240,10 @@ class _AudioTextDataset(Dataset):
|
|||
|
||||
t, tl = self.manifest_processor.process_text(index)
|
||||
|
||||
output = f, fl, torch.tensor(t).long(), torch.tensor(tl).long()
|
||||
if self.return_sample_id:
|
||||
output = f, fl, torch.tensor(t).long(), torch.tensor(tl).long(), index
|
||||
else:
|
||||
output = f, fl, torch.tensor(t).long(), torch.tensor(tl).long()
|
||||
|
||||
return output
|
||||
|
||||
|
@ -258,6 +282,7 @@ class AudioToCharDataset(_AudioTextDataset):
|
|||
normalize: whether to normalize transcript text (default): True
|
||||
bos_id: Id of beginning of sequence symbol to append if not None
|
||||
eos_id: Id of end of sequence symbol to append if not None
|
||||
return_sample_id (bool): whether to return the sample_id as a part of each sample
|
||||
"""
|
||||
|
||||
@property
|
||||
|
@ -269,6 +294,7 @@ class AudioToCharDataset(_AudioTextDataset):
|
|||
'a_sig_length': NeuralType(tuple('B'), LengthsType()),
|
||||
'transcripts': NeuralType(('B', 'T'), LabelsType()),
|
||||
'transcript_length': NeuralType(tuple('B'), LengthsType()),
|
||||
'sample_id': NeuralType(tuple('B'), LengthsType(), optional=True),
|
||||
}
|
||||
|
||||
def __init__(
|
||||
|
@ -289,6 +315,7 @@ class AudioToCharDataset(_AudioTextDataset):
|
|||
eos_id: Optional[int] = None,
|
||||
pad_id: int = 0,
|
||||
parser: Union[str, Callable] = 'en',
|
||||
return_sample_id: bool = False,
|
||||
):
|
||||
self.labels = labels
|
||||
|
||||
|
@ -309,6 +336,7 @@ class AudioToCharDataset(_AudioTextDataset):
|
|||
bos_id=bos_id,
|
||||
eos_id=eos_id,
|
||||
pad_id=pad_id,
|
||||
return_sample_id=return_sample_id,
|
||||
)
|
||||
|
||||
|
||||
|
@ -806,6 +834,7 @@ class AudioToBPEDataset(_AudioTextDataset):
|
|||
trim: Whether to trim silence segments
|
||||
use_start_end_token: Boolean which dictates whether to add [BOS] and [EOS]
|
||||
tokens to beginning and ending of speech respectively.
|
||||
return_sample_id (bool): whether to return the sample_id as a part of each sample
|
||||
"""
|
||||
|
||||
@property
|
||||
|
@ -817,6 +846,7 @@ class AudioToBPEDataset(_AudioTextDataset):
|
|||
'a_sig_length': NeuralType(tuple('B'), LengthsType()),
|
||||
'transcripts': NeuralType(('B', 'T'), LabelsType()),
|
||||
'transcript_length': NeuralType(tuple('B'), LengthsType()),
|
||||
'sample_id': NeuralType(tuple('B'), LengthsType(), optional=True),
|
||||
}
|
||||
|
||||
def __init__(
|
||||
|
@ -831,6 +861,7 @@ class AudioToBPEDataset(_AudioTextDataset):
|
|||
max_utts: int = 0,
|
||||
trim: bool = False,
|
||||
use_start_end_token: bool = True,
|
||||
return_sample_id: bool = False,
|
||||
):
|
||||
if use_start_end_token and hasattr(tokenizer, 'bos_token'):
|
||||
bos_id = tokenizer.bos_id
|
||||
|
@ -868,6 +899,7 @@ class AudioToBPEDataset(_AudioTextDataset):
|
|||
eos_id=eos_id,
|
||||
pad_id=pad_id,
|
||||
trim=trim,
|
||||
return_sample_id=return_sample_id,
|
||||
)
|
||||
|
||||
|
||||
|
@ -956,12 +988,13 @@ class _TarredAudioToTextDataset(IterableDataset):
|
|||
sampled at least once during 1 epoch.
|
||||
global_rank (int): Worker rank, used for partitioning shards. Defaults to 0.
|
||||
world_size (int): Total number of processes, used for partitioning shards. Defaults to 0.
|
||||
return_sample_id (bool): whether to return the sample_id as a part of each sample
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
audio_tar_filepaths: Union[str, List[str]],
|
||||
manifest_filepath: Union[str, List[str]],
|
||||
manifest_filepath: str,
|
||||
parser: Callable,
|
||||
sample_rate: int,
|
||||
int_values: bool = False,
|
||||
|
@ -977,6 +1010,7 @@ class _TarredAudioToTextDataset(IterableDataset):
|
|||
shard_strategy: str = "scatter",
|
||||
global_rank: int = 0,
|
||||
world_size: int = 0,
|
||||
return_sample_id: bool = False,
|
||||
):
|
||||
self.collection = collections.ASRAudioText(
|
||||
manifests_files=manifest_filepath,
|
||||
|
@ -992,6 +1026,7 @@ class _TarredAudioToTextDataset(IterableDataset):
|
|||
self.eos_id = eos_id
|
||||
self.bos_id = bos_id
|
||||
self.pad_id = pad_id
|
||||
self.return_sample_id = return_sample_id
|
||||
|
||||
valid_shard_strategies = ['scatter', 'replicate']
|
||||
if shard_strategy not in valid_shard_strategies:
|
||||
|
@ -1118,7 +1153,13 @@ class _TarredAudioToTextDataset(IterableDataset):
|
|||
t = t + [self.eos_id]
|
||||
tl += 1
|
||||
|
||||
return f, fl, torch.tensor(t).long(), torch.tensor(tl).long()
|
||||
if self.return_sample_id:
|
||||
return f, fl, torch.tensor(t).long(), torch.tensor(tl).long(), manifest_idx
|
||||
else:
|
||||
return f, fl, torch.tensor(t).long(), torch.tensor(tl).long()
|
||||
|
||||
def get_manifest_sample(self, sample_id):
|
||||
return self.collection[sample_id]
|
||||
|
||||
def __iter__(self):
|
||||
return self._dataset.__iter__()
|
||||
|
@ -1208,6 +1249,7 @@ class TarredAudioToCharDataset(_TarredAudioToTextDataset):
|
|||
sampled at least once during 1 epoch.
|
||||
global_rank (int): Worker rank, used for partitioning shards. Defaults to 0.
|
||||
world_size (int): Total number of processes, used for partitioning shards. Defaults to 0.
|
||||
return_sample_id (bool): whether to return the sample_id as a part of each sample
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
@ -1233,6 +1275,7 @@ class TarredAudioToCharDataset(_TarredAudioToTextDataset):
|
|||
shard_strategy: str = "scatter",
|
||||
global_rank: int = 0,
|
||||
world_size: int = 0,
|
||||
return_sample_id: bool = False,
|
||||
):
|
||||
self.labels = labels
|
||||
|
||||
|
@ -1258,6 +1301,7 @@ class TarredAudioToCharDataset(_TarredAudioToTextDataset):
|
|||
shard_strategy=shard_strategy,
|
||||
global_rank=global_rank,
|
||||
world_size=world_size,
|
||||
return_sample_id=return_sample_id,
|
||||
)
|
||||
|
||||
|
||||
|
@ -1332,12 +1376,13 @@ class TarredAudioToBPEDataset(_TarredAudioToTextDataset):
|
|||
sampled at least once during 1 epoch.
|
||||
global_rank (int): Worker rank, used for partitioning shards. Defaults to 0.
|
||||
world_size (int): Total number of processes, used for partitioning shards. Defaults to 0.
|
||||
return_sample_id (bool): whether to return the sample_id as a part of each sample
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
audio_tar_filepaths: Union[str, List[str]],
|
||||
manifest_filepath: Union[str, List[str]],
|
||||
manifest_filepath: str,
|
||||
tokenizer: 'nemo.collections.common.tokenizers.TokenizerSpec',
|
||||
sample_rate: int,
|
||||
int_values: bool = False,
|
||||
|
@ -1351,6 +1396,7 @@ class TarredAudioToBPEDataset(_TarredAudioToTextDataset):
|
|||
shard_strategy: str = "scatter",
|
||||
global_rank: int = 0,
|
||||
world_size: int = 0,
|
||||
return_sample_id: bool = False,
|
||||
):
|
||||
if use_start_end_token and hasattr(tokenizer, 'bos_token'):
|
||||
bos_id = tokenizer.bos_id
|
||||
|
@ -1393,4 +1439,5 @@ class TarredAudioToBPEDataset(_TarredAudioToTextDataset):
|
|||
shard_strategy=shard_strategy,
|
||||
global_rank=global_rank,
|
||||
world_size=world_size,
|
||||
return_sample_id=return_sample_id,
|
||||
)
|
||||
|
|
|
@ -140,6 +140,7 @@ class _AudioTextDALIDataset(Iterator):
|
|||
global_rank (int): Worker rank, used for partitioning shards. Defaults to 0.
|
||||
world_size (int): Total number of processes, used for partitioning shards. Defaults to 1.
|
||||
preprocessor_cfg (DictConfig): Preprocessor configuration. Supports AudioToMelSpectrogramPreprocessor and AudioToMFCCPreprocessor.
|
||||
return_sample_id (bool): whether to return the sample_id as a part of each sample (not supported yet).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
@ -162,8 +163,15 @@ class _AudioTextDALIDataset(Iterator):
|
|||
global_rank: int = 0,
|
||||
world_size: int = 1,
|
||||
preprocessor_cfg: DictConfig = None,
|
||||
return_sample_id: bool = False,
|
||||
):
|
||||
self.drop_last = drop_last # used by lr_scheduler
|
||||
if return_sample_id:
|
||||
raise ValueError(
|
||||
"Currently DALI data layers don't support returning the sample_id and return_sample_id can not be enabled."
|
||||
)
|
||||
self.return_sample_id = return_sample_id
|
||||
|
||||
if not HAVE_DALI:
|
||||
raise ModuleNotFoundError(
|
||||
f"{self} requires NVIDIA DALI to be installed. "
|
||||
|
@ -519,6 +527,7 @@ class AudioToCharDALIDataset(_AudioTextDALIDataset):
|
|||
global_rank (int): Worker rank, used for partitioning shards. Defaults to 0.
|
||||
world_size (int): Total number of processes, used for partitioning shards. Defaults to 1.
|
||||
preprocessor_cfg (DictConfig): Preprocessor configuration. Supports AudioToMelSpectrogramPreprocessor and AudioToMFCCPreprocessor.
|
||||
return_sample_id (bool): whether to return the sample_id as a part of each sample (not supported yet).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
@ -545,6 +554,7 @@ class AudioToCharDALIDataset(_AudioTextDALIDataset):
|
|||
global_rank: int = 0,
|
||||
world_size: int = 1,
|
||||
preprocessor_cfg: DictConfig = None,
|
||||
return_sample_id: bool = False,
|
||||
):
|
||||
self.labels = labels
|
||||
|
||||
|
@ -571,6 +581,7 @@ class AudioToCharDALIDataset(_AudioTextDALIDataset):
|
|||
global_rank=global_rank,
|
||||
world_size=world_size,
|
||||
preprocessor_cfg=preprocessor_cfg,
|
||||
return_sample_id=return_sample_id,
|
||||
)
|
||||
|
||||
|
||||
|
@ -607,6 +618,7 @@ class AudioToBPEDALIDataset(_AudioTextDALIDataset):
|
|||
preprocessor_cfg (DictConfig): Preprocessor configuration. Supports AudioToMelSpectrogramPreprocessor and AudioToMFCCPreprocessor.
|
||||
use_start_end_token (bool): Boolean which dictates whether to add [BOS] and [EOS] tokens to beginning and
|
||||
ending of speech respectively.
|
||||
return_sample_id (bool): whether to return the sample_id as a part of each sample (not supported yet).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
@ -627,7 +639,9 @@ class AudioToBPEDALIDataset(_AudioTextDALIDataset):
|
|||
world_size: int = 1,
|
||||
preprocessor_cfg: DictConfig = None,
|
||||
use_start_end_token: bool = True,
|
||||
return_sample_id: bool = False,
|
||||
):
|
||||
|
||||
if use_start_end_token and hasattr(tokenizer, 'bos_token'):
|
||||
bos_id = tokenizer.bos_id
|
||||
else:
|
||||
|
@ -670,4 +684,5 @@ class AudioToBPEDALIDataset(_AudioTextDALIDataset):
|
|||
global_rank=global_rank,
|
||||
world_size=world_size,
|
||||
preprocessor_cfg=preprocessor_cfg,
|
||||
return_sample_id=return_sample_id,
|
||||
)
|
||||
|
|
|
@ -12,11 +12,13 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Optional, Union
|
||||
import json
|
||||
from typing import Any, List, Optional, Union
|
||||
|
||||
import torch
|
||||
from omegaconf import DictConfig, open_dict
|
||||
from omegaconf.listconfig import ListConfig
|
||||
from pytorch_lightning.callbacks import BasePredictionWriter
|
||||
from torch.utils.data import ChainDataset
|
||||
|
||||
from nemo.collections.asr.data import audio_to_text, audio_to_text_dali
|
||||
|
@ -91,6 +93,7 @@ def get_char_dataset(config: dict, augmentor: Optional['AudioAugmentor'] = None)
|
|||
normalize=config.get('normalize_transcripts', False),
|
||||
trim=config.get('trim_silence', False),
|
||||
parser=config.get('parser', 'en'),
|
||||
return_sample_id=config.get('return_sample_id', False),
|
||||
)
|
||||
return dataset
|
||||
|
||||
|
@ -120,6 +123,7 @@ def get_bpe_dataset(
|
|||
max_utts=config.get('max_utts', 0),
|
||||
trim=config.get('trim_silence', False),
|
||||
use_start_end_token=config.get('use_start_end_token', True),
|
||||
return_sample_id=config.get('return_sample_id', False),
|
||||
)
|
||||
return dataset
|
||||
|
||||
|
@ -182,6 +186,7 @@ def get_tarred_dataset(
|
|||
shard_strategy=config.get('tarred_shard_strategy', 'scatter'),
|
||||
global_rank=global_rank,
|
||||
world_size=world_size,
|
||||
return_sample_id=config.get('return_sample_id', False),
|
||||
)
|
||||
else:
|
||||
dataset = audio_to_text.TarredAudioToBPEDataset(
|
||||
|
@ -200,6 +205,7 @@ def get_tarred_dataset(
|
|||
shard_strategy=config.get('tarred_shard_strategy', 'scatter'),
|
||||
global_rank=global_rank,
|
||||
world_size=world_size,
|
||||
return_sample_id=config.get('return_sample_id', False),
|
||||
)
|
||||
|
||||
datasets.append(dataset)
|
||||
|
@ -251,6 +257,7 @@ def get_dali_char_dataset(
|
|||
global_rank=global_rank,
|
||||
world_size=world_size,
|
||||
preprocessor_cfg=preprocessor_cfg,
|
||||
return_sample_id=config.get('return_sample_id', False),
|
||||
)
|
||||
return dataset
|
||||
|
||||
|
@ -295,10 +302,44 @@ def get_dali_bpe_dataset(
|
|||
global_rank=global_rank,
|
||||
world_size=world_size,
|
||||
preprocessor_cfg=preprocessor_cfg,
|
||||
return_sample_id=config.get('return_sample_id', False),
|
||||
)
|
||||
return dataset
|
||||
|
||||
|
||||
class ASRPredictionWriter(BasePredictionWriter):
|
||||
def __init__(self, dataset, output_file: str):
|
||||
super().__init__(write_interval="batch")
|
||||
self.outf = open(output_file, 'w')
|
||||
self.dataset = dataset
|
||||
self.samples_num = 0
|
||||
|
||||
def write_on_batch_end(
|
||||
self,
|
||||
trainer,
|
||||
pl_module: 'LightningModule',
|
||||
prediction: Any,
|
||||
batch_indices: List[int],
|
||||
batch: Any,
|
||||
batch_idx: int,
|
||||
dataloader_idx: int,
|
||||
):
|
||||
for sample_id, transcribed_text in prediction:
|
||||
item = {}
|
||||
sample = self.dataset.get_manifest_sample(sample_id)
|
||||
item["audio_filepath"] = sample.audio_file
|
||||
item["duration"] = sample.duration
|
||||
item["text"] = sample.text_raw
|
||||
item["pred_text"] = transcribed_text
|
||||
self.outf.write(json.dumps(item) + "\n")
|
||||
self.samples_num += 1
|
||||
return
|
||||
|
||||
def close_output_file(self):
|
||||
self.outf.close()
|
||||
return self.samples_num
|
||||
|
||||
|
||||
def convert_to_config_list(initial_list):
|
||||
if initial_list is None or initial_list == []:
|
||||
raise ValueError("manifest_filepaths and tarred_audio_filepaths must not be empty.")
|
||||
|
|
|
@ -12,16 +12,16 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from nemo.collections.asr.models.configs.asr_models_config import (
|
||||
ASRDatasetConfig,
|
||||
EncDecCTCConfig,
|
||||
EncDecCTCModelConfig,
|
||||
)
|
||||
from nemo.collections.asr.models.configs.classification_models_config import (
|
||||
EncDecClassificationConfig,
|
||||
EncDecClassificationDatasetConfig,
|
||||
EncDecClassificationModelConfig,
|
||||
)
|
||||
from nemo.collections.asr.models.configs.ctc_models_config import (
|
||||
EncDecCTCConfig,
|
||||
EncDecCTCDatasetConfig,
|
||||
EncDecCTCModelConfig,
|
||||
)
|
||||
from nemo.collections.asr.models.configs.matchboxnet_config import (
|
||||
EncDecClassificationModelConfigBuilder,
|
||||
MatchboxNetModelConfig,
|
||||
|
|
|
@ -27,15 +27,15 @@ from nemo.core.config import modelPT as model_cfg
|
|||
|
||||
|
||||
@dataclass
|
||||
class EncDecCTCDatasetConfig(nemo.core.classes.dataset.DatasetConfig):
|
||||
manifest_filepath: Optional[str] = None
|
||||
class ASRDatasetConfig(nemo.core.classes.dataset.DatasetConfig):
|
||||
manifest_filepath: Optional[Any] = None
|
||||
sample_rate: int = MISSING
|
||||
labels: List[str] = MISSING
|
||||
trim_silence: bool = False
|
||||
|
||||
# Tarred dataset support
|
||||
is_tarred: bool = False
|
||||
tarred_audio_filepaths: Optional[str] = None
|
||||
tarred_audio_filepaths: Optional[Any] = None
|
||||
tarred_shard_strategy: str = "scatter"
|
||||
shuffle_n: int = 0
|
||||
|
||||
|
@ -54,6 +54,7 @@ class EncDecCTCDatasetConfig(nemo.core.classes.dataset.DatasetConfig):
|
|||
bos_id: Optional[int] = None
|
||||
pad_id: int = 0
|
||||
use_start_end_token: bool = False
|
||||
return_sample_id: Optional[bool] = False
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -66,9 +67,9 @@ class EncDecCTCConfig(model_cfg.ModelConfig):
|
|||
labels: List[str] = MISSING
|
||||
|
||||
# Dataset configs
|
||||
train_ds: EncDecCTCDatasetConfig = EncDecCTCDatasetConfig(manifest_filepath=None, shuffle=True)
|
||||
validation_ds: EncDecCTCDatasetConfig = EncDecCTCDatasetConfig(manifest_filepath=None, shuffle=False)
|
||||
test_ds: EncDecCTCDatasetConfig = EncDecCTCDatasetConfig(manifest_filepath=None, shuffle=False)
|
||||
train_ds: ASRDatasetConfig = ASRDatasetConfig(manifest_filepath=None, shuffle=True)
|
||||
validation_ds: ASRDatasetConfig = ASRDatasetConfig(manifest_filepath=None, shuffle=False)
|
||||
test_ds: ASRDatasetConfig = ASRDatasetConfig(manifest_filepath=None, shuffle=False)
|
||||
|
||||
# Optimizer / Scheduler config
|
||||
optim: Optional[model_cfg.OptimConfig] = model_cfg.OptimConfig(sched=model_cfg.SchedConfig())
|
|
@ -17,7 +17,7 @@ from typing import Any, Callable, List, Optional
|
|||
|
||||
from omegaconf import MISSING
|
||||
|
||||
from nemo.collections.asr.models.configs import ctc_models_config as ctc_cfg
|
||||
from nemo.collections.asr.models.configs import asr_models_config as ctc_cfg
|
||||
from nemo.collections.asr.modules.audio_preprocessing import (
|
||||
AudioToMelSpectrogramPreprocessorConfig,
|
||||
SpectrogramAugmentationConfig,
|
||||
|
@ -174,13 +174,11 @@ class JasperModelConfig(ctc_cfg.EncDecCTCConfig):
|
|||
labels: List[str] = MISSING
|
||||
|
||||
# Dataset configs
|
||||
train_ds: ctc_cfg.EncDecCTCDatasetConfig = ctc_cfg.EncDecCTCDatasetConfig(
|
||||
train_ds: ctc_cfg.ASRDatasetConfig = ctc_cfg.ASRDatasetConfig(
|
||||
manifest_filepath=None, shuffle=True, trim_silence=True
|
||||
)
|
||||
validation_ds: ctc_cfg.EncDecCTCDatasetConfig = ctc_cfg.EncDecCTCDatasetConfig(
|
||||
manifest_filepath=None, shuffle=False
|
||||
)
|
||||
test_ds: ctc_cfg.EncDecCTCDatasetConfig = ctc_cfg.EncDecCTCDatasetConfig(manifest_filepath=None, shuffle=False)
|
||||
validation_ds: ctc_cfg.ASRDatasetConfig = ctc_cfg.ASRDatasetConfig(manifest_filepath=None, shuffle=False)
|
||||
test_ds: ctc_cfg.ASRDatasetConfig = ctc_cfg.ASRDatasetConfig(manifest_filepath=None, shuffle=False)
|
||||
|
||||
# Optimizer / Scheduler config
|
||||
optim: Optional[model_cfg.OptimConfig] = model_cfg.OptimConfig(sched=model_cfg.SchedConfig())
|
||||
|
@ -262,13 +260,13 @@ class EncDecCTCModelConfigBuilder(model_cfg.ModelConfigBuilder):
|
|||
# Note: Autocomplete for users wont work without these overrides
|
||||
# But practically it is not needed since python will infer at runtime
|
||||
|
||||
# def set_train_ds(self, cfg: Optional[ctc_cfg.EncDecCTCDatasetConfig] = None):
|
||||
# def set_train_ds(self, cfg: Optional[ctc_cfg.ASRDatasetConfig] = None):
|
||||
# super().set_train_ds(cfg)
|
||||
#
|
||||
# def set_validation_ds(self, cfg: Optional[ctc_cfg.EncDecCTCDatasetConfig] = None):
|
||||
# def set_validation_ds(self, cfg: Optional[ctc_cfg.ASRDatasetConfig] = None):
|
||||
# super().set_validation_ds(cfg)
|
||||
#
|
||||
# def set_test_ds(self, cfg: Optional[ctc_cfg.EncDecCTCDatasetConfig] = None):
|
||||
# def set_test_ds(self, cfg: Optional[ctc_cfg.ASRDatasetConfig] = None):
|
||||
# super().set_test_ds(cfg)
|
||||
|
||||
def _finalize_cfg(self):
|
||||
|
|
|
@ -506,6 +506,7 @@ class EncDecCTCModel(ASRModel, ExportableEncDecModel, ASRModuleMixin):
|
|||
"input_signal_length": NeuralType(tuple('B'), LengthsType(), optional=True),
|
||||
"processed_signal": NeuralType(('B', 'D', 'T'), SpectrogramType(), optional=True),
|
||||
"processed_signal_length": NeuralType(tuple('B'), LengthsType(), optional=True),
|
||||
"sample_id": NeuralType(tuple('B'), LengthsType(), optional=True),
|
||||
}
|
||||
|
||||
@property
|
||||
|
@ -596,6 +597,22 @@ class EncDecCTCModel(ASRModel, ExportableEncDecModel, ASRModuleMixin):
|
|||
|
||||
return {'loss': loss_value, 'log': tensorboard_logs}
|
||||
|
||||
def predict_step(self, batch, batch_idx, dataloader_idx=0):
|
||||
signal, signal_len, transcript, transcript_len, sample_id = batch
|
||||
if isinstance(batch, DALIOutputs) and batch.has_processed_signal:
|
||||
log_probs, encoded_len, predictions = self.forward(
|
||||
processed_signal=signal, processed_signal_length=signal_len
|
||||
)
|
||||
else:
|
||||
log_probs, encoded_len, predictions = self.forward(input_signal=signal, input_signal_length=signal_len)
|
||||
|
||||
transcribed_texts = self._wer.ctc_decoder_predictions_tensor(
|
||||
predictions=predictions, predictions_len=encoded_len, return_hypotheses=False,
|
||||
)
|
||||
|
||||
sample_id = sample_id.cpu().detach().numpy()
|
||||
return list(zip(sample_id, transcribed_texts))
|
||||
|
||||
def validation_step(self, batch, batch_idx, dataloader_idx=0):
|
||||
signal, signal_len, transcript, transcript_len = batch
|
||||
if isinstance(batch, DALIOutputs) and batch.has_processed_signal:
|
||||
|
|
|
@ -647,7 +647,6 @@ class EncDecRNNTModel(ASRModel, ASRModuleMixin, ExportableEncDecJointModel):
|
|||
if hasattr(self, '_trainer') and self._trainer is not None:
|
||||
log_every_n_steps = self._trainer.log_every_n_steps
|
||||
sample_id = self._trainer.global_step
|
||||
|
||||
else:
|
||||
log_every_n_steps = 1
|
||||
sample_id = batch_nb
|
||||
|
@ -699,6 +698,23 @@ class EncDecRNNTModel(ASRModel, ASRModuleMixin, ExportableEncDecJointModel):
|
|||
|
||||
return {'loss': loss_value}
|
||||
|
||||
def predict_step(self, batch, batch_idx, dataloader_idx=0):
|
||||
signal, signal_len, transcript, transcript_len, sample_id = batch
|
||||
|
||||
# forward() only performs encoder forward
|
||||
if isinstance(batch, DALIOutputs) and batch.has_processed_signal:
|
||||
encoded, encoded_len = self.forward(processed_signal=signal, processed_signal_length=signal_len)
|
||||
else:
|
||||
encoded, encoded_len = self.forward(input_signal=signal, input_signal_length=signal_len)
|
||||
del signal
|
||||
|
||||
best_hyp_text, all_hyp_text = self.decoding.rnnt_decoder_predictions_tensor(
|
||||
encoder_output=encoded, encoded_lengths=encoded_len, return_hypotheses=False
|
||||
)
|
||||
|
||||
sample_id = sample_id.cpu().detach().numpy()
|
||||
return list(zip(sample_id, best_hyp_text))
|
||||
|
||||
def validation_step(self, batch, batch_idx, dataloader_idx=0):
|
||||
signal, signal_len, transcript, transcript_len = batch
|
||||
|
||||
|
|
|
@ -269,6 +269,7 @@ class Typing(ABC):
|
|||
# Precompute metadata
|
||||
metadata = TypecheckMetadata(original_types=output_types, ignore_collections=ignore_collections)
|
||||
out_types_list = list(metadata.base_types.items())
|
||||
mandatory_out_types_list = list(metadata.mandatory_types.items())
|
||||
|
||||
# First convert all outputs to list/tuple format to check correct number of outputs
|
||||
if type(out_objects) in (list, tuple):
|
||||
|
@ -287,15 +288,15 @@ class Typing(ABC):
|
|||
pass
|
||||
|
||||
# In all other cases, python will wrap multiple outputs into an outer tuple.
|
||||
# As such, now the number of elements in this outer tuple should exactly match
|
||||
# the number of output types defined.
|
||||
elif len(out_types_list) != len(out_container):
|
||||
# Allow number of output arguments to be <= total output neural types and >= mandatory outputs.
|
||||
|
||||
elif len(out_container) > len(out_types_list) or len(out_container) < len(mandatory_out_types_list):
|
||||
raise TypeError(
|
||||
"Number of output arguments provided ({}) is not as expected ({}).\n"
|
||||
"This can be either because insufficient number of output NeuralTypes were provided,"
|
||||
"Number of output arguments provided ({}) is not as expected. It should be larger than {} and less than {}.\n"
|
||||
"This can be either because insufficient/extra number of output NeuralTypes were provided,"
|
||||
"or the provided NeuralTypes {} should enable container support "
|
||||
"(add '[]' to the NeuralType definition)".format(
|
||||
len(out_container), len(output_types), output_types
|
||||
len(out_container), len(out_types_list), len(mandatory_out_types_list), output_types
|
||||
)
|
||||
)
|
||||
|
||||
|
|
|
@ -105,5 +105,5 @@ class DatasetConfig:
|
|||
batch_size: int = 32
|
||||
drop_last: bool = False
|
||||
shuffle: bool = False
|
||||
num_workers: Optional[int] = None
|
||||
num_workers: Optional[int] = 0
|
||||
pin_memory: bool = True
|
||||
|
|
|
@ -237,7 +237,7 @@ class TestEncDecCTCModel:
|
|||
os.rename(old_tokenizer_dir + '.bkp', old_tokenizer_dir)
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_EncDecCTCDatasetConfig_for_AudioToBPEDataset(self):
|
||||
def test_ASRDatasetConfig_for_AudioToBPEDataset(self):
|
||||
# ignore some additional arguments as dataclass is generic
|
||||
IGNORE_ARGS = [
|
||||
'is_tarred',
|
||||
|
@ -261,10 +261,7 @@ class TestEncDecCTCModel:
|
|||
REMAP_ARGS = {'trim_silence': 'trim', 'labels': 'tokenizer'}
|
||||
|
||||
result = assert_dataclass_signature_match(
|
||||
audio_to_text.AudioToBPEDataset,
|
||||
configs.EncDecCTCDatasetConfig,
|
||||
ignore_args=IGNORE_ARGS,
|
||||
remap_args=REMAP_ARGS,
|
||||
audio_to_text.AudioToBPEDataset, configs.ASRDatasetConfig, ignore_args=IGNORE_ARGS, remap_args=REMAP_ARGS,
|
||||
)
|
||||
signatures_match, cls_subset, dataclass_subset = result
|
||||
|
||||
|
@ -273,7 +270,7 @@ class TestEncDecCTCModel:
|
|||
assert dataclass_subset is None
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_EncDecCTCDatasetConfig_for_TarredAudioToBPEDataset(self):
|
||||
def test_ASRDatasetConfig_for_TarredAudioToBPEDataset(self):
|
||||
# ignore some additional arguments as dataclass is generic
|
||||
IGNORE_ARGS = [
|
||||
'is_tarred',
|
||||
|
@ -303,7 +300,7 @@ class TestEncDecCTCModel:
|
|||
|
||||
result = assert_dataclass_signature_match(
|
||||
audio_to_text.TarredAudioToBPEDataset,
|
||||
configs.EncDecCTCDatasetConfig,
|
||||
configs.ASRDatasetConfig,
|
||||
ignore_args=IGNORE_ARGS,
|
||||
remap_args=REMAP_ARGS,
|
||||
)
|
||||
|
|
|
@ -240,7 +240,7 @@ class TestEncDecCTCModel:
|
|||
# trainer and exp manager should be there
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_EncDecCTCDatasetConfig_for_AudioToCharDataset(self):
|
||||
def test_ASRDatasetConfig_for_AudioToCharDataset(self):
|
||||
# ignore some additional arguments as dataclass is generic
|
||||
IGNORE_ARGS = [
|
||||
'is_tarred',
|
||||
|
@ -259,10 +259,7 @@ class TestEncDecCTCModel:
|
|||
REMAP_ARGS = {'trim_silence': 'trim'}
|
||||
|
||||
result = assert_dataclass_signature_match(
|
||||
audio_to_text.AudioToCharDataset,
|
||||
configs.EncDecCTCDatasetConfig,
|
||||
ignore_args=IGNORE_ARGS,
|
||||
remap_args=REMAP_ARGS,
|
||||
audio_to_text.AudioToCharDataset, configs.ASRDatasetConfig, ignore_args=IGNORE_ARGS, remap_args=REMAP_ARGS,
|
||||
)
|
||||
signatures_match, cls_subset, dataclass_subset = result
|
||||
|
||||
|
@ -271,7 +268,7 @@ class TestEncDecCTCModel:
|
|||
assert dataclass_subset is None
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_EncDecCTCDatasetConfig_for_TarredAudioToCharDataset(self):
|
||||
def test_ASRDatasetConfig_for_TarredAudioToCharDataset(self):
|
||||
# ignore some additional arguments as dataclass is generic
|
||||
IGNORE_ARGS = [
|
||||
'is_tarred',
|
||||
|
@ -294,7 +291,7 @@ class TestEncDecCTCModel:
|
|||
|
||||
result = assert_dataclass_signature_match(
|
||||
audio_to_text.TarredAudioToCharDataset,
|
||||
configs.EncDecCTCDatasetConfig,
|
||||
configs.ASRDatasetConfig,
|
||||
ignore_args=IGNORE_ARGS,
|
||||
remap_args=REMAP_ARGS,
|
||||
)
|
||||
|
|
|
@ -341,7 +341,6 @@ class TestASRDatasets:
|
|||
|
||||
num_samples = 10
|
||||
batch_size = 1
|
||||
device = 'gpu' if torch.cuda.is_available() else 'cpu'
|
||||
texts = []
|
||||
|
||||
with tempfile.NamedTemporaryFile(mode='w', encoding='utf-8') as f:
|
||||
|
@ -394,7 +393,6 @@ class TestASRDatasets:
|
|||
)
|
||||
ref_preprocessor = EncDecCTCModel.from_config_dict(preprocessor_cfg)
|
||||
|
||||
count = 0
|
||||
for ref_data, dali_data in zip(ref_dataloader, dali_dataset):
|
||||
ref_audio, ref_audio_len, _, _ = ref_data
|
||||
ref_features, ref_features_len = ref_preprocessor(input_signal=ref_audio, length=ref_audio_len)
|
||||
|
|
Loading…
Reference in a new issue