Add Transducer documentation (#3015)

* Add RNNT documentation

Signed-off-by: smajumdar <titu1994@gmail.com>

* Revert unnecessary changes

Signed-off-by: smajumdar <titu1994@gmail.com>

* Update docs for RNNT

Signed-off-by: smajumdar <titu1994@gmail.com>

Co-authored-by: Oleksii Kuchaiev <okuchaiev@users.noreply.github.com>
This commit is contained in:
Somshubra Majumdar 2021-10-21 10:40:11 -07:00 committed by GitHub
parent 4e544676f2
commit 9f99918974
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 283 additions and 8 deletions

View file

@ -27,7 +27,7 @@ NeMo supports both character-based and BPE-based models for N-gram LMs. An N-gra
decoders on top of the ASR models to produce more accurate candidates. The beam search decoder would incorporate
the scores produced by the N-gram LM into its score calculations as the following:
.. code::
.. code-block::
final_score = acoustic_score + beam_alpha*lm_score + beam_beta*seq_length
@ -51,7 +51,7 @@ detected automatically from the type of the model.
You may train the N-gram model as the following:
.. code::
.. code-block::
python train_kenlm.py --nemo_model_file <path to the .nemo file of the model> \
--train_file <path to the training text or JSON manifest file \
@ -63,7 +63,7 @@ The train file specified by `--train_file` can be a text file or JSON manifest.
other than `.json`, it assumes that data format is plain text. For plain text format, each line should contain one
sample. For JSON manifest file, the file need to contain json formatted samples per each line like this:
.. code::
.. code-block::
{"audio_filepath": "/data_path/file1.wav", "text": "The transcript of the audio file."}
@ -99,7 +99,7 @@ The script to evaluate an ASR model with beam search decoding and N-gram models
You may evaluate an ASR model as the following:
.. code::
.. code-block::
python eval_beamsearch_ngram.py --nemo_model_file <path to the .nemo file of the model> \
--input_manifest <path to the evaluation JSON manifest file \
@ -180,7 +180,7 @@ You may specify a single or list of values for each of these parameters to perfo
beam search decoding on all the combinations of the these three hyperparameters.
For instance, the following set of parameters would results in 2*1*2=4 beam search decodings:
.. code::
.. code-block::
python eval_beamsearch_ngram.py ... \
--beam_width 64 128 \
@ -223,7 +223,7 @@ search decoding or the result of fusion with an N-gram LM. You may generate this
The neural rescorer would rescore the beams/candidates by using two parameters of `rescorer_alpha` and `rescorer_beta` as the following:
.. code::
.. code-block::
final_score = beam_search_score + rescorer_alpha*neural_rescorer_score + rescorer_beta*seq_length
@ -240,7 +240,8 @@ You may follow the following steps to evaluate a neural LM:
#. Rescore the candidates by `scripts/asr_language_modeling/neural_rescorer/eval_neural_rescorer.py <https://github.com/NVIDIA/NeMo/blob/stable/scripts/asr_language_modeling/neural_rescorer/eval_neural_rescorer.py>`__.
.. code::
.. code-block::
python eval_neural_rescorer.py
--lm_model=[path to .nemo file of the LM or the name of a HF pretrained model]
--beams_file=[path to beams .tsv file]

View file

@ -484,9 +484,237 @@ Conformer-Transducer
Please refer to the model page of `Conformer-Transducer <./models.html#Conformer-Transducer>`__ for more information on this model.
Fine-tuning Configurations
Transducer Configurations
-------------------------
All CTC-based ASR model configs can be modified to support Transducer loss training. Below, we discuss the modifications required in the config to enable Transducer training. All modifications are made to the ``model`` config.
Model Defaults
~~~~~~~~~~~~~~~~~~~~
It is a subsection to the model config representing the default values shared across the entire model represented as ``model.model_defaults``.
There are three values that are primary components of a transducer model. They are :
* ``enc_hidden``: The hidden dimension of the final layer of the Encoder network.
* ``pred_hidden``: The hidden dimension of the final layer of the Prediction network.
* ``joint_hidden``: The hidden dimension of the intermediate layer of the Joint network.
One can access these values inside the config by using OmegaConf interpolation as follows :
.. code-block:: yaml
model:
...
model_defaults:
enc_hidden: 256
pred_hidden: 256
joint_hidden: 256
...
decoder:
...
prednet:
pred_hidden: ${model.model_defaults.pred_hidden}
Acoustic Encoder Model
~~~~~~~~~~~~~~~~~~~~~~
The transducer model is comprised of three models combined. One of these models is the Acoustic (encoder) model. We should be able to drop in any CTC Acoustic model config into this section of the transducer config.
The only condition that needs to be met is that **the final layer of the acoustic model must have the hidden dimension defined in ``model_defaults.enc_hidden``**.
Decoder / Prediction Model
~~~~~~~~~~~~~~~~~~~~~~~~~~
The Prediction model is generally an autoregressive, causal model that consumes text tokens and returns embeddings that will be used by the Joint model. The base config for an LSTM based Prediction network can be found in the the ``decoder`` section of `ContextNet <./models.html#ContextNet>`__ or other Transducer architectures. For further information refer to the ``Intro to Transducers`` tutorial in the ASR tutorial section.
**This config can be copy-pasted into any custom transducer model with no modification.**
Let us discuss some of the important arguments:
* ``blank_as_pad``: In ordinary transducer models, the embedding matrix does not acknowledge the ``Transducer Blank`` token (similar to CTC Blank). However, this causes the autoregressive loop to be more complicated and less efficient. Instead, this flag which is set by default, will add the ``Transducer Blank`` token to the embedding matrix - and use it as a pad value (zeros tensor). This enables more efficient inference without harming training. For further information refer to the ``Intro to Transducers`` tutorial in the ASR tutorial section.
* ``prednet.pred_hidden``: The hidden dimension of the LSTM and the output dimension of the Prediction network.
.. code-block:: yaml
decoder:
_target_: nemo.collections.asr.modules.RNNTDecoder
normalization_mode: null
random_state_sampling: false
blank_as_pad: true
prednet:
pred_hidden: ${model.model_defaults.pred_hidden}
pred_rnn_layers: 1
t_max: null
dropout: 0.0
Joint Model
~~~~~~~~~~~
The Joint model is a simple feed-forward Multi-Layer Perceptron network. This MLP accepts the output of the Acoustic and Prediction models and computes a joint probability distribution over the entire vocabulary space. The base config for the Joint network can be found in the the ``joint`` section of `ContextNet <./models.html#ContextNet>`__ or other Transducer architectures. For further information refer to the ``Intro to Transducers`` tutorial in the ASR tutorial section.
**This config can be copy-pasted into any custom transducer model with no modification.**
The Joint model config has several essential components which we discuss below :
* ``log_softmax``: Due to the cost of computing softmax on such large tensors, the Numba CUDA implementation of RNNT loss will implicitly compute the log softmax when called (so its inputs should be logits). The CPU version of the loss doesn't face such memory issues so it requires log-probabilities instead. Since the behaviour is different for CPU-GPU, the ``None`` value will automatically switch behaviour dependent on whether the input tensor is on a CPU or GPU device.
* ``preserve_memory``: This flag will call ``torch.cuda.empty_cache()`` at certain critical sections when computing the Joint tensor. While this operation might allow us to preserve some memory, the empty_cache() operation is tremendously slow and will slow down training by an order of magnitude or more. It is available to use but not recommended.
* ``fuse_loss_wer``: This flag performs "batch splitting" and then "fused loss + metric" calculation. It will be discussed in detail in the next tutorial that will train a Transducer model.
* ``fused_batch_size``: When the above flag is set to True, the model will have two distinct "batch sizes". The batch size provided in the three data loader configs (``model.*_ds.batch_size``) will now be the ``Acoustic model`` batch size, whereas the ``fused_batch_size`` will be the batch size of the ``Prediction model``, the ``Joint model``, the ``transducer loss`` module and the ``decoding`` module.
* ``jointnet.joint_hidden``: The hidden intermediate dimension of the joint network.
.. code-block:: yaml
joint:
_target_: nemo.collections.asr.modules.RNNTJoint
log_softmax: null # sets it according to cpu/gpu device
# fused mode
fuse_loss_wer: false
fused_batch_size: 16
jointnet:
joint_hidden: ${model.model_defaults.joint_hidden}
activation: "relu"
dropout: 0.0
Effect of Batch Splitting / Fused Batch step
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
The following information below explain why memory is an issue when training Transducer models and how NeMo tackles the issue with its Fused Batch step. The material can be read for a thorough understanding, otherwise, it can be skipped. You can also follow these steps in the "ASR_with_Transducers" tutorial.
**Diving deeper into the memory costs of Transducer Joint**
One of the significant limitations of Transducers is the exorbitant memory cost of computing the Joint module. The Joint module is comprised of two steps.
1) Projecting the Acoustic and Transcription feature dimensions to some standard hidden dimension (specified by model.model_defaults.joint_hidden)
2) Projecting this intermediate hidden dimension to the final vocabulary space to obtain the transcription.
Take the following example.
BS=32 ; T (after 2x stride) = 800, U (with character encoding) = 400-450 tokens, Vocabulary size V = 28 (26 alphabet chars, space and apostrophe). Let the hidden dimension of the Joint model be 640 (Most Google Transducer papers use hidden dimension of 640).
* :math:`Memory \, (Hidden, \, gb) = 32 \times 800 \times 450 \times 640 \times 4 = 29.49` gigabytes (4 bytes per float).
* :math:`Memory \, (Joint, \, gb) = 32 \times 800 \times 450 \times 28 \times 4 = 1.290` gigabytes (4 bytes per float)
**NOTE**: This is just for the forward pass! We need to double this memory to store gradients! This much memory is also just for the Joint model **alone**. Far more memory is required for the Prediction model as well as the large Acoustic model itself and its gradients!
Even with mixed precision, that's $\sim 30$ GB of GPU RAM for just 1 part of the network + its gradients.
Effect of Fused Batch Step
^^^^^^^^^^^^^^^^^^^^^^^^^^
The fundamental problem is that the joint tensor grows in size when ``[T x U]`` grows in size. This growth in memory cost is due to many reasons - either by model construction (downsampling) or the choice of dataset preprocessing (character tokenization vs. sub-word tokenization).
Another dimension that NeMo can control is **batch**. Due to how we batch our samples, small and large samples all get clumped together into a single batch. So even though the individual samples are not all as long as the maximum length of T and U in that batch, when a batch of such samples is constructed, it will consume a significant amount of memory for the sake of compute efficiency.
So as is always the case - **trade-off compute speed for memory savings**.
The fused operation goes as follows :
1) Forward the entire acoustic model in a single pass. (Use global batch size here for acoustic model - found in ``model.*_ds.batch_size``)
2) Split the Acoustic Model's logits by ``fused_batch_size`` and loop over these sub-batches.
3) Construct a sub-batch of same ``fused_batch_size`` for the Prediction model. Now the target sequence length is :math:`U_{sub-batch} < U`.
4) Feed this :math:`U_{sub-batch}` into the Joint model, along with a sub-batch from the Acoustic model (with :math:`T_{sub-batch} < T)`. Remember, we only have to slice off a part of the acoustic model here since we have the full batch of samples :math:`(B, T, D)` from the acoustic model.
5) Performing steps (3) and (4) yields :math:`T_{sub-batch}` and :math:`U_{sub-batch}`. Perform sub-batch joint step - costing an intermediate :math:`(B, T_{sub-batch}, U_{sub-batch}, V)` in memory.
6) Compute loss on sub-batch and preserve in a list to be later concatenated.
7) Compute sub-batch metrics (such as Character / Word Error Rate) using the above Joint tensor and sub-batch of ground truth labels. Preserve the scores to be averaged across the entire batch later.
8) Delete the sub-batch joint matrix :math:`(B, T_{sub-batch}, U_{sub-batch}, V)`. Only gradients from .backward() are preserved now in the computation graph.
9) Repeat steps (3) - (8) until all sub-batches are consumed.
10) Cleanup step. Compute full batch WER and log. Concatenate loss list and pass to PTL to compute the equivalent of the original (full batch) Joint step. Delete ancillary objects necessary for sub-batching.
Transducer Decoding
~~~~~~~~~~~~~~~~~~~
Models which have been trained with CTC can transcribe text simply by performing a regular argmax over the output of their decoder. For transducer-based models, the three networks must operate in a synchronized manner in order to transcribe the acoustic features. The base config for the Transducer decoding step can be found in the the ``decoding`` section of `ContextNet <./models.html#ContextNet>`__ or other Transducer architectures. For further information refer to the ``Intro to Transducers`` tutorial in the ASR tutorial section.
**This config can be copy-pasted into any custom transducer model with no modification.**
The most important component at the top level is the ``strategy``. It can take one of many values:
* ``greedy``: This is sample-level greedy decoding. It is generally exceptionally slow as each sample in the batch will be decoded independently. For publications, this should be used alongside batch size of 1 for exact results.
* ``greedy_batch``: This is the general default and should nearly match the ``greedy`` decoding scores (if the acoustic features are not affected by feature mixing in batch mode). Even for small batch sizes, this strategy is significantly faster than ``greedy``.
* ``beam``: Runs beam search with the implicit language model of the Prediction model. It will generally be quite slow, and might need some tuning of the beam size to get better transcriptions.
* ``tsd``: Time synchronous decoding. Please refer to the paper: `Alignment-Length Synchronous Decoding for RNN Transducer <https://ieeexplore.ieee.org/document/9053040>`_ for details on the algorithm implemented. Time synchronous decoding (TSD) execution time grows by the factor T * max_symmetric_expansions. For longer sequences, T is greater and can therefore take a long time for beams to obtain good results. TSD also requires more memory to execute.
* ``alsd``: Alignment-length synchronous decoding. Please refer to the paper: `Alignment-Length Synchronous Decoding for RNN Transducer <https://ieeexplore.ieee.org/document/9053040>`_ for details on the algorithm implemented. Alignment-length synchronous decoding (ALSD) execution time is faster than TSD, with a growth factor of T + U_max, where U_max is the maximum target length expected during execution. Generally, T + U_max < T * max_symmetric_expansions. However, ALSD beams are non-unique. Therefore it is required to use larger beam sizes to achieve the same (or close to the same) decoding accuracy as TSD. For a given decoding accuracy, it is possible to attain faster decoding via ALSD than TSD.
* ``maes``: Modified Adaptive Expansion Search Decoding. Please refer to the paper `Accelerating RNN Transducer Inference via Adaptive Expansion Search <https://ieeexplore.ieee.org/document/9250505>`_. Modified Adaptive Synchronous Decoding (mAES) execution time is adaptive w.r.t the number of expansions (for tokens) required per timestep. The number of expansions can usually be constrained to 1 or 2, and in most cases 2 is sufficient. This beam search technique can possibly obtain superior WER while sacrificing some evaluation time.
.. code-block:: yaml
decoding:
strategy: "greedy_batch"
# greedy strategy config
greedy:
max_symbols: 30
# beam strategy config
beam:
beam_size: 2
score_norm: true
softmax_temperature: 1.0 # scale the logits by some temperature prior to softmax
tsd_max_sym_exp: 10 # for Time Synchronous Decoding, int > 0
alsd_max_target_len: 5.0 # for Alignment-Length Synchronous Decoding, float > 1.0
maes_num_steps: 2 # for modified Adaptive Expansion Search, int > 0
maes_prefix_alpha: 1 # for modified Adaptive Expansion Search, int > 0
maes_expansion_beta: 2 # for modified Adaptive Expansion Search, int >= 0
maes_expansion_gamma: 2.3 # for modified Adaptive Expansion Search, float >= 0
Transducer Loss
~~~~~~~~~~~~~~~
This section configures the type of Transducer loss itself, along with possible sub-sections. By default, an optimized implementation of Transducer loss will be used which depends on Numba for CUDA acceleration. The base config for the Transducer loss section can be found in the the ``loss`` section of `ContextNet <./models.html#ContextNet>`__ or other Transducer architectures. For further information refer to the ``Intro to Transducers`` tutorial in the ASR tutorial section.
**This config can be copy-pasted into any custom transducer model with no modification.**
The loss config is based on a resolver pattern and can be used as follows:
1) ``loss_name``: ``default`` is generally a good option. Will select one of the available resolved losses and match the kwargs from a sub-configs passed via explicit ``{loss_name}_kwargs`` sub-config.
2) ``{loss_name}_kwargs``: This sub-config is passed to the resolved loss above and can be used to configure the resolved loss.
.. code-block:: yaml
loss:
loss_name: "default"
warprnnt_numba_kwargs:
fastemit_lambda: 0.0
FastEmit Regularization
^^^^^^^^^^^^^^^^^^^^^^^
FastEmit Regularization is supported for the default Numba based WarpRNNT loss. Recently proposed regularization approach - `FastEmit: Low-latency Streaming ASR with Sequence-level Emission Regularization <https://arxiv.org/abs/2010.11148>`_ allows us near-direct control over the latency of transducer models.
Refer to the above paper for results and recommendations of ``fastemit_lambda``.
Fine-tuning Configurations
--------------------------
All ASR scripts support easy fine-tuning by partially/fully loading the pretrained weights from a checkpoint into the currently instantiated model. Pre-trained weights can be provided in multiple ways -
1) Providing a path to a NeMo model (via ``init_from_nemo_model``)

