diff --git a/.dockerignore b/.dockerignore index fdc976d48..14f5114d0 100644 --- a/.dockerignore +++ b/.dockerignore @@ -15,3 +15,5 @@ coverage.xml *,cover *.log .git +**/*.nemo +**/*.ckpt diff --git a/nemo/collections/asr/models/ctc_models.py b/nemo/collections/asr/models/ctc_models.py index 216b1f571..0004be239 100644 --- a/nemo/collections/asr/models/ctc_models.py +++ b/nemo/collections/asr/models/ctc_models.py @@ -560,6 +560,7 @@ class EncDecCTCModel(ASRModel, ExportableEncDecModel, ASRModuleMixin): predictions_lengths=encoded_len, ) wer, _, _ = self._wer.compute() + self._wer.reset() tensorboard_logs.update({'training_batch_wer': wer}) return {'loss': loss_value, 'log': tensorboard_logs} @@ -580,6 +581,7 @@ class EncDecCTCModel(ASRModel, ExportableEncDecModel, ASRModuleMixin): predictions=predictions, targets=transcript, target_lengths=transcript_len, predictions_lengths=encoded_len ) wer, wer_num, wer_denom = self._wer.compute() + self._wer.reset() return { 'val_loss': loss_value, 'val_wer_num': wer_num, diff --git a/nemo/collections/asr/models/rnnt_models.py b/nemo/collections/asr/models/rnnt_models.py index d1f05f74a..e535faa3f 100644 --- a/nemo/collections/asr/models/rnnt_models.py +++ b/nemo/collections/asr/models/rnnt_models.py @@ -112,6 +112,38 @@ class EncDecRNNTModel(ASRModel, ASRModuleMixin): self.joint.set_loss(self.loss) self.joint.set_wer(self.wer) + self.setup_optim_normalization() + + def setup_optim_normalization(self): + """ + Helper method to setup normalization of certain parts of the model prior to the optimization step. + + Supported pre-optimization normalizations are as follows: + + .. code-block:: yaml + + # Variation Noise injection + model: + variational_noise: + std: 0.0 + start_step: 0 + + # Joint - Length normalization + model: + normalize_joint_txu: false + + # Encoder Network - gradient normalization + model: + normalize_encoder_norm: false + + # Decoder / Prediction Network - gradient normalization + model: + normalize_decoder_norm: false + + # Joint - gradient normalization + model: + normalize_joint_norm: false + """ # setting up the variational noise for the decoder if hasattr(self.cfg, 'variational_noise'): self._optim_variational_noise_std = self.cfg['variational_noise'].get('std', 0) @@ -120,6 +152,19 @@ class EncDecRNNTModel(ASRModel, ASRModuleMixin): self._optim_variational_noise_std = 0 self._optim_variational_noise_start = 0 + # Setup normalized gradients for model joint by T x U scaling factor (joint length normalization) + self._optim_normalize_joint_txu = self.cfg.get('normalize_joint_txu', False) + self._optim_normalize_txu = None + + # Setup normalized encoder norm for model + self._optim_normalize_encoder_norm = self.cfg.get('normalize_encoder_norm', False) + + # Setup normalized decoder norm for model + self._optim_normalize_decoder_norm = self.cfg.get('normalize_decoder_norm', False) + + # Setup normalized joint norm for model + self._optim_normalize_joint_norm = self.cfg.get('normalize_joint_norm', False) + def extract_rnnt_loss_cfg(self, cfg: Optional[DictConfig]): """ Helper method to extract the rnnt loss name, and potentially its kwargs @@ -580,6 +625,7 @@ class EncDecRNNTModel(ASRModel, ASRModuleMixin): if (sample_id + 1) % log_every_n_steps == 0: self.wer.update(encoded, encoded_len, transcript, transcript_len) _, scores, words = self.wer.compute() + self.wer.reset() tensorboard_logs.update({'training_batch_wer': scores.float() / words}) else: @@ -607,6 +653,10 @@ class EncDecRNNTModel(ASRModel, ASRModuleMixin): # Log items self.log_dict(tensorboard_logs) + # Preserve batch acoustic model T and language model U parameters if normalizing + if self._optim_normalize_joint_txu: + self._optim_normalize_txu = [encoded_len.max(), transcript_len.max()] + return {'loss': loss_value} def validation_step(self, batch, batch_idx, dataloader_idx=0): @@ -635,6 +685,7 @@ class EncDecRNNTModel(ASRModel, ASRModuleMixin): self.wer.update(encoded, encoded_len, transcript, transcript_len) wer, wer_num, wer_denom = self.wer.compute() + self.wer.reset() tensorboard_logs['val_wer_num'] = wer_num tensorboard_logs['val_wer_denom'] = wer_denom @@ -743,3 +794,32 @@ class EncDecRNNTModel(ASRModel, ASRModuleMixin): dtype=param.dtype, ) param.grad.data.add_(noise) + + if self._optim_normalize_joint_txu: + T, U = self._optim_normalize_txu + if T is not None and U is not None: + for param_name, param in self.encoder.named_parameters(): + if param.grad is not None: + param.grad.data.div_(U) + + for param_name, param in self.decoder.named_parameters(): + if param.grad is not None: + param.grad.data.div_(T) + + if self._optim_normalize_encoder_norm: + for param_name, param in self.encoder.named_parameters(): + if param.grad is not None: + norm = param.grad.norm() + param.grad.data.div_(norm) + + if self._optim_normalize_decoder_norm: + for param_name, param in self.decoder.named_parameters(): + if param.grad is not None: + norm = param.grad.norm() + param.grad.data.div_(norm) + + if self._optim_normalize_joint_norm: + for param_name, param in self.joint.named_parameters(): + if param.grad is not None: + norm = param.grad.norm() + param.grad.data.div_(norm) diff --git a/nemo/collections/asr/modules/rnnt.py b/nemo/collections/asr/modules/rnnt.py index 881ff75d5..55e224fcf 100644 --- a/nemo/collections/asr/modules/rnnt.py +++ b/nemo/collections/asr/modules/rnnt.py @@ -66,6 +66,10 @@ class RNNTDecoder(rnnt_abstract.AbstractRNNTDecoder): of training. Reference: [Can recurrent neural networks warp time?](https://openreview.net/forum?id=SJcKhk-Ab) + weights_init_scale: Float scale of the weights after initialization. Setting to lower than one + sometimes helps reduce variance between runs. + hidden_hidden_bias_scale: Float scale for the hidden-to-hidden bias scale. Set to 0.0 for + the default behaviour. dropout: float, set to 0.0 by default. Optional dropout applied at the end of the final LSTM RNN layer. vocab_size: int, specifying the vocabulary size of the embedding layer of the Prediction network, @@ -125,6 +129,8 @@ class RNNTDecoder(rnnt_abstract.AbstractRNNTDecoder): # Optional arguments forget_gate_bias = prednet.get('forget_gate_bias', 1.0) t_max = prednet.get('t_max', None) + weights_init_scale = prednet.get('weights_init_scale', 1.0) + hidden_hidden_bias_scale = prednet.get('hidden_hidden_bias_scale', 0.0) dropout = prednet.get('dropout', 0.0) self.random_state_sampling = random_state_sampling @@ -135,6 +141,8 @@ class RNNTDecoder(rnnt_abstract.AbstractRNNTDecoder): forget_gate_bias=forget_gate_bias, t_max=t_max, norm=normalization_mode, + weights_init_scale=weights_init_scale, + hidden_hidden_bias_scale=hidden_hidden_bias_scale, dropout=dropout, ) @@ -245,7 +253,18 @@ class RNNTDecoder(rnnt_abstract.AbstractRNNTDecoder): del y, start, state return g, hid - def _predict(self, vocab_size, pred_n_hidden, pred_rnn_layers, forget_gate_bias, t_max, norm, dropout): + def _predict( + self, + vocab_size, + pred_n_hidden, + pred_rnn_layers, + forget_gate_bias, + t_max, + norm, + weights_init_scale, + hidden_hidden_bias_scale, + dropout, + ): """ Prepare the trainable parameters of the Prediction Network. @@ -256,6 +275,10 @@ class RNNTDecoder(rnnt_abstract.AbstractRNNTDecoder): forget_gate_bias: Whether to perform unit forget gate bias. t_max: Whether to perform Chrono LSTM init. norm: Type of normalization to perform in RNN. + weights_init_scale: Float scale of the weights after initialization. Setting to lower than one + sometimes helps reduce variance between runs. + hidden_hidden_bias_scale: Float scale for the hidden-to-hidden bias scale. Set to 0.0 for + the default behaviour. dropout: Whether to apply dropout to RNN. """ if self.blank_as_pad: @@ -274,6 +297,8 @@ class RNNTDecoder(rnnt_abstract.AbstractRNNTDecoder): forget_gate_bias=forget_gate_bias, t_max=t_max, dropout=dropout, + weights_init_scale=weights_init_scale, + hidden_hidden_bias_scale=hidden_hidden_bias_scale, ), } ) @@ -764,6 +789,7 @@ class RNNTJoint(rnnt_abstract.AbstractRNNTJoint): # Compute the wer (with logging for just 1st sub-batch) self.wer.update(sub_enc, sub_enc_lens, sub_transcripts, sub_transcript_lens) wer, wer_num, wer_denom = self.wer.compute() + self.wer.reset() wer_numer_list.append(wer_num) wer_denom_list.append(wer_denom) diff --git a/nemo/collections/asr/parts/jasper.py b/nemo/collections/asr/parts/jasper.py index 260fb5e50..73cc213dd 100644 --- a/nemo/collections/asr/parts/jasper.py +++ b/nemo/collections/asr/parts/jasper.py @@ -12,12 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. - +import math from typing import Callable, List, Optional, Tuple import torch import torch.nn as nn from torch import Tensor +from torch.nn.init import _calculate_correct_fan from torch.nn.modules.utils import _single from nemo.collections.asr.parts.activations import Swish @@ -36,6 +37,52 @@ except ImportError: jasper_activations = {"hardtanh": nn.Hardtanh, "relu": nn.ReLU, "selu": nn.SELU, "swish": Swish} +def tds_uniform_(tensor, mode='fan_in'): + """ + Uniform Initialization from the paper [Sequence-to-Sequence Speech Recognition with Time-Depth Separable Convolutions](https://www.isca-speech.org/archive/Interspeech_2019/pdfs/2460.pdf) + Normalized to - + + .. math:: + \text{bound} = \text{2} \times \sqrt{\frac{1}{\text{fan\_mode}}} + + Args: + tensor: an n-dimensional `torch.Tensor` + mode: either ``'fan_in'`` (default) or ``'fan_out'``. Choosing ``'fan_in'`` + preserves the magnitude of the variance of the weights in the + forward pass. Choosing ``'fan_out'`` preserves the magnitudes in the + backwards pass. + """ + fan = _calculate_correct_fan(tensor, mode) + gain = 2.0 # sqrt(4.0) = 2 + std = gain / math.sqrt(fan) # sqrt(4.0 / fan_in) + bound = std # Calculate uniform bounds from standard deviation + with torch.no_grad(): + return tensor.uniform_(-bound, bound) + + +def tds_normal_(tensor, mode='fan_in'): + """ + Normal Initialization from the paper [Sequence-to-Sequence Speech Recognition with Time-Depth Separable Convolutions](https://www.isca-speech.org/archive/Interspeech_2019/pdfs/2460.pdf) + Normalized to - + + .. math:: + \text{bound} = \text{2} \times \sqrt{\frac{1}{\text{fan\_mode}}} + + Args: + tensor: an n-dimensional `torch.Tensor` + mode: either ``'fan_in'`` (default) or ``'fan_out'``. Choosing ``'fan_in'`` + preserves the magnitude of the variance of the weights in the + forward pass. Choosing ``'fan_out'`` preserves the magnitudes in the + backwards pass. + """ + fan = _calculate_correct_fan(tensor, mode) + gain = 2.0 + std = gain / math.sqrt(fan) # sqrt(4.0 / fan_in) + bound = std # Calculate uniform bounds from standard deviation + with torch.no_grad(): + return tensor.normal_(0.0, bound) + + def init_weights(m, mode: Optional[str] = 'xavier_uniform'): if isinstance(m, MaskedConv1d): init_weights(m.conv, mode) @@ -49,6 +96,10 @@ def init_weights(m, mode: Optional[str] = 'xavier_uniform'): nn.init.kaiming_uniform_(m.weight, nonlinearity="relu") elif mode == 'kaiming_normal': nn.init.kaiming_normal_(m.weight, nonlinearity="relu") + elif mode == 'tds_uniform': + tds_uniform_(m.weight) + elif mode == 'tds_normal': + tds_normal_(m.weight) else: raise ValueError("Unknown Initialization mode: {0}".format(mode)) elif isinstance(m, nn.BatchNorm1d): diff --git a/nemo/collections/asr/parts/numba/rnnt_loss/rnnt.py b/nemo/collections/asr/parts/numba/rnnt_loss/rnnt.py index 038713f30..5caf2fea2 100644 --- a/nemo/collections/asr/parts/numba/rnnt_loss/rnnt.py +++ b/nemo/collections/asr/parts/numba/rnnt_loss/rnnt.py @@ -180,7 +180,7 @@ def rnnt_loss_gpu( acts, acts_shape = rnnt_helper.flatten_tensor(acts) ### REPRESENT THE CUDA ARRAY INTERFACE OF COSTS VECTOR ### - costs_repr = cuda.as_cuda_array(costs) # NO COPY OF DATA, JUST CHANGE REPRESENTATION + costs_repr = cuda.as_cuda_array(costs, sync=False) # NO COPY OF DATA, JUST CHANGE REPRESENTATION wrapper = gpu_rnnt.GPURNNT( minibatch=minibatch_size, diff --git a/nemo/collections/common/parts/rnn.py b/nemo/collections/common/parts/rnn.py index 0945c2926..ff2618cac 100644 --- a/nemo/collections/common/parts/rnn.py +++ b/nemo/collections/common/parts/rnn.py @@ -29,6 +29,8 @@ def rnn( dropout: Optional[float] = 0.0, norm_first_rnn: Optional[bool] = None, t_max: Optional[int] = None, + weights_init_scale: float = 1.0, + hidden_hidden_bias_scale: float = 0.0, ) -> torch.nn.Module: """ Utility function to provide unified interface to common LSTM RNN modules. @@ -58,6 +60,12 @@ def rnn( Reference: [Can recurrent neural networks warp time?](https://openreview.net/forum?id=SJcKhk-Ab) + weights_init_scale: Float scale of the weights after initialization. Setting to lower than one + sometimes helps reduce variance between runs. + + hidden_hidden_bias_scale: Float scale for the hidden-to-hidden bias scale. Set to 0.0 for + the default behaviour. + Returns: A RNN module """ @@ -72,6 +80,8 @@ def rnn( dropout=dropout, forget_gate_bias=forget_gate_bias, t_max=t_max, + weights_init_scale=weights_init_scale, + hidden_hidden_bias_scale=hidden_hidden_bias_scale, ) if norm == "batch": @@ -84,6 +94,8 @@ def rnn( forget_gate_bias=forget_gate_bias, t_max=t_max, norm_first_rnn=norm_first_rnn, + weights_init_scale=weights_init_scale, + hidden_hidden_bias_scale=hidden_hidden_bias_scale, ) if norm == "layer": @@ -95,6 +107,8 @@ def rnn( dropout=dropout, forget_gate_bias=forget_gate_bias, t_max=t_max, + weights_init_scale=weights_init_scale, + hidden_hidden_bias_scale=hidden_hidden_bias_scale, ) ) @@ -138,6 +152,8 @@ class LSTMDropout(torch.nn.Module): dropout: Optional[float], forget_gate_bias: Optional[float], t_max: Optional[int] = None, + weights_init_scale: float = 1.0, + hidden_hidden_bias_scale: float = 0.0, ): """Returns an LSTM with forget gate bias init to `forget_gate_bias`. Args: @@ -157,6 +173,12 @@ class LSTMDropout(torch.nn.Module): Reference: [Can recurrent neural networks warp time?](https://openreview.net/forum?id=SJcKhk-Ab) + weights_init_scale: Float scale of the weights after initialization. Setting to lower than one + sometimes helps reduce variance between runs. + + hidden_hidden_bias_scale: Float scale for the hidden-to-hidden bias scale. Set to 0.0 for + the default behaviour. + Returns: A `torch.nn.LSTM`. """ @@ -188,10 +210,14 @@ class LSTMDropout(torch.nn.Module): bias.data[hidden_size : 2 * hidden_size].fill_(forget_gate_bias) if "bias_hh" in name: bias = getattr(self.lstm, name) - bias.data[hidden_size : 2 * hidden_size].fill_(0) + bias.data[hidden_size : 2 * hidden_size] *= float(hidden_hidden_bias_scale) self.dropout = torch.nn.Dropout(dropout) if dropout else None + for name, v in self.named_parameters(): + if 'weight' in name or 'bias' in name: + v.data *= float(weights_init_scale) + def forward( self, x: torch.Tensor, h: Optional[Tuple[torch.Tensor, torch.Tensor]] = None ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: @@ -214,6 +240,8 @@ class RNNLayer(torch.nn.Module): batch_norm: bool = True, forget_gate_bias: Optional[float] = 1.0, t_max: Optional[int] = None, + weights_init_scale: float = 1.0, + hidden_hidden_bias_scale: float = 0.0, ): super().__init__() @@ -229,6 +257,8 @@ class RNNLayer(torch.nn.Module): dropout=0.0, forget_gate_bias=forget_gate_bias, t_max=t_max, + weights_init_scale=weights_init_scale, + hidden_hidden_bias_scale=hidden_hidden_bias_scale, ) else: self.rnn = rnn_type(input_size=input_size, hidden_size=hidden_size, bias=not batch_norm) @@ -265,6 +295,8 @@ class BNRNNSum(torch.nn.Module): forget_gate_bias: Optional[float] = 1.0, norm_first_rnn: bool = False, t_max: Optional[int] = None, + weights_init_scale: float = 1.0, + hidden_hidden_bias_scale: float = 0.0, ): super().__init__() self.rnn_layers = rnn_layers @@ -281,6 +313,8 @@ class BNRNNSum(torch.nn.Module): batch_norm=batch_norm and (norm_first_rnn or i > 0), forget_gate_bias=forget_gate_bias, t_max=t_max, + weights_init_scale=weights_init_scale, + hidden_hidden_bias_scale=hidden_hidden_bias_scale, ) ) @@ -367,6 +401,8 @@ def ln_lstm( dropout: Optional[float], forget_gate_bias: Optional[float], t_max: Optional[int], + weights_init_scale: Optional[float] = None, + hidden_hidden_bias_scale: Optional[float] = None, ) -> torch.nn.Module: """Returns a ScriptModule that mimics a PyTorch native LSTM.""" # The following are not implemented. @@ -376,6 +412,12 @@ def ln_lstm( if t_max is not None: raise ValueError("LayerNormLSTM does not support chrono init") + if weights_init_scale is not None: + raise ValueError("LayerNormLSTM does not support `weight_init_scale`") + + if hidden_hidden_bias_scale is not None: + raise ValueError("LayerNormLSTM does not support `hidden_hidden_bias_scale`") + return StackedLSTM( num_layers, LSTMLayer, diff --git a/nemo/core/classes/modelPT.py b/nemo/core/classes/modelPT.py index 091c3d6dd..64e1982f6 100644 --- a/nemo/core/classes/modelPT.py +++ b/nemo/core/classes/modelPT.py @@ -1177,6 +1177,7 @@ class ModelPT(LightningModule, Model): Args: trainer: PyTorch Lightning Trainer object. """ + self.trainer = trainer self._trainer = trainer self.set_world_size(self._trainer)