DeepLearningExamples/PyTorch/SpeechRecognition/Jasper/model.py
Przemek Strzelczyk fa400a7367 Adding Jasper/PyT
2019-07-26 20:08:16 +02:00

410 lines
16 KiB
Python

# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from apex import amp
import torch
import torch.nn as nn
from parts.features import FeatureFactory
from helpers import Optimization
import random
jasper_activations = {
"hardtanh": nn.Hardtanh,
"relu": nn.ReLU,
"selu": nn.SELU,
}
def init_weights(m, mode='xavier_uniform'):
if type(m) == nn.Conv1d or type(m) == MaskedConv1d:
if mode == 'xavier_uniform':
nn.init.xavier_uniform_(m.weight, gain=1.0)
elif mode == 'xavier_normal':
nn.init.xavier_normal_(m.weight, gain=1.0)
elif mode == 'kaiming_uniform':
nn.init.kaiming_uniform_(m.weight, nonlinearity="relu")
elif mode == 'kaiming_normal':
nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
else:
raise ValueError("Unknown Initialization mode: {0}".format(mode))
elif type(m) == nn.BatchNorm1d:
if m.track_running_stats:
m.running_mean.zero_()
m.running_var.fill_(1)
m.num_batches_tracked.zero_()
if m.affine:
nn.init.ones_(m.weight)
nn.init.zeros_(m.bias)
def get_same_padding(kernel_size, stride, dilation):
if stride > 1 and dilation > 1:
raise ValueError("Only stride OR dilation may be greater than 1")
return (kernel_size // 2) * dilation
class AudioPreprocessing(nn.Module):
"""GPU accelerated audio preprocessing
"""
def __init__(self, **kwargs):
nn.Module.__init__(self) # For PyTorch API
self.optim_level = kwargs.get('optimization_level', Optimization.nothing)
self.featurizer = FeatureFactory.from_config(kwargs)
def forward(self, x):
input_signal, length = x
length.requires_grad_(False)
if self.optim_level not in [Optimization.nothing, Optimization.mxprO0, Optimization.mxprO3]:
with amp.disable_casts():
processed_signal = self.featurizer(x)
processed_length = self.featurizer.get_seq_len(length)
else:
processed_signal = self.featurizer(x)
processed_length = self.featurizer.get_seq_len(length)
return processed_signal, processed_length
class SpectrogramAugmentation(nn.Module):
"""Spectrogram augmentation
"""
def __init__(self, **kwargs):
nn.Module.__init__(self)
self.spec_cutout_regions = SpecCutoutRegions(kwargs)
self.spec_augment = SpecAugment(kwargs)
@torch.no_grad()
def forward(self, input_spec):
augmented_spec = self.spec_cutout_regions(input_spec)
augmented_spec = self.spec_augment(augmented_spec)
return augmented_spec
class SpecAugment(nn.Module):
"""Spec augment. refer to https://arxiv.org/abs/1904.08779
"""
def __init__(self, cfg, rng=None):
super(SpecAugment, self).__init__()
self._rng = random.Random() if rng is None else rng
self.cutout_x_regions = cfg.get('cutout_x_regions', 0)
self.cutout_y_regions = cfg.get('cutout_y_regions', 0)
self.cutout_x_width = cfg.get('cutout_x_width', 10)
self.cutout_y_width = cfg.get('cutout_y_width', 10)
@torch.no_grad()
def forward(self, x):
sh = x.shape
mask = torch.zeros(x.shape).byte()
for idx in range(sh[0]):
for _ in range(self.cutout_x_regions):
cutout_x_left = int(self._rng.uniform(0, sh[1] - self.cutout_x_width))
mask[idx, cutout_x_left:cutout_x_left + self.cutout_x_width, :] = 1
for _ in range(self.cutout_y_regions):
cutout_y_left = int(self._rng.uniform(0, sh[2] - self.cutout_y_width))
mask[idx, :, cutout_y_left:cutout_y_left + self.cutout_y_width] = 1
x = x.masked_fill(mask.to(device=x.device), 0)
return x
class SpecCutoutRegions(nn.Module):
"""Cutout. refer to https://arxiv.org/pdf/1708.04552.pdf
"""
def __init__(self, cfg, rng=None):
super(SpecCutoutRegions, self).__init__()
self._rng = random.Random() if rng is None else rng
self.cutout_rect_regions = cfg.get('cutout_rect_regions', 0)
self.cutout_rect_time = cfg.get('cutout_rect_time', 5)
self.cutout_rect_freq = cfg.get('cutout_rect_freq', 20)
@torch.no_grad()
def forward(self, x):
sh = x.shape
mask = torch.zeros(x.shape).byte()
for idx in range(sh[0]):
for i in range(self.cutout_rect_regions):
cutout_rect_x = int(self._rng.uniform(
0, sh[1] - self.cutout_rect_freq))
cutout_rect_y = int(self._rng.uniform(
0, sh[2] - self.cutout_rect_time))
mask[idx, cutout_rect_x:cutout_rect_x + self.cutout_rect_freq,
cutout_rect_y:cutout_rect_y + self.cutout_rect_time] = 1
x = x.masked_fill(mask.to(device=x.device), 0)
return x
class JasperEncoder(nn.Module):
"""Jasper encoder
"""
def __init__(self, **kwargs):
cfg = {}
for key, value in kwargs.items():
cfg[key] = value
nn.Module.__init__(self)
self._cfg = cfg
activation = jasper_activations[cfg['encoder']['activation']]()
use_conv_mask = cfg['encoder'].get('convmask', False)
feat_in = cfg['input']['features'] * cfg['input'].get('frame_splicing', 1)
init_mode = cfg.get('init_mode', 'xavier_uniform')
residual_panes = []
encoder_layers = []
self.dense_residual = False
for lcfg in cfg['jasper']:
dense_res = []
if lcfg.get('residual_dense', False):
residual_panes.append(feat_in)
dense_res = residual_panes
self.dense_residual = True
encoder_layers.append(
JasperBlock(feat_in, lcfg['filters'], repeat=lcfg['repeat'],
kernel_size=lcfg['kernel'], stride=lcfg['stride'],
dilation=lcfg['dilation'], dropout=lcfg['dropout'],
residual=lcfg['residual'], activation=activation,
residual_panes=dense_res, conv_mask=use_conv_mask))
feat_in = lcfg['filters']
self.encoder = nn.Sequential(*encoder_layers)
self.apply(lambda x: init_weights(x, mode=init_mode))
def num_weights(self):
return sum(p.numel() for p in self.parameters() if p.requires_grad)
def forward(self, x):
audio_signal, length = x
s_input, length = self.encoder(([audio_signal], length))
return s_input, length
class JasperDecoderForCTC(nn.Module):
"""Jasper decoder
"""
def __init__(self, **kwargs):
nn.Module.__init__(self)
self._feat_in = kwargs.get("feat_in")
self._num_classes = kwargs.get("num_classes")
init_mode = kwargs.get('init_mode', 'xavier_uniform')
self.decoder_layers = nn.Sequential(
nn.Conv1d(self._feat_in, self._num_classes, kernel_size=1, bias=True),
nn.LogSoftmax(dim=1))
self.apply(lambda x: init_weights(x, mode=init_mode))
def num_weights(self):
return sum(p.numel() for p in self.parameters() if p.requires_grad)
def forward(self, encoder_output):
out = self.decoder_layers(encoder_output[-1])
return out.transpose(1, 2)
class Jasper(nn.Module):
"""Contains data preprocessing, spectrogram augmentation, jasper encoder and decoder
"""
def __init__(self, **kwargs):
nn.Module.__init__(self)
self.audio_preprocessor = AudioPreprocessing(**kwargs.get("feature_config"))
self.data_spectr_augmentation = SpectrogramAugmentation(**kwargs.get("feature_config"))
self.jasper_encoder = JasperEncoder(**kwargs.get("jasper_model_definition"))
self.jasper_decoder = JasperDecoderForCTC(feat_in=kwargs.get("feat_in"),
num_classes=kwargs.get("num_classes"))
def num_weights(self):
return sum(p.numel() for p in self.parameters() if p.requires_grad)
def forward(self, x):
input_signal, length = x
t_processed_signal, p_length_t = self.audio_preprocessor(x)
if self.training:
t_processed_signal = self.data_spectr_augmentation(input_spec=t_processed_signal)
t_encoded_t, t_encoded_len_t = self.jasper_encoder((t_processed_signal, p_length_t))
return self.jasper_decoder(encoder_output=t_encoded_t), t_encoded_len_t
class JasperEncoderDecoder(nn.Module):
"""Contains jasper encoder and decoder
"""
def __init__(self, **kwargs):
nn.Module.__init__(self)
self.jasper_encoder = JasperEncoder(**kwargs.get("jasper_model_definition"))
self.jasper_decoder = JasperDecoderForCTC(feat_in=kwargs.get("feat_in"),
num_classes=kwargs.get("num_classes"))
def num_weights(self):
return sum(p.numel() for p in self.parameters() if p.requires_grad)
def forward(self, x):
t_processed_signal, p_length_t = x
t_encoded_t, t_encoded_len_t = self.jasper_encoder((t_processed_signal, p_length_t))
return self.jasper_decoder(encoder_output=t_encoded_t), t_encoded_len_t
class MaskedConv1d(nn.Conv1d):
"""1D convolution with sequence masking
"""
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
padding=0, dilation=1, groups=1, bias=False, use_mask=True):
super(MaskedConv1d, self).__init__(in_channels, out_channels, kernel_size,
stride=stride,
padding=padding, dilation=dilation,
groups=groups, bias=bias)
self.use_mask = use_mask
def get_seq_len(self, lens):
return ((lens + 2 * self.padding[0] - self.dilation[0] * (
self.kernel_size[0] - 1) - 1) / self.stride[0] + 1)
def forward(self, inp):
x, lens = inp
if self.use_mask:
max_len = x.size(2)
mask = torch.arange(max_len).to(lens.dtype).to(lens.device).expand(len(lens),
max_len) >= lens.unsqueeze(
1)
x = x.masked_fill(mask.unsqueeze(1).to(device=x.device), 0)
del mask
lens = self.get_seq_len(lens)
out = super(MaskedConv1d, self).forward(x)
return out, lens
class JasperBlock(nn.Module):
"""Jasper Block. See https://arxiv.org/pdf/1904.03288.pdf
"""
def __init__(self, inplanes, planes, repeat=3, kernel_size=11, stride=1,
dilation=1, padding='same', dropout=0.2, activation=None,
residual=True, residual_panes=[], conv_mask=False):
super(JasperBlock, self).__init__()
if padding != "same":
raise ValueError("currently only 'same' padding is supported")
padding_val = get_same_padding(kernel_size[0], stride[0], dilation[0])
self.conv_mask = conv_mask
self.conv = nn.ModuleList()
inplanes_loop = inplanes
for _ in range(repeat - 1):
self.conv.extend(
self._get_conv_bn_layer(inplanes_loop, planes, kernel_size=kernel_size,
stride=stride, dilation=dilation,
padding=padding_val))
self.conv.extend(
self._get_act_dropout_layer(drop_prob=dropout, activation=activation))
inplanes_loop = planes
self.conv.extend(
self._get_conv_bn_layer(inplanes_loop, planes, kernel_size=kernel_size,
stride=stride, dilation=dilation,
padding=padding_val))
self.res = nn.ModuleList() if residual else None
res_panes = residual_panes.copy()
self.dense_residual = residual
if residual:
if len(residual_panes) == 0:
res_panes = [inplanes]
self.dense_residual = False
for ip in res_panes:
self.res.append(nn.ModuleList(
modules=self._get_conv_bn_layer(ip, planes, kernel_size=1)))
self.out = nn.Sequential(
*self._get_act_dropout_layer(drop_prob=dropout, activation=activation))
def _get_conv_bn_layer(self, in_channels, out_channels, kernel_size=11,
stride=1, dilation=1, padding=0, bias=False):
layers = [
MaskedConv1d(in_channels, out_channels, kernel_size, stride=stride,
dilation=dilation, padding=padding, bias=bias,
use_mask=self.conv_mask),
nn.BatchNorm1d(out_channels, eps=1e-3, momentum=0.1)
]
return layers
def _get_act_dropout_layer(self, drop_prob=0.2, activation=None):
if activation is None:
activation = nn.Hardtanh(min_val=0.0, max_val=20.0)
layers = [
activation,
nn.Dropout(p=drop_prob)
]
return layers
def num_weights(self):
return sum(p.numel() for p in self.parameters() if p.requires_grad)
def forward(self, input_):
xs, lens_orig = input_
# compute forward convolutions
out = xs[-1]
lens = lens_orig
for i, l in enumerate(self.conv):
if isinstance(l, MaskedConv1d):
out, lens = l((out, lens))
else:
out = l(out)
# compute the residuals
if self.res is not None:
for i, layer in enumerate(self.res):
res_out = xs[i]
for j, res_layer in enumerate(layer):
if j == 0:
res_out, _ = res_layer((res_out, lens_orig))
else:
res_out = res_layer(res_out)
out += res_out
# compute the output
out = self.out(out)
if self.res is not None and self.dense_residual:
return xs + [out], lens
return [out], lens
class GreedyCTCDecoder(nn.Module):
""" Greedy CTC Decoder
"""
def __init__(self, **kwargs):
nn.Module.__init__(self) # For PyTorch API
def forward(self, log_probs):
with torch.no_grad():
argmx = log_probs.argmax(dim=-1, keepdim=False).int()
return argmx
class CTCLossNM:
""" CTC loss
"""
def __init__(self, **kwargs):
self._blank = kwargs['num_classes'] - 1
self._criterion = nn.CTCLoss(blank=self._blank, reduction='none')
def __call__(self, log_probs, targets, input_length, target_length):
input_length = input_length.long()
target_length = target_length.long()
targets = targets.long()
loss = self._criterion(log_probs.transpose(1, 0), targets, input_length,
target_length)
# note that this is different from reduction = 'mean'
# because we are not dividing by target lengths
return torch.mean(loss)