View file

@ -24,6 +24,7 @@ How to Use Model Export
-----------------------
The following arguments are for ``Exportable.export()``. In most cases, you should only supply the name of the output file and use all defaults:
.. code-block:: Python
def export(
self,
output: str,
@ -99,6 +100,7 @@ Those are needed for inferring in/out names and dynamic axes. If your model deri
Your model should also have an export-friendly ``forward()`` method - that can mean different things for ONNX ant TorchScript. For ONNX, you can't have forced named parameters without default, like ``forward(self, *, text)``. For TorchScript, you should avoid ``None`` and use ``Optional`` instead. The criterias are highly volatile and may change with every PyTorch version, so it's a trial-and-error process. There is also the general issue that in many cases, ``forward()`` for inference can be simplified and even use less inputs/outputs. To address this, ``Exportable`` looks for ``forward_for_export()`` method in your model and uses that instead of ``forward()`` to export:
.. code-block:: Python
# Uses forced named args, many default parameters.
def forward(
self,

View file

@ -379,6 +379,28 @@ class RNNTDecoding(AbstractRNNTDecoding):
By default, a float of 2.0 is used so that a target sequence can be at most twice
as long as the acoustic model output length T.
maes_num_steps: Number of adaptive steps to take. From the paper, 2 steps is generally sufficient,
and can be reduced to 1 to improve decoding speed while sacrificing some accuracy. int > 0.
maes_prefix_alpha: Maximum prefix length in prefix search. Must be an integer, and is advised to keep this as 1
in order to reduce expensive beam search cost later. int >= 0.
maes_expansion_beta: Maximum number of prefix expansions allowed, in addition to the beam size.
Effectively, the number of hypothesis = beam_size + maes_expansion_beta. Must be an int >= 0,
and affects the speed of inference since large values will perform large beam search in the next step.
maes_expansion_gamma: Float pruning threshold used in the prune-by-value step when computing the expansions.
The default (2.3) is selected from the paper. It performs a comparison (max_log_prob - gamma <= log_prob[v])
where v is all vocabulary indices in the Vocab set and max_log_prob is the "most" likely token to be
predicted. Gamma therefore provides a margin of additional tokens which can be potential candidates for
expansion apart from the "most likely" candidate.
Lower values will reduce the number of expansions (by increasing pruning-by-value, thereby improving speed
but hurting accuracy). Higher values will increase the number of expansions (by reducing pruning-by-value,
thereby reducing speed but potentially improving accuracy). This is a hyper parameter to be experimentally
tuned on a validation set.
softmax_temperature: Scales the logits of the joint prior to computing log_softmax.
decoder: The Decoder/Prediction network module.
joint: The Joint network module.
vocabulary: The vocabulary (excluding the RNNT blank token) which will be used for decoding.

View file

@ -85,6 +85,28 @@ class RNNTBPEDecoding(AbstractRNNTDecoding):
By default, a float of 2.0 is used so that a target sequence can be at most twice
as long as the acoustic model output length T.
maes_num_steps: Number of adaptive steps to take. From the paper, 2 steps is generally sufficient,
and can be reduced to 1 to improve decoding speed while sacrificing some accuracy. int > 0.
maes_prefix_alpha: Maximum prefix length in prefix search. Must be an integer, and is advised to keep this as 1
in order to reduce expensive beam search cost later. int >= 0.
maes_expansion_beta: Maximum number of prefix expansions allowed, in addition to the beam size.
Effectively, the number of hypothesis = beam_size + maes_expansion_beta. Must be an int >= 0,
and affects the speed of inference since large values will perform large beam search in the next step.
maes_expansion_gamma: Float pruning threshold used in the prune-by-value step when computing the expansions.
The default (2.3) is selected from the paper. It performs a comparison (max_log_prob - gamma <= log_prob[v])
where v is all vocabulary indices in the Vocab set and max_log_prob is the "most" likely token to be
predicted. Gamma therefore provides a margin of additional tokens which can be potential candidates for
expansion apart from the "most likely" candidate.
Lower values will reduce the number of expansions (by increasing pruning-by-value, thereby improving speed
but hurting accuracy). Higher values will increase the number of expansions (by reducing pruning-by-value,
thereby reducing speed but potentially improving accuracy). This is a hyper parameter to be experimentally
tuned on a validation set.
softmax_temperature: Scales the logits of the joint prior to computing log_softmax.
decoder: The Decoder/Prediction network module.
joint: The Joint network module.
tokenizer: The tokenizer which will be used for decoding.