DeepLearningExamples/PyTorch/SpeechRecognition/QuartzNet/quartznet/model.py
2021-09-14 06:03:36 -07:00

392 lines
14 KiB
Python

# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn as nn
import torch.nn.functional as F
activations = {
"hardtanh": nn.Hardtanh,
"relu": nn.ReLU,
"selu": nn.SELU,
}
def init_weights(m, mode='xavier_uniform'):
if type(m) == nn.Conv1d or type(m) == MaskedConv1d:
if mode == 'xavier_uniform':
nn.init.xavier_uniform_(m.weight, gain=1.0)
elif mode == 'xavier_normal':
nn.init.xavier_normal_(m.weight, gain=1.0)
elif mode == 'kaiming_uniform':
nn.init.kaiming_uniform_(m.weight, nonlinearity="relu")
elif mode == 'kaiming_normal':
nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
else:
raise ValueError("Unknown Initialization mode: {0}".format(mode))
elif type(m) == nn.BatchNorm1d:
if m.track_running_stats:
m.running_mean.zero_()
m.running_var.fill_(1)
m.num_batches_tracked.zero_()
if m.affine:
nn.init.ones_(m.weight)
nn.init.zeros_(m.bias)
def compute_new_kernel_size(kernel_size, kernel_width):
new_kernel_size = max(int(kernel_size * kernel_width), 1)
# If kernel is even shape, round up to make it odd
if new_kernel_size % 2 == 0:
new_kernel_size += 1
return new_kernel_size
def get_same_padding(kernel_size, stride, dilation):
if stride > 1 and dilation > 1:
raise ValueError("Only stride OR dilation may be greater than 1")
return (kernel_size // 2) * dilation
class GroupShuffle(nn.Module):
def __init__(self, groups, channels):
super(GroupShuffle, self).__init__()
self.groups = groups
self.channels_per_group = channels // groups
def forward(self, x):
sh = x.shape
x = x.view(-1, self.groups, self.channels_per_group, sh[-1])
x = torch.transpose(x, 1, 2).contiguous()
x = x.view(-1, self.groups * self.channels_per_group, sh[-1])
return x
class MaskedConv1d(nn.Conv1d):
"""1D convolution with sequence masking
"""
__constants__ = ["masked"]
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
padding=0, dilation=1, groups=1, bias=False, use_mask=True,
heads=-1):
# Jasper refactor compat
assert heads == -1 # Unsupported
masked = use_mask
super(MaskedConv1d, self).__init__(
in_channels, out_channels, kernel_size, stride=stride,
padding=padding, dilation=dilation, groups=groups, bias=bias)
self.masked = masked
def get_seq_len(self, lens):
pad, ks = self.padding[0], self.kernel_size[0]
return torch.div(lens + 2 * pad - self.dilation[0] * (ks - 1) - 1,
self.stride[0], rounding_mode='trunc') + 1
def forward(self, x, x_lens=None):
if self.masked:
max_len = x.size(2)
idxs = torch.arange(max_len, dtype=x_lens.dtype, device=x.device)
mask = idxs.expand(x_lens.size(0), max_len) >= x_lens.unsqueeze(1)
x = x.masked_fill(mask.unsqueeze(1).to(device=x.device), 0)
x_lens = self.get_seq_len(x_lens)
return super(MaskedConv1d, self).forward(x), x_lens
class JasperBlock(nn.Module):
__constants__ = ["conv_mask", "separable", "res", "mconv"]
def __init__(self, infilters, filters, repeat=3, kernel_size=11,
kernel_size_factor=1, stride=1, dilation=1, padding='same',
dropout=0.2, activation=None, residual=True, groups=1,
separable=False, heads=-1, normalization="batch",
norm_groups=1, residual_panes=[], use_conv_masks=False):
super(JasperBlock, self).__init__()
# Fix params being passed as list, but default to ints
wrap = lambda v: [v] if type(v) is int else v
kernel_size = wrap(kernel_size)
dilation = wrap(dilation)
padding = wrap(padding)
stride = wrap(stride)
if padding != "same":
raise ValueError("currently only 'same' padding is supported")
kernel_size_factor = float(kernel_size_factor)
if type(kernel_size) in (list, tuple):
kernel_size = [compute_new_kernel_size(k, kernel_size_factor)
for k in kernel_size]
else:
kernel_size = compute_new_kernel_size(kernel_size,
kernel_size_factor)
padding_val = get_same_padding(kernel_size[0], stride[0], dilation[0])
self.conv_mask = use_conv_masks
self.separable = separable
infilters_loop = infilters
conv = nn.ModuleList()
for _ in range(repeat - 1):
conv.extend(
self._get_conv_bn_layer(
infilters_loop, filters, kernel_size=kernel_size,
stride=stride, dilation=dilation, padding=padding_val,
groups=groups, heads=heads, separable=separable,
normalization=normalization, norm_groups=norm_groups)
)
conv.extend(self._get_act_dropout_layer(drop_prob=dropout,
activation=activation))
infilters_loop = filters
conv.extend(
self._get_conv_bn_layer(
infilters_loop, filters, kernel_size=kernel_size, stride=stride,
dilation=dilation, padding=padding_val, groups=groups,
heads=heads, separable=separable, normalization=normalization,
norm_groups=norm_groups)
)
self.mconv = conv
res_panes = residual_panes.copy()
self.dense_residual = residual
if residual:
res_list = nn.ModuleList()
if len(residual_panes) == 0:
res_panes = [infilters]
self.dense_residual = False
for ip in res_panes:
res_list.append(nn.ModuleList(
self._get_conv_bn_layer(ip, filters, kernel_size=1,
normalization=normalization,
norm_groups=norm_groups, stride=[1])
))
self.res = res_list
else:
self.res = None
self.mout = nn.Sequential(*self._get_act_dropout_layer(
drop_prob=dropout, activation=activation))
def _get_conv(self, in_channels, out_channels, kernel_size=11, stride=1,
dilation=1, padding=0, bias=False, groups=1, heads=-1,
separable=False):
kw = {'in_channels': in_channels, 'out_channels': out_channels,
'kernel_size': kernel_size, 'stride': stride, 'dilation': dilation,
'padding': padding, 'bias': bias, 'groups': groups}
if self.conv_mask:
return MaskedConv1d(**kw, heads=heads, use_mask=self.conv_mask)
else:
return nn.Conv1d(**kw)
def _get_conv_bn_layer(self, in_channels, out_channels, kernel_size=11,
stride=1, dilation=1, padding=0, bias=False,
groups=1, heads=-1, separable=False,
normalization="batch", norm_groups=1):
if norm_groups == -1:
norm_groups = out_channels
if separable:
layers = [
self._get_conv(in_channels, in_channels, kernel_size,
stride=stride, dilation=dilation, padding=padding,
bias=bias, groups=in_channels, heads=heads),
self._get_conv(in_channels, out_channels, kernel_size=1,
stride=1, dilation=1, padding=0, bias=bias,
groups=groups),
]
else:
layers = [
self._get_conv(in_channels, out_channels, kernel_size,
stride=stride, dilation=dilation,
padding=padding, bias=bias, groups=groups)
]
if normalization == "group":
layers.append(nn.GroupNorm(num_groups=norm_groups,
num_channels=out_channels))
elif normalization == "instance":
layers.append(nn.GroupNorm(num_groups=out_channels,
num_channels=out_channels))
elif normalization == "layer":
layers.append(nn.GroupNorm(num_groups=1, num_channels=out_channels))
elif normalization == "batch":
layers.append(nn.BatchNorm1d(out_channels, eps=1e-3, momentum=0.1))
else:
raise ValueError(
f"Normalization method ({normalization}) does not match"
f" one of [batch, layer, group, instance]."
)
if groups > 1:
layers.append(GroupShuffle(groups, out_channels))
return layers
def _get_act_dropout_layer(self, drop_prob=0.2, activation=None):
if activation is None:
activation = nn.Hardtanh(min_val=0.0, max_val=20.0)
layers = [activation, nn.Dropout(p=drop_prob)]
return layers
def forward(self, xs, xs_lens=None):
if not self.conv_mask:
xs_lens = 0
# compute forward convolutions
out = xs[-1]
lens = xs_lens
for i, l in enumerate(self.mconv):
# if we're doing masked convolutions, we need to pass in and
# possibly update the sequence lengths
# if (i % 4) == 0 and self.conv_mask:
if isinstance(l, MaskedConv1d):
out, lens = l(out, lens)
else:
out = l(out)
# compute the residuals
if self.res is not None:
for i, layer in enumerate(self.res):
res_out = xs[i]
for j, res_layer in enumerate(layer):
if isinstance(res_layer, MaskedConv1d):
res_out, _ = res_layer(res_out, xs_lens)
else:
res_out = res_layer(res_out)
out = out + res_out
# compute the output
out = self.mout(out)
if self.res is not None and self.dense_residual:
out = xs + [out]
else:
out = [out]
return (out, lens) if self.conv_mask else (out, None)
class JasperEncoder(nn.Module):
__constants__ = ["use_conv_masks"]
def __init__(self, in_feats, activation, frame_splicing=1,
init='xavier_uniform', use_conv_masks=False, blocks=[]):
super(JasperEncoder, self).__init__()
self.use_conv_masks = use_conv_masks
self.layers = nn.ModuleList()
in_feats *= frame_splicing
all_residual_panes = []
for i, blk in enumerate(blocks):
blk['activation'] = activations[activation]()
has_residual_dense = blk.pop('residual_dense', False)
if has_residual_dense:
all_residual_panes += [in_feats]
blk['residual_panes'] = all_residual_panes
else:
blk['residual_panes'] = []
self.layers.append(
JasperBlock(in_feats, use_conv_masks=use_conv_masks, **blk))
in_feats = blk['filters']
self.apply(lambda x: init_weights(x, mode=init))
def forward(self, x, x_lens=None):
out, out_lens = [x], x_lens
for layer in self.layers:
out, out_lens = layer(out, out_lens)
return out, out_lens
class JasperDecoderForCTC(nn.Module):
def __init__(self, in_feats, n_classes, init='xavier_uniform'):
super(JasperDecoderForCTC, self).__init__()
self.layers = nn.Sequential(
nn.Conv1d(in_feats, n_classes, kernel_size=1, bias=True),)
self.apply(lambda x: init_weights(x, mode=init))
def forward(self, enc_out):
out = self.layers(enc_out[-1]).transpose(1, 2)
return F.log_softmax(out, dim=2)
class GreedyCTCDecoder(nn.Module):
@torch.no_grad()
def forward(self, log_probs):
return log_probs.argmax(dim=-1, keepdim=False).int()
class QuartzNet(nn.Module):
def __init__(self, encoder_kw, decoder_kw, transpose_in=False):
super(QuartzNet, self).__init__()
self.transpose_in = transpose_in
self.encoder = JasperEncoder(**encoder_kw)
self.decoder = JasperDecoderForCTC(**decoder_kw)
def forward(self, x, x_lens=None):
if self.encoder.use_conv_masks:
assert x_lens is not None
enc, enc_lens = self.encoder(x, x_lens)
out = self.decoder(enc)
return out, enc_lens
else:
if self.transpose_in:
x = x.transpose(1, 2)
enc, _ = self.encoder(x)
out = self.decoder(enc)
return out # XXX torchscript refuses to output None
# TODO Explicitly add x_lens=None for inference (now x can be a Tensor or tuple)
def infer(self, x):
if self.encoder.use_conv_masks:
return self.forward(x)
else:
ret = self.forward(x[0])
return ret, len(ret)
class CTCLossNM:
def __init__(self, n_classes):
self._criterion = nn.CTCLoss(blank=n_classes-1, reduction='none')
def __call__(self, log_probs, targets, input_length, target_length):
input_length = input_length.long()
target_length = target_length.long()
targets = targets.long()
loss = self._criterion(log_probs.transpose(1, 0), targets,
input_length, target_length)
# note that this is different from reduction = 'mean'
# because we are not dividing by target lengths
return torch.mean(loss)