NeMo/nemo/collections/asr/modules/conv_asr.py
Samuel Kriman b7a175b7b9
Self-supervised pre-training for speech models (#3139)
* self-supervised training

Signed-off-by: sam1373 <samuelkriman@gmail.com>

* test

Signed-off-by: sam1373 <samuelkriman@gmail.com>

* remove imports

Signed-off-by: sam1373 <samuelkriman@gmail.com>

* fix

Signed-off-by: sam1373 <samuelkriman@gmail.com>

* sort imports

Signed-off-by: sam1373 <samuelkriman@gmail.com>

* fix audio_to_text

Signed-off-by: sam1373 <samuelkriman@gmail.com>

* manifest handle no text

Signed-off-by: sam1373 <samuelkriman@gmail.com>

* loss init

Signed-off-by: sam1373 <samuelkriman@gmail.com>

* style

Signed-off-by: sam1373 <samuelkriman@gmail.com>

* remove tokenizer from config

Signed-off-by: sam1373 <samuelkriman@gmail.com>

* config changes

Signed-off-by: sam1373 <samuelkriman@gmail.com>

* remove hydra import

Signed-off-by: sam1373 <samuelkriman@gmail.com>

* always spec augment

Signed-off-by: sam1373 <samuelkriman@gmail.com>

* fixes

Signed-off-by: sam1373 <samuelkriman@gmail.com>

* copyright

Signed-off-by: sam1373 <samuelkriman@gmail.com>

* fix cosine sim

Signed-off-by: sam1373 <samuelkriman@gmail.com>

* fix cosine sim

Signed-off-by: sam1373 <samuelkriman@gmail.com>

* fix cosine sim

Signed-off-by: sam1373 <samuelkriman@gmail.com>

* changes based on comments

Signed-off-by: sam1373 <samuelkriman@gmail.com>

* changes based on comments

Signed-off-by: sam1373 <samuelkriman@gmail.com>

* configs

Signed-off-by: sam1373 <samuelkriman@gmail.com>

* name fix

Signed-off-by: sam1373 <samuelkriman@gmail.com>

* ci config changes

Signed-off-by: sam1373 <samuelkriman@gmail.com>

* renamed to num_negatives

Signed-off-by: sam1373 <samuelkriman@gmail.com>

* minor changes

Signed-off-by: sam1373 <samuelkriman@gmail.com>

* name changes, type annotations

Signed-off-by: sam1373 <samuelkriman@gmail.com>

Co-authored-by: Yang Zhang <yzhang123@users.noreply.github.com>
2021-11-10 15:33:11 -08:00

881 lines
29 KiB
Python

# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import OrderedDict
from dataclasses import dataclass, field
from typing import List, Optional, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from omegaconf import MISSING, ListConfig, OmegaConf
from nemo.collections.asr.parts.submodules.jasper import (
JasperBlock,
MaskedConv1d,
ParallelBlock,
SqueezeExcite,
init_weights,
jasper_activations,
)
from nemo.collections.asr.parts.submodules.tdnn_attention import (
AttentivePoolLayer,
StatsPoolLayer,
TDNNModule,
TDNNSEModule,
)
from nemo.core.classes.common import typecheck
from nemo.core.classes.exportable import Exportable
from nemo.core.classes.module import NeuralModule
from nemo.core.neural_types import (
AcousticEncodedRepresentation,
LengthsType,
LogitsType,
LogprobsType,
NeuralType,
SpectrogramType,
)
from nemo.utils import logging
__all__ = ['ConvASRDecoder', 'ConvASREncoder', 'ConvASRDecoderClassification']
class ConvASREncoder(NeuralModule, Exportable):
"""
Convolutional encoder for ASR models. With this class you can implement JasperNet and QuartzNet models.
Based on these papers:
https://arxiv.org/pdf/1904.03288.pdf
https://arxiv.org/pdf/1910.10261.pdf
"""
def _prepare_for_export(self, **kwargs):
m_count = 0
for m in self.modules():
if isinstance(m, MaskedConv1d):
if self._rnnt_export:
pass
else:
m.use_mask = False
m_count += 1
Exportable._prepare_for_export(self, **kwargs)
logging.warning(f"Turned off {m_count} masked convolutions")
def input_example(self):
"""
Generates input examples for tracing etc.
Returns:
A tuple of input examples.
"""
input_example = torch.randn(1, self._feat_in, 8192).to(next(self.parameters()).device)
lens = torch.randint(0, input_example.shape[-1], size=(input_example.shape[0],))
if self._rnnt_export:
return tuple([input_example, lens])
else:
return tuple([input_example])
@property
def disabled_deployment_input_names(self):
"""Implement this method to return a set of input names disabled for export"""
if self._rnnt_export:
return set([])
else:
return set(["length"])
@property
def disabled_deployment_output_names(self):
"""Implement this method to return a set of output names disabled for export"""
if self._rnnt_export:
return set([])
else:
return set(["encoded_lengths"])
def save_to(self, save_path: str):
pass
@classmethod
def restore_from(cls, restore_path: str):
pass
@property
def input_types(self):
"""Returns definitions of module input ports.
"""
return OrderedDict(
{
"audio_signal": NeuralType(('B', 'D', 'T'), SpectrogramType()),
"length": NeuralType(tuple('B'), LengthsType()),
}
)
@property
def output_types(self):
"""Returns definitions of module output ports.
"""
return OrderedDict(
{
"outputs": NeuralType(('B', 'D', 'T'), AcousticEncodedRepresentation()),
"encoded_lengths": NeuralType(tuple('B'), LengthsType()),
}
)
def __init__(
self,
jasper,
activation: str,
feat_in: int,
normalization_mode: str = "batch",
residual_mode: str = "add",
norm_groups: int = -1,
conv_mask: bool = True,
frame_splicing: int = 1,
init_mode: Optional[str] = 'xavier_uniform',
quantize: bool = False,
):
super().__init__()
if isinstance(jasper, ListConfig):
jasper = OmegaConf.to_container(jasper)
activation = jasper_activations[activation]()
# If the activation can be executed in place, do so.
if hasattr(activation, 'inplace'):
activation.inplace = True
feat_in = feat_in * frame_splicing
self._feat_in = feat_in
residual_panes = []
encoder_layers = []
self.dense_residual = False
for lcfg in jasper:
dense_res = []
if lcfg.get('residual_dense', False):
residual_panes.append(feat_in)
dense_res = residual_panes
self.dense_residual = True
groups = lcfg.get('groups', 1)
separable = lcfg.get('separable', False)
heads = lcfg.get('heads', -1)
residual_mode = lcfg.get('residual_mode', residual_mode)
se = lcfg.get('se', False)
se_reduction_ratio = lcfg.get('se_reduction_ratio', 8)
se_context_window = lcfg.get('se_context_size', -1)
se_interpolation_mode = lcfg.get('se_interpolation_mode', 'nearest')
kernel_size_factor = lcfg.get('kernel_size_factor', 1.0)
stride_last = lcfg.get('stride_last', False)
future_context = lcfg.get('future_context', -1)
encoder_layers.append(
JasperBlock(
feat_in,
lcfg['filters'],
repeat=lcfg['repeat'],
kernel_size=lcfg['kernel'],
stride=lcfg['stride'],
dilation=lcfg['dilation'],
dropout=lcfg['dropout'],
residual=lcfg['residual'],
groups=groups,
separable=separable,
heads=heads,
residual_mode=residual_mode,
normalization=normalization_mode,
norm_groups=norm_groups,
activation=activation,
residual_panes=dense_res,
conv_mask=conv_mask,
se=se,
se_reduction_ratio=se_reduction_ratio,
se_context_window=se_context_window,
se_interpolation_mode=se_interpolation_mode,
kernel_size_factor=kernel_size_factor,
stride_last=stride_last,
future_context=future_context,
quantize=quantize,
)
)
feat_in = lcfg['filters']
self._feat_out = feat_in
self.encoder = torch.nn.Sequential(*encoder_layers)
self.apply(lambda x: init_weights(x, mode=init_mode))
# Flag needed for RNNT export support
self._rnnt_export = False
@typecheck()
def forward(self, audio_signal, length=None):
s_input, length = self.encoder(([audio_signal], length))
if length is None:
return s_input[-1]
return s_input[-1], length
class ParallelConvASREncoder(NeuralModule, Exportable):
"""
Convolutional encoder for ASR models with parallel blocks. CarneliNet can be implemented with this class.
"""
def _prepare_for_export(self):
m_count = 0
for m in self.modules():
if isinstance(m, MaskedConv1d):
m.use_mask = False
m_count += 1
logging.warning(f"Turned off {m_count} masked convolutions")
def input_example(self):
"""
Generates input examples for tracing etc.
Returns:
A tuple of input examples.
"""
input_example = torch.randn(16, self._feat_in, 256).to(next(self.parameters()).device)
return tuple([input_example])
@property
def disabled_deployment_input_names(self):
"""Implement this method to return a set of input names disabled for export"""
return set(["length"])
@property
def disabled_deployment_output_names(self):
"""Implement this method to return a set of output names disabled for export"""
return set(["encoded_lengths"])
def save_to(self, save_path: str):
pass
@classmethod
def restore_from(cls, restore_path: str):
pass
@property
def input_types(self):
"""Returns definitions of module input ports.
"""
return OrderedDict(
{
"audio_signal": NeuralType(('B', 'D', 'T'), SpectrogramType()),
"length": NeuralType(tuple('B'), LengthsType()),
}
)
@property
def output_types(self):
"""Returns definitions of module output ports.
"""
return OrderedDict(
{
"outputs": NeuralType(('B', 'D', 'T'), AcousticEncodedRepresentation()),
"encoded_lengths": NeuralType(tuple('B'), LengthsType()),
}
)
def __init__(
self,
jasper,
activation: str,
feat_in: int,
normalization_mode: str = "batch",
residual_mode: str = "add",
norm_groups: int = -1,
conv_mask: bool = True,
frame_splicing: int = 1,
init_mode: Optional[str] = 'xavier_uniform',
aggregation_mode: Optional[str] = None,
quantize: bool = False,
):
super().__init__()
if isinstance(jasper, ListConfig):
jasper = OmegaConf.to_container(jasper)
activation = jasper_activations[activation]()
feat_in = feat_in * frame_splicing
self._feat_in = feat_in
residual_panes = []
encoder_layers = []
self.dense_residual = False
for lcfg in jasper:
dense_res = []
if lcfg.get('residual_dense', False):
residual_panes.append(feat_in)
dense_res = residual_panes
self.dense_residual = True
groups = lcfg.get('groups', 1)
separable = lcfg.get('separable', False)
heads = lcfg.get('heads', -1)
residual_mode = lcfg.get('residual_mode', residual_mode)
se = lcfg.get('se', False)
se_reduction_ratio = lcfg.get('se_reduction_ratio', 8)
se_context_window = lcfg.get('se_context_size', -1)
se_interpolation_mode = lcfg.get('se_interpolation_mode', 'nearest')
kernel_size_factor = lcfg.get('kernel_size_factor', 1.0)
stride_last = lcfg.get('stride_last', False)
aggregation_mode = lcfg.get('aggregation_mode', 'sum')
block_dropout = lcfg.get('block_dropout', 0.0)
parallel_residual_mode = lcfg.get('parallel_residual_mode', 'sum')
parallel_blocks = []
for kernel_size in lcfg['kernel']:
parallel_blocks.append(
JasperBlock(
feat_in,
lcfg['filters'],
repeat=lcfg['repeat'],
kernel_size=[kernel_size],
stride=lcfg['stride'],
dilation=lcfg['dilation'],
dropout=lcfg['dropout'],
residual=lcfg['residual'],
groups=groups,
separable=separable,
heads=heads,
residual_mode=residual_mode,
normalization=normalization_mode,
norm_groups=norm_groups,
activation=activation,
residual_panes=dense_res,
conv_mask=conv_mask,
se=se,
se_reduction_ratio=se_reduction_ratio,
se_context_window=se_context_window,
se_interpolation_mode=se_interpolation_mode,
kernel_size_factor=kernel_size_factor,
stride_last=stride_last,
quantize=quantize,
)
)
if len(parallel_blocks) == 1:
encoder_layers.append(parallel_blocks[0])
else:
encoder_layers.append(
ParallelBlock(
parallel_blocks,
aggregation_mode=aggregation_mode,
block_dropout_prob=block_dropout,
residual_mode=parallel_residual_mode,
in_filters=feat_in,
out_filters=lcfg['filters'],
)
)
feat_in = lcfg['filters']
self._feat_out = feat_in
self.encoder = torch.nn.Sequential(*encoder_layers)
self.apply(lambda x: init_weights(x, mode=init_mode))
@typecheck()
def forward(self, audio_signal, length=None):
s_input, length = self.encoder(([audio_signal], length))
if length is None:
return s_input[-1]
return s_input[-1], length
class ConvASRDecoder(NeuralModule, Exportable):
"""Simple ASR Decoder for use with CTC-based models such as JasperNet and QuartzNet
Based on these papers:
https://arxiv.org/pdf/1904.03288.pdf
https://arxiv.org/pdf/1910.10261.pdf
https://arxiv.org/pdf/2005.04290.pdf
"""
def save_to(self, save_path: str):
pass
@classmethod
def restore_from(cls, restore_path: str):
pass
@property
def input_types(self):
return OrderedDict({"encoder_output": NeuralType(('B', 'D', 'T'), AcousticEncodedRepresentation())})
@property
def output_types(self):
return OrderedDict({"logprobs": NeuralType(('B', 'T', 'D'), LogprobsType())})
def __init__(self, feat_in, num_classes, init_mode="xavier_uniform", vocabulary=None):
super().__init__()
if vocabulary is not None:
if num_classes != len(vocabulary):
raise ValueError(
f"If vocabulary is specified, it's length should be equal to the num_classes. Instead got: num_classes={num_classes} and len(vocabulary)={len(vocabulary)}"
)
self.__vocabulary = vocabulary
self._feat_in = feat_in
# Add 1 for blank char
self._num_classes = num_classes + 1
self.decoder_layers = torch.nn.Sequential(
torch.nn.Conv1d(self._feat_in, self._num_classes, kernel_size=1, bias=True)
)
self.apply(lambda x: init_weights(x, mode=init_mode))
@typecheck()
def forward(self, encoder_output):
return torch.nn.functional.log_softmax(self.decoder_layers(encoder_output).transpose(1, 2), dim=-1)
def input_example(self):
"""
Generates input examples for tracing etc.
Returns:
A tuple of input examples.
"""
bs = 8
seq = 64
input_example = torch.randn(bs, self._feat_in, seq).to(next(self.parameters()).device)
return tuple([input_example])
def _prepare_for_export(self, **kwargs):
m_count = 0
for m in self.modules():
if type(m).__name__ == "MaskedConv1d":
m.use_mask = False
m_count += 1
if m_count > 0:
logging.warning(f"Turned off {m_count} masked convolutions")
Exportable._prepare_for_export(self, **kwargs)
@property
def vocabulary(self):
return self.__vocabulary
@property
def num_classes_with_blank(self):
return self._num_classes
class ConvASRDecoderReconstruction(NeuralModule, Exportable):
"""ASR Decoder for reconstructing masked regions of spectrogram
"""
@property
def input_types(self):
return OrderedDict({"encoder_output": NeuralType(('B', 'D', 'T'), AcousticEncodedRepresentation())})
@property
def output_types(self):
return OrderedDict({"spec_recon": NeuralType(('B', 'T', 'D'), SpectrogramType())})
def __init__(
self,
feat_in,
feat_out,
feat_hidden,
stride_layers,
kernel_size=11,
init_mode="xavier_uniform",
activation="relu",
):
super().__init__()
if stride_layers > 0 and (kernel_size < 3 or kernel_size % 2 == 0):
raise ValueError(
"Kernel size in this decoder needs to be >= 3 and odd when using at least 1 stride layer."
)
activation = jasper_activations[activation]()
self.feat_in = feat_in
self.feat_out = feat_out
self.feat_hidden = feat_hidden
self.decoder_layers = [nn.Conv1d(self.feat_in, self.feat_hidden, kernel_size=1, bias=True)]
for i in range(stride_layers):
self.decoder_layers.append(activation)
self.decoder_layers.append(
nn.ConvTranspose1d(
self.feat_hidden,
self.feat_hidden,
kernel_size,
stride=2,
padding=(kernel_size - 3) // 2 + 1,
output_padding=1,
bias=True,
)
)
self.decoder_layers.append(nn.Conv1d(self.feat_hidden, self.feat_hidden, kernel_size=1, bias=True))
self.decoder_layers.append(nn.BatchNorm1d(self.feat_hidden, eps=1e-3, momentum=0.1))
self.decoder_layers.append(activation)
self.decoder_layers.append(nn.Conv1d(self.feat_hidden, self.feat_out, kernel_size=1, bias=True))
self.decoder_layers = nn.Sequential(*self.decoder_layers)
self.apply(lambda x: init_weights(x, mode=init_mode))
@typecheck()
def forward(self, encoder_output):
return self.decoder_layers(encoder_output).transpose(-2, -1)
def input_example(self):
"""
Generates input examples for tracing etc.
Returns:
A tuple of input examples.
"""
bs = 8
seq = 64
input_example = torch.randn(bs, self._feat_in, seq).to(next(self.parameters()).device)
return tuple([input_example])
def _prepare_for_export(self, **kwargs):
m_count = 0
for m in self.modules():
if type(m).__name__ == "MaskedConv1d":
m.use_mask = False
m_count += 1
if m_count > 0:
logging.warning(f"Turned off {m_count} masked convolutions")
Exportable._prepare_for_export(self, **kwargs)
class ConvASRDecoderClassification(NeuralModule, Exportable):
"""Simple ASR Decoder for use with classification models such as JasperNet and QuartzNet
Based on these papers:
https://arxiv.org/pdf/2005.04290.pdf
"""
def input_example(self):
"""
Generates input examples for tracing etc.
Returns:
A tuple of input examples.
"""
input_example = torch.randn(16, self._feat_in, 128).to(next(self.parameters()).device)
return tuple([input_example])
@property
def input_types(self):
return OrderedDict({"encoder_output": NeuralType(('B', 'D', 'T'), AcousticEncodedRepresentation())})
@property
def output_types(self):
return OrderedDict({"logits": NeuralType(('B', 'D'), LogitsType())})
def __init__(
self,
feat_in: int,
num_classes: int,
init_mode: Optional[str] = "xavier_uniform",
return_logits: bool = True,
pooling_type='avg',
):
super().__init__()
self._feat_in = feat_in
self._return_logits = return_logits
self._num_classes = num_classes
if pooling_type == 'avg':
self.pooling = torch.nn.AdaptiveAvgPool1d(1)
elif pooling_type == 'max':
self.pooling = torch.nn.AdaptiveMaxPool1d(1)
else:
raise ValueError('Pooling type chosen is not valid. Must be either `avg` or `max`')
self.decoder_layers = torch.nn.Sequential(torch.nn.Linear(self._feat_in, self._num_classes, bias=True))
self.apply(lambda x: init_weights(x, mode=init_mode))
@typecheck()
def forward(self, encoder_output):
batch, in_channels, timesteps = encoder_output.size()
encoder_output = self.pooling(encoder_output).view(batch, in_channels) # [B, C]
logits = self.decoder_layers(encoder_output) # [B, num_classes]
if self._return_logits:
return logits
return torch.nn.functional.softmax(logits, dim=-1)
@property
def num_classes(self):
return self._num_classes
class ECAPAEncoder(NeuralModule, Exportable):
"""
Modified ECAPA Encoder layer without Res2Net module for faster training and inference which achieves
better numbers on speaker diarization tasks
Reference: ECAPA-TDNN Embeddings for Speaker Diarization (https://arxiv.org/pdf/2104.01466.pdf)
input:
feat_in: input feature shape (mel spec feature shape)
filters: list of filter shapes for SE_TDNN modules
kernel_sizes: list of kernel shapes for SE_TDNN modules
dilations: list of dilations for group conv se layer
scale: scale value to group wider conv channels (deafult:8)
output:
outputs : encoded output
output_length: masked output lengths
"""
@property
def input_types(self):
"""Returns definitions of module input ports.
"""
return OrderedDict(
{
"audio_signal": NeuralType(('B', 'D', 'T'), SpectrogramType()),
"length": NeuralType(tuple('B'), LengthsType()),
}
)
@property
def output_types(self):
"""Returns definitions of module output ports.
"""
return OrderedDict(
{
"outputs": NeuralType(('B', 'D', 'T'), AcousticEncodedRepresentation()),
"encoded_lengths": NeuralType(tuple('B'), LengthsType()),
}
)
def __init__(
self,
feat_in: int,
filters: list,
kernel_sizes: list,
dilations: list,
scale: int = 8,
init_mode: str = 'xavier_uniform',
):
super().__init__()
self.layers = nn.ModuleList()
self.layers.append(TDNNModule(feat_in, filters[0], kernel_size=kernel_sizes[0], dilation=dilations[0]))
for i in range(len(filters) - 2):
self.layers.append(
TDNNSEModule(
filters[i],
filters[i + 1],
group_scale=scale,
se_channels=128,
kernel_size=kernel_sizes[i + 1],
dilation=dilations[i + 1],
)
)
self.feature_agg = TDNNModule(filters[-1], filters[-1], kernel_sizes[-1], dilations[-1])
self.apply(lambda x: init_weights(x, mode=init_mode))
def forward(self, audio_signal, length=None):
x = audio_signal
outputs = []
for layer in self.layers:
x = layer(x, length=length)
outputs.append(x)
x = torch.cat(outputs[1:], dim=1)
x = self.feature_agg(x)
return x, length
class SpeakerDecoder(NeuralModule, Exportable):
"""
Speaker Decoder creates the final neural layers that maps from the outputs
of Jasper Encoder to the embedding layer followed by speaker based softmax loss.
Args:
feat_in (int): Number of channels being input to this module
num_classes (int): Number of unique speakers in dataset
emb_sizes (list) : shapes of intermediate embedding layers (we consider speaker embbeddings from 1st of this layers)
Defaults to [1024,1024]
pool_mode (str) : Pooling stratergy type. options are 'xvector','tap', 'attention'
Defaults to 'xvector (mean and variance)'
tap (temporal average pooling: just mean)
attention (attention based pooling)
init_mode (str): Describes how neural network parameters are
initialized. Options are ['xavier_uniform', 'xavier_normal',
'kaiming_uniform','kaiming_normal'].
Defaults to "xavier_uniform".
"""
def input_example(self):
"""
Generates input examples for tracing etc.
Returns:
A tuple of input examples.
"""
input_example = torch.randn(16, self.input_feat_in, 256).to(next(self.parameters()).device)
return tuple([input_example])
@property
def input_types(self):
return OrderedDict(
{
"encoder_output": NeuralType(('B', 'D', 'T'), AcousticEncodedRepresentation()),
"length": NeuralType(('B',), LengthsType(), optional=True),
}
)
@property
def output_types(self):
return OrderedDict(
{
"logits": NeuralType(('B', 'D'), LogitsType()),
"embs": NeuralType(('B', 'D'), AcousticEncodedRepresentation()),
}
)
def __init__(
self,
feat_in: int,
num_classes: int,
emb_sizes: Optional[Union[int, list]] = 256,
pool_mode: str = 'xvector',
angular: bool = False,
attention_channels: int = 128,
init_mode: str = "xavier_uniform",
):
super().__init__()
self.angular = angular
self.emb_id = 2
bias = False if self.angular else True
emb_sizes = [emb_sizes] if type(emb_sizes) is int else emb_sizes
self._num_classes = num_classes
self.pool_mode = pool_mode.lower()
if self.pool_mode == 'xvector' or self.pool_mode == 'tap':
self._pooling = StatsPoolLayer(feat_in=feat_in, pool_mode=self.pool_mode)
affine_type = 'linear'
elif self.pool_mode == 'attention':
self._pooling = AttentivePoolLayer(inp_filters=feat_in, attention_channels=attention_channels)
affine_type = 'conv'
shapes = [self._pooling.feat_in]
for size in emb_sizes:
shapes.append(int(size))
emb_layers = []
for shape_in, shape_out in zip(shapes[:-1], shapes[1:]):
layer = self.affine_layer(shape_in, shape_out, learn_mean=False, affine_type=affine_type)
emb_layers.append(layer)
self.emb_layers = nn.ModuleList(emb_layers)
self.final = nn.Linear(shapes[-1], self._num_classes, bias=bias)
self.apply(lambda x: init_weights(x, mode=init_mode))
def affine_layer(
self, inp_shape, out_shape, learn_mean=True, affine_type='conv',
):
if affine_type == 'conv':
layer = nn.Sequential(
nn.BatchNorm1d(inp_shape, affine=True, track_running_stats=True),
nn.Conv1d(inp_shape, out_shape, kernel_size=1),
)
else:
layer = nn.Sequential(
nn.Linear(inp_shape, out_shape),
nn.BatchNorm1d(out_shape, affine=learn_mean, track_running_stats=True),
nn.ReLU(),
)
return layer
@typecheck()
def forward(self, encoder_output, length=None):
pool = self._pooling(encoder_output, length)
embs = []
for layer in self.emb_layers:
pool, emb = layer(pool), layer[: self.emb_id](pool)
embs.append(emb)
pool = pool.squeeze(-1)
if self.angular:
for W in self.final.parameters():
W = F.normalize(W, p=2, dim=1)
pool = F.normalize(pool, p=2, dim=1)
out = self.final(pool)
return out, embs[-1].squeeze(-1)
@dataclass
class JasperEncoderConfig:
filters: int = MISSING
repeat: int = MISSING
kernel: List[int] = MISSING
stride: List[int] = MISSING
dilation: List[int] = MISSING
dropout: float = MISSING
residual: bool = MISSING
# Optional arguments
groups: int = 1
separable: bool = False
heads: int = -1
residual_mode: str = "add"
residual_dense: bool = False
se: bool = False
se_reduction_ratio: int = 8
se_context_size: int = -1
se_interpolation_mode: str = 'nearest'
kernel_size_factor: float = 1.0
stride_last: bool = False
@dataclass
class ConvASREncoderConfig:
_target_: str = 'nemo.collections.asr.modules.ConvASREncoder'
jasper: Optional[JasperEncoderConfig] = field(default_factory=list)
activation: str = MISSING
feat_in: int = MISSING
normalization_mode: str = "batch"
residual_mode: str = "add"
norm_groups: int = -1
conv_mask: bool = True
frame_splicing: int = 1
init_mode: Optional[str] = "xavier_uniform"
@dataclass
class ConvASRDecoderConfig:
_target_: str = 'nemo.collections.asr.modules.ConvASRDecoder'
feat_in: int = MISSING
num_classes: int = MISSING
init_mode: Optional[str] = "xavier_uniform"
vocabulary: Optional[List[str]] = field(default_factory=list)
@dataclass
class ConvASRDecoderClassificationConfig:
_target_: str = 'nemo.collections.asr.modules.ConvASRDecoderClassification'
feat_in: int = MISSING
num_classes: int = MISSING
init_mode: Optional[str] = "xavier_uniform"
return_logits: bool = True
pooling_type: str = 'avg'