ASR patches for v1.0.0 (#2207)

* Multiple updates to RNNT add initialization

Signed-off-by: smajumdar <titu1994@gmail.com>

* Correct name of initilization

Signed-off-by: smajumdar <titu1994@gmail.com>

* Update dockerignore

Signed-off-by: smajumdar <titu1994@gmail.com>

* Fix RNNT WER calculation

Signed-off-by: smajumdar <titu1994@gmail.com>

* Address comments

Signed-off-by: smajumdar <titu1994@gmail.com>
This commit is contained in:
Somshubra Majumdar 2021-05-13 23:55:42 -07:00 committed by GitHub
parent d1540e8bff
commit d1b37ed30a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 208 additions and 4 deletions

View file

@ -15,3 +15,5 @@ coverage.xml
*,cover
*.log
.git
**/*.nemo
**/*.ckpt

View file

@ -560,6 +560,7 @@ class EncDecCTCModel(ASRModel, ExportableEncDecModel, ASRModuleMixin):
predictions_lengths=encoded_len,
)
wer, _, _ = self._wer.compute()
self._wer.reset()
tensorboard_logs.update({'training_batch_wer': wer})
return {'loss': loss_value, 'log': tensorboard_logs}
@ -580,6 +581,7 @@ class EncDecCTCModel(ASRModel, ExportableEncDecModel, ASRModuleMixin):
predictions=predictions, targets=transcript, target_lengths=transcript_len, predictions_lengths=encoded_len
)
wer, wer_num, wer_denom = self._wer.compute()
self._wer.reset()
return {
'val_loss': loss_value,
'val_wer_num': wer_num,

View file

@ -112,6 +112,38 @@ class EncDecRNNTModel(ASRModel, ASRModuleMixin):
self.joint.set_loss(self.loss)
self.joint.set_wer(self.wer)
self.setup_optim_normalization()
def setup_optim_normalization(self):
"""
Helper method to setup normalization of certain parts of the model prior to the optimization step.
Supported pre-optimization normalizations are as follows:
.. code-block:: yaml
# Variation Noise injection
model:
variational_noise:
std: 0.0
start_step: 0
# Joint - Length normalization
model:
normalize_joint_txu: false
# Encoder Network - gradient normalization
model:
normalize_encoder_norm: false
# Decoder / Prediction Network - gradient normalization
model:
normalize_decoder_norm: false
# Joint - gradient normalization
model:
normalize_joint_norm: false
"""
# setting up the variational noise for the decoder
if hasattr(self.cfg, 'variational_noise'):
self._optim_variational_noise_std = self.cfg['variational_noise'].get('std', 0)
@ -120,6 +152,19 @@ class EncDecRNNTModel(ASRModel, ASRModuleMixin):
self._optim_variational_noise_std = 0
self._optim_variational_noise_start = 0
# Setup normalized gradients for model joint by T x U scaling factor (joint length normalization)
self._optim_normalize_joint_txu = self.cfg.get('normalize_joint_txu', False)
self._optim_normalize_txu = None
# Setup normalized encoder norm for model
self._optim_normalize_encoder_norm = self.cfg.get('normalize_encoder_norm', False)
# Setup normalized decoder norm for model
self._optim_normalize_decoder_norm = self.cfg.get('normalize_decoder_norm', False)
# Setup normalized joint norm for model
self._optim_normalize_joint_norm = self.cfg.get('normalize_joint_norm', False)
def extract_rnnt_loss_cfg(self, cfg: Optional[DictConfig]):
"""
Helper method to extract the rnnt loss name, and potentially its kwargs
@ -580,6 +625,7 @@ class EncDecRNNTModel(ASRModel, ASRModuleMixin):
if (sample_id + 1) % log_every_n_steps == 0:
self.wer.update(encoded, encoded_len, transcript, transcript_len)
_, scores, words = self.wer.compute()
self.wer.reset()
tensorboard_logs.update({'training_batch_wer': scores.float() / words})
else:
@ -607,6 +653,10 @@ class EncDecRNNTModel(ASRModel, ASRModuleMixin):
# Log items
self.log_dict(tensorboard_logs)
# Preserve batch acoustic model T and language model U parameters if normalizing
if self._optim_normalize_joint_txu:
self._optim_normalize_txu = [encoded_len.max(), transcript_len.max()]
return {'loss': loss_value}
def validation_step(self, batch, batch_idx, dataloader_idx=0):
@ -635,6 +685,7 @@ class EncDecRNNTModel(ASRModel, ASRModuleMixin):
self.wer.update(encoded, encoded_len, transcript, transcript_len)
wer, wer_num, wer_denom = self.wer.compute()
self.wer.reset()
tensorboard_logs['val_wer_num'] = wer_num
tensorboard_logs['val_wer_denom'] = wer_denom
@ -743,3 +794,32 @@ class EncDecRNNTModel(ASRModel, ASRModuleMixin):
dtype=param.dtype,
)
param.grad.data.add_(noise)
if self._optim_normalize_joint_txu:
T, U = self._optim_normalize_txu
if T is not None and U is not None:
for param_name, param in self.encoder.named_parameters():
if param.grad is not None:
param.grad.data.div_(U)
for param_name, param in self.decoder.named_parameters():
if param.grad is not None:
param.grad.data.div_(T)
if self._optim_normalize_encoder_norm:
for param_name, param in self.encoder.named_parameters():
if param.grad is not None:
norm = param.grad.norm()
param.grad.data.div_(norm)
if self._optim_normalize_decoder_norm:
for param_name, param in self.decoder.named_parameters():
if param.grad is not None:
norm = param.grad.norm()
param.grad.data.div_(norm)
if self._optim_normalize_joint_norm:
for param_name, param in self.joint.named_parameters():
if param.grad is not None:
norm = param.grad.norm()
param.grad.data.div_(norm)

View file

@ -66,6 +66,10 @@ class RNNTDecoder(rnnt_abstract.AbstractRNNTDecoder):
of training.
Reference:
[Can recurrent neural networks warp time?](https://openreview.net/forum?id=SJcKhk-Ab)
weights_init_scale: Float scale of the weights after initialization. Setting to lower than one
sometimes helps reduce variance between runs.
hidden_hidden_bias_scale: Float scale for the hidden-to-hidden bias scale. Set to 0.0 for
the default behaviour.
dropout: float, set to 0.0 by default. Optional dropout applied at the end of the final LSTM RNN layer.
vocab_size: int, specifying the vocabulary size of the embedding layer of the Prediction network,
@ -125,6 +129,8 @@ class RNNTDecoder(rnnt_abstract.AbstractRNNTDecoder):
# Optional arguments
forget_gate_bias = prednet.get('forget_gate_bias', 1.0)
t_max = prednet.get('t_max', None)
weights_init_scale = prednet.get('weights_init_scale', 1.0)
hidden_hidden_bias_scale = prednet.get('hidden_hidden_bias_scale', 0.0)
dropout = prednet.get('dropout', 0.0)
self.random_state_sampling = random_state_sampling
@ -135,6 +141,8 @@ class RNNTDecoder(rnnt_abstract.AbstractRNNTDecoder):
forget_gate_bias=forget_gate_bias,
t_max=t_max,
norm=normalization_mode,
weights_init_scale=weights_init_scale,
hidden_hidden_bias_scale=hidden_hidden_bias_scale,
dropout=dropout,
)
@ -245,7 +253,18 @@ class RNNTDecoder(rnnt_abstract.AbstractRNNTDecoder):
del y, start, state
return g, hid
def _predict(self, vocab_size, pred_n_hidden, pred_rnn_layers, forget_gate_bias, t_max, norm, dropout):
def _predict(
self,
vocab_size,
pred_n_hidden,
pred_rnn_layers,
forget_gate_bias,
t_max,
norm,
weights_init_scale,
hidden_hidden_bias_scale,
dropout,
):
"""
Prepare the trainable parameters of the Prediction Network.
@ -256,6 +275,10 @@ class RNNTDecoder(rnnt_abstract.AbstractRNNTDecoder):
forget_gate_bias: Whether to perform unit forget gate bias.
t_max: Whether to perform Chrono LSTM init.
norm: Type of normalization to perform in RNN.
weights_init_scale: Float scale of the weights after initialization. Setting to lower than one
sometimes helps reduce variance between runs.
hidden_hidden_bias_scale: Float scale for the hidden-to-hidden bias scale. Set to 0.0 for
the default behaviour.
dropout: Whether to apply dropout to RNN.
"""
if self.blank_as_pad:
@ -274,6 +297,8 @@ class RNNTDecoder(rnnt_abstract.AbstractRNNTDecoder):
forget_gate_bias=forget_gate_bias,
t_max=t_max,
dropout=dropout,
weights_init_scale=weights_init_scale,
hidden_hidden_bias_scale=hidden_hidden_bias_scale,
),
}
)
@ -764,6 +789,7 @@ class RNNTJoint(rnnt_abstract.AbstractRNNTJoint):
# Compute the wer (with logging for just 1st sub-batch)
self.wer.update(sub_enc, sub_enc_lens, sub_transcripts, sub_transcript_lens)
wer, wer_num, wer_denom = self.wer.compute()
self.wer.reset()
wer_numer_list.append(wer_num)
wer_denom_list.append(wer_denom)

View file

@ -12,12 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import Callable, List, Optional, Tuple
import torch
import torch.nn as nn
from torch import Tensor
from torch.nn.init import _calculate_correct_fan
from torch.nn.modules.utils import _single
from nemo.collections.asr.parts.activations import Swish
@ -36,6 +37,52 @@ except ImportError:
jasper_activations = {"hardtanh": nn.Hardtanh, "relu": nn.ReLU, "selu": nn.SELU, "swish": Swish}
def tds_uniform_(tensor, mode='fan_in'):
"""
Uniform Initialization from the paper [Sequence-to-Sequence Speech Recognition with Time-Depth Separable Convolutions](https://www.isca-speech.org/archive/Interspeech_2019/pdfs/2460.pdf)
Normalized to -
.. math::
\text{bound} = \text{2} \times \sqrt{\frac{1}{\text{fan\_mode}}}
Args:
tensor: an n-dimensional `torch.Tensor`
mode: either ``'fan_in'`` (default) or ``'fan_out'``. Choosing ``'fan_in'``
preserves the magnitude of the variance of the weights in the
forward pass. Choosing ``'fan_out'`` preserves the magnitudes in the
backwards pass.
"""
fan = _calculate_correct_fan(tensor, mode)
gain = 2.0 # sqrt(4.0) = 2
std = gain / math.sqrt(fan) # sqrt(4.0 / fan_in)
bound = std # Calculate uniform bounds from standard deviation
with torch.no_grad():
return tensor.uniform_(-bound, bound)
def tds_normal_(tensor, mode='fan_in'):
"""
Normal Initialization from the paper [Sequence-to-Sequence Speech Recognition with Time-Depth Separable Convolutions](https://www.isca-speech.org/archive/Interspeech_2019/pdfs/2460.pdf)
Normalized to -
.. math::
\text{bound} = \text{2} \times \sqrt{\frac{1}{\text{fan\_mode}}}
Args:
tensor: an n-dimensional `torch.Tensor`
mode: either ``'fan_in'`` (default) or ``'fan_out'``. Choosing ``'fan_in'``
preserves the magnitude of the variance of the weights in the
forward pass. Choosing ``'fan_out'`` preserves the magnitudes in the
backwards pass.
"""
fan = _calculate_correct_fan(tensor, mode)
gain = 2.0
std = gain / math.sqrt(fan) # sqrt(4.0 / fan_in)
bound = std # Calculate uniform bounds from standard deviation
with torch.no_grad():
return tensor.normal_(0.0, bound)
def init_weights(m, mode: Optional[str] = 'xavier_uniform'):
if isinstance(m, MaskedConv1d):
init_weights(m.conv, mode)
@ -49,6 +96,10 @@ def init_weights(m, mode: Optional[str] = 'xavier_uniform'):
nn.init.kaiming_uniform_(m.weight, nonlinearity="relu")
elif mode == 'kaiming_normal':
nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
elif mode == 'tds_uniform':
tds_uniform_(m.weight)
elif mode == 'tds_normal':
tds_normal_(m.weight)
else:
raise ValueError("Unknown Initialization mode: {0}".format(mode))
elif isinstance(m, nn.BatchNorm1d):

View file

@ -180,7 +180,7 @@ def rnnt_loss_gpu(
acts, acts_shape = rnnt_helper.flatten_tensor(acts)
### REPRESENT THE CUDA ARRAY INTERFACE OF COSTS VECTOR ###
costs_repr = cuda.as_cuda_array(costs) # NO COPY OF DATA, JUST CHANGE REPRESENTATION
costs_repr = cuda.as_cuda_array(costs, sync=False) # NO COPY OF DATA, JUST CHANGE REPRESENTATION
wrapper = gpu_rnnt.GPURNNT(
minibatch=minibatch_size,

View file

@ -29,6 +29,8 @@ def rnn(
dropout: Optional[float] = 0.0,
norm_first_rnn: Optional[bool] = None,
t_max: Optional[int] = None,
weights_init_scale: float = 1.0,
hidden_hidden_bias_scale: float = 0.0,
) -> torch.nn.Module:
"""
Utility function to provide unified interface to common LSTM RNN modules.
@ -58,6 +60,12 @@ def rnn(
Reference:
[Can recurrent neural networks warp time?](https://openreview.net/forum?id=SJcKhk-Ab)
weights_init_scale: Float scale of the weights after initialization. Setting to lower than one
sometimes helps reduce variance between runs.
hidden_hidden_bias_scale: Float scale for the hidden-to-hidden bias scale. Set to 0.0 for
the default behaviour.
Returns:
A RNN module
"""
@ -72,6 +80,8 @@ def rnn(
dropout=dropout,
forget_gate_bias=forget_gate_bias,
t_max=t_max,
weights_init_scale=weights_init_scale,
hidden_hidden_bias_scale=hidden_hidden_bias_scale,
)
if norm == "batch":
@ -84,6 +94,8 @@ def rnn(
forget_gate_bias=forget_gate_bias,
t_max=t_max,
norm_first_rnn=norm_first_rnn,
weights_init_scale=weights_init_scale,
hidden_hidden_bias_scale=hidden_hidden_bias_scale,
)
if norm == "layer":
@ -95,6 +107,8 @@ def rnn(
dropout=dropout,
forget_gate_bias=forget_gate_bias,
t_max=t_max,
weights_init_scale=weights_init_scale,
hidden_hidden_bias_scale=hidden_hidden_bias_scale,
)
)
@ -138,6 +152,8 @@ class LSTMDropout(torch.nn.Module):
dropout: Optional[float],
forget_gate_bias: Optional[float],
t_max: Optional[int] = None,
weights_init_scale: float = 1.0,
hidden_hidden_bias_scale: float = 0.0,
):
"""Returns an LSTM with forget gate bias init to `forget_gate_bias`.
Args:
@ -157,6 +173,12 @@ class LSTMDropout(torch.nn.Module):
Reference:
[Can recurrent neural networks warp time?](https://openreview.net/forum?id=SJcKhk-Ab)
weights_init_scale: Float scale of the weights after initialization. Setting to lower than one
sometimes helps reduce variance between runs.
hidden_hidden_bias_scale: Float scale for the hidden-to-hidden bias scale. Set to 0.0 for
the default behaviour.
Returns:
A `torch.nn.LSTM`.
"""
@ -188,10 +210,14 @@ class LSTMDropout(torch.nn.Module):
bias.data[hidden_size : 2 * hidden_size].fill_(forget_gate_bias)
if "bias_hh" in name:
bias = getattr(self.lstm, name)
bias.data[hidden_size : 2 * hidden_size].fill_(0)
bias.data[hidden_size : 2 * hidden_size] *= float(hidden_hidden_bias_scale)
self.dropout = torch.nn.Dropout(dropout) if dropout else None
for name, v in self.named_parameters():
if 'weight' in name or 'bias' in name:
v.data *= float(weights_init_scale)
def forward(
self, x: torch.Tensor, h: Optional[Tuple[torch.Tensor, torch.Tensor]] = None
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
@ -214,6 +240,8 @@ class RNNLayer(torch.nn.Module):
batch_norm: bool = True,
forget_gate_bias: Optional[float] = 1.0,
t_max: Optional[int] = None,
weights_init_scale: float = 1.0,
hidden_hidden_bias_scale: float = 0.0,
):
super().__init__()
@ -229,6 +257,8 @@ class RNNLayer(torch.nn.Module):
dropout=0.0,
forget_gate_bias=forget_gate_bias,
t_max=t_max,
weights_init_scale=weights_init_scale,
hidden_hidden_bias_scale=hidden_hidden_bias_scale,
)
else:
self.rnn = rnn_type(input_size=input_size, hidden_size=hidden_size, bias=not batch_norm)
@ -265,6 +295,8 @@ class BNRNNSum(torch.nn.Module):
forget_gate_bias: Optional[float] = 1.0,
norm_first_rnn: bool = False,
t_max: Optional[int] = None,
weights_init_scale: float = 1.0,
hidden_hidden_bias_scale: float = 0.0,
):
super().__init__()
self.rnn_layers = rnn_layers
@ -281,6 +313,8 @@ class BNRNNSum(torch.nn.Module):
batch_norm=batch_norm and (norm_first_rnn or i > 0),
forget_gate_bias=forget_gate_bias,
t_max=t_max,
weights_init_scale=weights_init_scale,
hidden_hidden_bias_scale=hidden_hidden_bias_scale,
)
)
@ -367,6 +401,8 @@ def ln_lstm(
dropout: Optional[float],
forget_gate_bias: Optional[float],
t_max: Optional[int],
weights_init_scale: Optional[float] = None,
hidden_hidden_bias_scale: Optional[float] = None,
) -> torch.nn.Module:
"""Returns a ScriptModule that mimics a PyTorch native LSTM."""
# The following are not implemented.
@ -376,6 +412,12 @@ def ln_lstm(
if t_max is not None:
raise ValueError("LayerNormLSTM does not support chrono init")
if weights_init_scale is not None:
raise ValueError("LayerNormLSTM does not support `weight_init_scale`")
if hidden_hidden_bias_scale is not None:
raise ValueError("LayerNormLSTM does not support `hidden_hidden_bias_scale`")
return StackedLSTM(
num_layers,
LSTMLayer,

View file

@ -1177,6 +1177,7 @@ class ModelPT(LightningModule, Model):
Args:
trainer: PyTorch Lightning Trainer object.
"""
self.trainer = trainer
self._trainer = trainer
self.set_world_size(self._trainer)