ASR patches for v1.0.0 (#2207)
* Multiple updates to RNNT add initialization Signed-off-by: smajumdar <titu1994@gmail.com> * Correct name of initilization Signed-off-by: smajumdar <titu1994@gmail.com> * Update dockerignore Signed-off-by: smajumdar <titu1994@gmail.com> * Fix RNNT WER calculation Signed-off-by: smajumdar <titu1994@gmail.com> * Address comments Signed-off-by: smajumdar <titu1994@gmail.com>
This commit is contained in:
parent
d1540e8bff
commit
d1b37ed30a
|
@ -15,3 +15,5 @@ coverage.xml
|
|||
*,cover
|
||||
*.log
|
||||
.git
|
||||
**/*.nemo
|
||||
**/*.ckpt
|
||||
|
|
|
@ -560,6 +560,7 @@ class EncDecCTCModel(ASRModel, ExportableEncDecModel, ASRModuleMixin):
|
|||
predictions_lengths=encoded_len,
|
||||
)
|
||||
wer, _, _ = self._wer.compute()
|
||||
self._wer.reset()
|
||||
tensorboard_logs.update({'training_batch_wer': wer})
|
||||
|
||||
return {'loss': loss_value, 'log': tensorboard_logs}
|
||||
|
@ -580,6 +581,7 @@ class EncDecCTCModel(ASRModel, ExportableEncDecModel, ASRModuleMixin):
|
|||
predictions=predictions, targets=transcript, target_lengths=transcript_len, predictions_lengths=encoded_len
|
||||
)
|
||||
wer, wer_num, wer_denom = self._wer.compute()
|
||||
self._wer.reset()
|
||||
return {
|
||||
'val_loss': loss_value,
|
||||
'val_wer_num': wer_num,
|
||||
|
|
|
@ -112,6 +112,38 @@ class EncDecRNNTModel(ASRModel, ASRModuleMixin):
|
|||
self.joint.set_loss(self.loss)
|
||||
self.joint.set_wer(self.wer)
|
||||
|
||||
self.setup_optim_normalization()
|
||||
|
||||
def setup_optim_normalization(self):
|
||||
"""
|
||||
Helper method to setup normalization of certain parts of the model prior to the optimization step.
|
||||
|
||||
Supported pre-optimization normalizations are as follows:
|
||||
|
||||
.. code-block:: yaml
|
||||
|
||||
# Variation Noise injection
|
||||
model:
|
||||
variational_noise:
|
||||
std: 0.0
|
||||
start_step: 0
|
||||
|
||||
# Joint - Length normalization
|
||||
model:
|
||||
normalize_joint_txu: false
|
||||
|
||||
# Encoder Network - gradient normalization
|
||||
model:
|
||||
normalize_encoder_norm: false
|
||||
|
||||
# Decoder / Prediction Network - gradient normalization
|
||||
model:
|
||||
normalize_decoder_norm: false
|
||||
|
||||
# Joint - gradient normalization
|
||||
model:
|
||||
normalize_joint_norm: false
|
||||
"""
|
||||
# setting up the variational noise for the decoder
|
||||
if hasattr(self.cfg, 'variational_noise'):
|
||||
self._optim_variational_noise_std = self.cfg['variational_noise'].get('std', 0)
|
||||
|
@ -120,6 +152,19 @@ class EncDecRNNTModel(ASRModel, ASRModuleMixin):
|
|||
self._optim_variational_noise_std = 0
|
||||
self._optim_variational_noise_start = 0
|
||||
|
||||
# Setup normalized gradients for model joint by T x U scaling factor (joint length normalization)
|
||||
self._optim_normalize_joint_txu = self.cfg.get('normalize_joint_txu', False)
|
||||
self._optim_normalize_txu = None
|
||||
|
||||
# Setup normalized encoder norm for model
|
||||
self._optim_normalize_encoder_norm = self.cfg.get('normalize_encoder_norm', False)
|
||||
|
||||
# Setup normalized decoder norm for model
|
||||
self._optim_normalize_decoder_norm = self.cfg.get('normalize_decoder_norm', False)
|
||||
|
||||
# Setup normalized joint norm for model
|
||||
self._optim_normalize_joint_norm = self.cfg.get('normalize_joint_norm', False)
|
||||
|
||||
def extract_rnnt_loss_cfg(self, cfg: Optional[DictConfig]):
|
||||
"""
|
||||
Helper method to extract the rnnt loss name, and potentially its kwargs
|
||||
|
@ -580,6 +625,7 @@ class EncDecRNNTModel(ASRModel, ASRModuleMixin):
|
|||
if (sample_id + 1) % log_every_n_steps == 0:
|
||||
self.wer.update(encoded, encoded_len, transcript, transcript_len)
|
||||
_, scores, words = self.wer.compute()
|
||||
self.wer.reset()
|
||||
tensorboard_logs.update({'training_batch_wer': scores.float() / words})
|
||||
|
||||
else:
|
||||
|
@ -607,6 +653,10 @@ class EncDecRNNTModel(ASRModel, ASRModuleMixin):
|
|||
# Log items
|
||||
self.log_dict(tensorboard_logs)
|
||||
|
||||
# Preserve batch acoustic model T and language model U parameters if normalizing
|
||||
if self._optim_normalize_joint_txu:
|
||||
self._optim_normalize_txu = [encoded_len.max(), transcript_len.max()]
|
||||
|
||||
return {'loss': loss_value}
|
||||
|
||||
def validation_step(self, batch, batch_idx, dataloader_idx=0):
|
||||
|
@ -635,6 +685,7 @@ class EncDecRNNTModel(ASRModel, ASRModuleMixin):
|
|||
|
||||
self.wer.update(encoded, encoded_len, transcript, transcript_len)
|
||||
wer, wer_num, wer_denom = self.wer.compute()
|
||||
self.wer.reset()
|
||||
|
||||
tensorboard_logs['val_wer_num'] = wer_num
|
||||
tensorboard_logs['val_wer_denom'] = wer_denom
|
||||
|
@ -743,3 +794,32 @@ class EncDecRNNTModel(ASRModel, ASRModuleMixin):
|
|||
dtype=param.dtype,
|
||||
)
|
||||
param.grad.data.add_(noise)
|
||||
|
||||
if self._optim_normalize_joint_txu:
|
||||
T, U = self._optim_normalize_txu
|
||||
if T is not None and U is not None:
|
||||
for param_name, param in self.encoder.named_parameters():
|
||||
if param.grad is not None:
|
||||
param.grad.data.div_(U)
|
||||
|
||||
for param_name, param in self.decoder.named_parameters():
|
||||
if param.grad is not None:
|
||||
param.grad.data.div_(T)
|
||||
|
||||
if self._optim_normalize_encoder_norm:
|
||||
for param_name, param in self.encoder.named_parameters():
|
||||
if param.grad is not None:
|
||||
norm = param.grad.norm()
|
||||
param.grad.data.div_(norm)
|
||||
|
||||
if self._optim_normalize_decoder_norm:
|
||||
for param_name, param in self.decoder.named_parameters():
|
||||
if param.grad is not None:
|
||||
norm = param.grad.norm()
|
||||
param.grad.data.div_(norm)
|
||||
|
||||
if self._optim_normalize_joint_norm:
|
||||
for param_name, param in self.joint.named_parameters():
|
||||
if param.grad is not None:
|
||||
norm = param.grad.norm()
|
||||
param.grad.data.div_(norm)
|
||||
|
|
|
@ -66,6 +66,10 @@ class RNNTDecoder(rnnt_abstract.AbstractRNNTDecoder):
|
|||
of training.
|
||||
Reference:
|
||||
[Can recurrent neural networks warp time?](https://openreview.net/forum?id=SJcKhk-Ab)
|
||||
weights_init_scale: Float scale of the weights after initialization. Setting to lower than one
|
||||
sometimes helps reduce variance between runs.
|
||||
hidden_hidden_bias_scale: Float scale for the hidden-to-hidden bias scale. Set to 0.0 for
|
||||
the default behaviour.
|
||||
dropout: float, set to 0.0 by default. Optional dropout applied at the end of the final LSTM RNN layer.
|
||||
|
||||
vocab_size: int, specifying the vocabulary size of the embedding layer of the Prediction network,
|
||||
|
@ -125,6 +129,8 @@ class RNNTDecoder(rnnt_abstract.AbstractRNNTDecoder):
|
|||
# Optional arguments
|
||||
forget_gate_bias = prednet.get('forget_gate_bias', 1.0)
|
||||
t_max = prednet.get('t_max', None)
|
||||
weights_init_scale = prednet.get('weights_init_scale', 1.0)
|
||||
hidden_hidden_bias_scale = prednet.get('hidden_hidden_bias_scale', 0.0)
|
||||
dropout = prednet.get('dropout', 0.0)
|
||||
self.random_state_sampling = random_state_sampling
|
||||
|
||||
|
@ -135,6 +141,8 @@ class RNNTDecoder(rnnt_abstract.AbstractRNNTDecoder):
|
|||
forget_gate_bias=forget_gate_bias,
|
||||
t_max=t_max,
|
||||
norm=normalization_mode,
|
||||
weights_init_scale=weights_init_scale,
|
||||
hidden_hidden_bias_scale=hidden_hidden_bias_scale,
|
||||
dropout=dropout,
|
||||
)
|
||||
|
||||
|
@ -245,7 +253,18 @@ class RNNTDecoder(rnnt_abstract.AbstractRNNTDecoder):
|
|||
del y, start, state
|
||||
return g, hid
|
||||
|
||||
def _predict(self, vocab_size, pred_n_hidden, pred_rnn_layers, forget_gate_bias, t_max, norm, dropout):
|
||||
def _predict(
|
||||
self,
|
||||
vocab_size,
|
||||
pred_n_hidden,
|
||||
pred_rnn_layers,
|
||||
forget_gate_bias,
|
||||
t_max,
|
||||
norm,
|
||||
weights_init_scale,
|
||||
hidden_hidden_bias_scale,
|
||||
dropout,
|
||||
):
|
||||
"""
|
||||
Prepare the trainable parameters of the Prediction Network.
|
||||
|
||||
|
@ -256,6 +275,10 @@ class RNNTDecoder(rnnt_abstract.AbstractRNNTDecoder):
|
|||
forget_gate_bias: Whether to perform unit forget gate bias.
|
||||
t_max: Whether to perform Chrono LSTM init.
|
||||
norm: Type of normalization to perform in RNN.
|
||||
weights_init_scale: Float scale of the weights after initialization. Setting to lower than one
|
||||
sometimes helps reduce variance between runs.
|
||||
hidden_hidden_bias_scale: Float scale for the hidden-to-hidden bias scale. Set to 0.0 for
|
||||
the default behaviour.
|
||||
dropout: Whether to apply dropout to RNN.
|
||||
"""
|
||||
if self.blank_as_pad:
|
||||
|
@ -274,6 +297,8 @@ class RNNTDecoder(rnnt_abstract.AbstractRNNTDecoder):
|
|||
forget_gate_bias=forget_gate_bias,
|
||||
t_max=t_max,
|
||||
dropout=dropout,
|
||||
weights_init_scale=weights_init_scale,
|
||||
hidden_hidden_bias_scale=hidden_hidden_bias_scale,
|
||||
),
|
||||
}
|
||||
)
|
||||
|
@ -764,6 +789,7 @@ class RNNTJoint(rnnt_abstract.AbstractRNNTJoint):
|
|||
# Compute the wer (with logging for just 1st sub-batch)
|
||||
self.wer.update(sub_enc, sub_enc_lens, sub_transcripts, sub_transcript_lens)
|
||||
wer, wer_num, wer_denom = self.wer.compute()
|
||||
self.wer.reset()
|
||||
|
||||
wer_numer_list.append(wer_num)
|
||||
wer_denom_list.append(wer_denom)
|
||||
|
|
|
@ -12,12 +12,13 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import math
|
||||
from typing import Callable, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch import Tensor
|
||||
from torch.nn.init import _calculate_correct_fan
|
||||
from torch.nn.modules.utils import _single
|
||||
|
||||
from nemo.collections.asr.parts.activations import Swish
|
||||
|
@ -36,6 +37,52 @@ except ImportError:
|
|||
jasper_activations = {"hardtanh": nn.Hardtanh, "relu": nn.ReLU, "selu": nn.SELU, "swish": Swish}
|
||||
|
||||
|
||||
def tds_uniform_(tensor, mode='fan_in'):
|
||||
"""
|
||||
Uniform Initialization from the paper [Sequence-to-Sequence Speech Recognition with Time-Depth Separable Convolutions](https://www.isca-speech.org/archive/Interspeech_2019/pdfs/2460.pdf)
|
||||
Normalized to -
|
||||
|
||||
.. math::
|
||||
\text{bound} = \text{2} \times \sqrt{\frac{1}{\text{fan\_mode}}}
|
||||
|
||||
Args:
|
||||
tensor: an n-dimensional `torch.Tensor`
|
||||
mode: either ``'fan_in'`` (default) or ``'fan_out'``. Choosing ``'fan_in'``
|
||||
preserves the magnitude of the variance of the weights in the
|
||||
forward pass. Choosing ``'fan_out'`` preserves the magnitudes in the
|
||||
backwards pass.
|
||||
"""
|
||||
fan = _calculate_correct_fan(tensor, mode)
|
||||
gain = 2.0 # sqrt(4.0) = 2
|
||||
std = gain / math.sqrt(fan) # sqrt(4.0 / fan_in)
|
||||
bound = std # Calculate uniform bounds from standard deviation
|
||||
with torch.no_grad():
|
||||
return tensor.uniform_(-bound, bound)
|
||||
|
||||
|
||||
def tds_normal_(tensor, mode='fan_in'):
|
||||
"""
|
||||
Normal Initialization from the paper [Sequence-to-Sequence Speech Recognition with Time-Depth Separable Convolutions](https://www.isca-speech.org/archive/Interspeech_2019/pdfs/2460.pdf)
|
||||
Normalized to -
|
||||
|
||||
.. math::
|
||||
\text{bound} = \text{2} \times \sqrt{\frac{1}{\text{fan\_mode}}}
|
||||
|
||||
Args:
|
||||
tensor: an n-dimensional `torch.Tensor`
|
||||
mode: either ``'fan_in'`` (default) or ``'fan_out'``. Choosing ``'fan_in'``
|
||||
preserves the magnitude of the variance of the weights in the
|
||||
forward pass. Choosing ``'fan_out'`` preserves the magnitudes in the
|
||||
backwards pass.
|
||||
"""
|
||||
fan = _calculate_correct_fan(tensor, mode)
|
||||
gain = 2.0
|
||||
std = gain / math.sqrt(fan) # sqrt(4.0 / fan_in)
|
||||
bound = std # Calculate uniform bounds from standard deviation
|
||||
with torch.no_grad():
|
||||
return tensor.normal_(0.0, bound)
|
||||
|
||||
|
||||
def init_weights(m, mode: Optional[str] = 'xavier_uniform'):
|
||||
if isinstance(m, MaskedConv1d):
|
||||
init_weights(m.conv, mode)
|
||||
|
@ -49,6 +96,10 @@ def init_weights(m, mode: Optional[str] = 'xavier_uniform'):
|
|||
nn.init.kaiming_uniform_(m.weight, nonlinearity="relu")
|
||||
elif mode == 'kaiming_normal':
|
||||
nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
|
||||
elif mode == 'tds_uniform':
|
||||
tds_uniform_(m.weight)
|
||||
elif mode == 'tds_normal':
|
||||
tds_normal_(m.weight)
|
||||
else:
|
||||
raise ValueError("Unknown Initialization mode: {0}".format(mode))
|
||||
elif isinstance(m, nn.BatchNorm1d):
|
||||
|
|
|
@ -180,7 +180,7 @@ def rnnt_loss_gpu(
|
|||
acts, acts_shape = rnnt_helper.flatten_tensor(acts)
|
||||
|
||||
### REPRESENT THE CUDA ARRAY INTERFACE OF COSTS VECTOR ###
|
||||
costs_repr = cuda.as_cuda_array(costs) # NO COPY OF DATA, JUST CHANGE REPRESENTATION
|
||||
costs_repr = cuda.as_cuda_array(costs, sync=False) # NO COPY OF DATA, JUST CHANGE REPRESENTATION
|
||||
|
||||
wrapper = gpu_rnnt.GPURNNT(
|
||||
minibatch=minibatch_size,
|
||||
|
|
|
@ -29,6 +29,8 @@ def rnn(
|
|||
dropout: Optional[float] = 0.0,
|
||||
norm_first_rnn: Optional[bool] = None,
|
||||
t_max: Optional[int] = None,
|
||||
weights_init_scale: float = 1.0,
|
||||
hidden_hidden_bias_scale: float = 0.0,
|
||||
) -> torch.nn.Module:
|
||||
"""
|
||||
Utility function to provide unified interface to common LSTM RNN modules.
|
||||
|
@ -58,6 +60,12 @@ def rnn(
|
|||
Reference:
|
||||
[Can recurrent neural networks warp time?](https://openreview.net/forum?id=SJcKhk-Ab)
|
||||
|
||||
weights_init_scale: Float scale of the weights after initialization. Setting to lower than one
|
||||
sometimes helps reduce variance between runs.
|
||||
|
||||
hidden_hidden_bias_scale: Float scale for the hidden-to-hidden bias scale. Set to 0.0 for
|
||||
the default behaviour.
|
||||
|
||||
Returns:
|
||||
A RNN module
|
||||
"""
|
||||
|
@ -72,6 +80,8 @@ def rnn(
|
|||
dropout=dropout,
|
||||
forget_gate_bias=forget_gate_bias,
|
||||
t_max=t_max,
|
||||
weights_init_scale=weights_init_scale,
|
||||
hidden_hidden_bias_scale=hidden_hidden_bias_scale,
|
||||
)
|
||||
|
||||
if norm == "batch":
|
||||
|
@ -84,6 +94,8 @@ def rnn(
|
|||
forget_gate_bias=forget_gate_bias,
|
||||
t_max=t_max,
|
||||
norm_first_rnn=norm_first_rnn,
|
||||
weights_init_scale=weights_init_scale,
|
||||
hidden_hidden_bias_scale=hidden_hidden_bias_scale,
|
||||
)
|
||||
|
||||
if norm == "layer":
|
||||
|
@ -95,6 +107,8 @@ def rnn(
|
|||
dropout=dropout,
|
||||
forget_gate_bias=forget_gate_bias,
|
||||
t_max=t_max,
|
||||
weights_init_scale=weights_init_scale,
|
||||
hidden_hidden_bias_scale=hidden_hidden_bias_scale,
|
||||
)
|
||||
)
|
||||
|
||||
|
@ -138,6 +152,8 @@ class LSTMDropout(torch.nn.Module):
|
|||
dropout: Optional[float],
|
||||
forget_gate_bias: Optional[float],
|
||||
t_max: Optional[int] = None,
|
||||
weights_init_scale: float = 1.0,
|
||||
hidden_hidden_bias_scale: float = 0.0,
|
||||
):
|
||||
"""Returns an LSTM with forget gate bias init to `forget_gate_bias`.
|
||||
Args:
|
||||
|
@ -157,6 +173,12 @@ class LSTMDropout(torch.nn.Module):
|
|||
Reference:
|
||||
[Can recurrent neural networks warp time?](https://openreview.net/forum?id=SJcKhk-Ab)
|
||||
|
||||
weights_init_scale: Float scale of the weights after initialization. Setting to lower than one
|
||||
sometimes helps reduce variance between runs.
|
||||
|
||||
hidden_hidden_bias_scale: Float scale for the hidden-to-hidden bias scale. Set to 0.0 for
|
||||
the default behaviour.
|
||||
|
||||
Returns:
|
||||
A `torch.nn.LSTM`.
|
||||
"""
|
||||
|
@ -188,10 +210,14 @@ class LSTMDropout(torch.nn.Module):
|
|||
bias.data[hidden_size : 2 * hidden_size].fill_(forget_gate_bias)
|
||||
if "bias_hh" in name:
|
||||
bias = getattr(self.lstm, name)
|
||||
bias.data[hidden_size : 2 * hidden_size].fill_(0)
|
||||
bias.data[hidden_size : 2 * hidden_size] *= float(hidden_hidden_bias_scale)
|
||||
|
||||
self.dropout = torch.nn.Dropout(dropout) if dropout else None
|
||||
|
||||
for name, v in self.named_parameters():
|
||||
if 'weight' in name or 'bias' in name:
|
||||
v.data *= float(weights_init_scale)
|
||||
|
||||
def forward(
|
||||
self, x: torch.Tensor, h: Optional[Tuple[torch.Tensor, torch.Tensor]] = None
|
||||
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
|
@ -214,6 +240,8 @@ class RNNLayer(torch.nn.Module):
|
|||
batch_norm: bool = True,
|
||||
forget_gate_bias: Optional[float] = 1.0,
|
||||
t_max: Optional[int] = None,
|
||||
weights_init_scale: float = 1.0,
|
||||
hidden_hidden_bias_scale: float = 0.0,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
|
@ -229,6 +257,8 @@ class RNNLayer(torch.nn.Module):
|
|||
dropout=0.0,
|
||||
forget_gate_bias=forget_gate_bias,
|
||||
t_max=t_max,
|
||||
weights_init_scale=weights_init_scale,
|
||||
hidden_hidden_bias_scale=hidden_hidden_bias_scale,
|
||||
)
|
||||
else:
|
||||
self.rnn = rnn_type(input_size=input_size, hidden_size=hidden_size, bias=not batch_norm)
|
||||
|
@ -265,6 +295,8 @@ class BNRNNSum(torch.nn.Module):
|
|||
forget_gate_bias: Optional[float] = 1.0,
|
||||
norm_first_rnn: bool = False,
|
||||
t_max: Optional[int] = None,
|
||||
weights_init_scale: float = 1.0,
|
||||
hidden_hidden_bias_scale: float = 0.0,
|
||||
):
|
||||
super().__init__()
|
||||
self.rnn_layers = rnn_layers
|
||||
|
@ -281,6 +313,8 @@ class BNRNNSum(torch.nn.Module):
|
|||
batch_norm=batch_norm and (norm_first_rnn or i > 0),
|
||||
forget_gate_bias=forget_gate_bias,
|
||||
t_max=t_max,
|
||||
weights_init_scale=weights_init_scale,
|
||||
hidden_hidden_bias_scale=hidden_hidden_bias_scale,
|
||||
)
|
||||
)
|
||||
|
||||
|
@ -367,6 +401,8 @@ def ln_lstm(
|
|||
dropout: Optional[float],
|
||||
forget_gate_bias: Optional[float],
|
||||
t_max: Optional[int],
|
||||
weights_init_scale: Optional[float] = None,
|
||||
hidden_hidden_bias_scale: Optional[float] = None,
|
||||
) -> torch.nn.Module:
|
||||
"""Returns a ScriptModule that mimics a PyTorch native LSTM."""
|
||||
# The following are not implemented.
|
||||
|
@ -376,6 +412,12 @@ def ln_lstm(
|
|||
if t_max is not None:
|
||||
raise ValueError("LayerNormLSTM does not support chrono init")
|
||||
|
||||
if weights_init_scale is not None:
|
||||
raise ValueError("LayerNormLSTM does not support `weight_init_scale`")
|
||||
|
||||
if hidden_hidden_bias_scale is not None:
|
||||
raise ValueError("LayerNormLSTM does not support `hidden_hidden_bias_scale`")
|
||||
|
||||
return StackedLSTM(
|
||||
num_layers,
|
||||
LSTMLayer,
|
||||
|
|
|
@ -1177,6 +1177,7 @@ class ModelPT(LightningModule, Model):
|
|||
Args:
|
||||
trainer: PyTorch Lightning Trainer object.
|
||||
"""
|
||||
self.trainer = trainer
|
||||
self._trainer = trainer
|
||||
self.set_world_size(self._trainer)
|
||||
|
||||
|
|
Loading…
Reference in a new issue