Correct ASR issues + Patch for Pytorch 1.8 (#1565)

* Trim silence default to False

Signed-off-by: smajumdar <titu1994@gmail.com>

* Update stft and torch.fft.ifft for Pytorch 1.8

Signed-off-by: smajumdar <titu1994@gmail.com>

* Style fixes

Signed-off-by: smajumdar <titu1994@gmail.com>

* Clear up old code

Signed-off-by: smajumdar <titu1994@gmail.com>
This commit is contained in:
Somshubra Majumdar 2020-12-17 14:14:59 -08:00 committed by GitHub
parent e07d05c411
commit 51ae260a50
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 15 additions and 17 deletions

View file

@ -30,6 +30,4 @@ from .package_info import (
)
if "NEMO_PACKAGE_BUILDING" not in os.environ:
from nemo import core
from nemo import utils
from nemo import collections
from nemo import collections, core, utils

View file

@ -43,7 +43,7 @@ def get_char_dataset(config: dict, augmentor: Optional['AudioAugmentor'] = None)
blank_index=config.get('blank_index', -1),
unk_index=config.get('unk_index', -1),
normalize=config.get('normalize_transcripts', False),
trim=config.get('trim_silence', True),
trim=config.get('trim_silence', False),
load_audio=config.get('load_audio', True),
parser=config.get('parser', 'en'),
add_misc=config.get('add_misc', False),
@ -74,7 +74,7 @@ def get_bpe_dataset(
max_duration=config.get('max_duration', None),
min_duration=config.get('min_duration', None),
max_utts=config.get('max_utts', 0),
trim=config.get('trim_silence', True),
trim=config.get('trim_silence', False),
load_audio=config.get('load_audio', True),
add_misc=config.get('add_misc', False),
use_start_end_token=config.get('use_start_end_token', True),
@ -113,7 +113,7 @@ def get_tarred_char_dataset(
blank_index=config.get('blank_index', -1),
unk_index=config.get('unk_index', -1),
normalize=config.get('normalize_transcripts', False),
trim=config.get('trim_silence', True),
trim=config.get('trim_silence', False),
parser=config.get('parser', 'en'),
add_misc=config.get('add_misc', False),
shard_strategy=config.get('tarred_shard_strategy', 'scatter'),
@ -157,7 +157,7 @@ def get_tarred_bpe_dataset(
max_duration=config.get('max_duration', None),
min_duration=config.get('min_duration', None),
max_utts=config.get('max_utts', 0),
trim=config.get('trim_silence', True),
trim=config.get('trim_silence', False),
add_misc=config.get('add_misc', False),
use_start_end_token=config.get('use_start_end_token', True),
shard_strategy=config.get('tarred_shard_strategy', 'scatter'),
@ -201,7 +201,7 @@ def get_dali_char_dataset(
blank_index=config.get('blank_index', -1),
unk_index=config.get('unk_index', -1),
normalize=config.get('normalize_transcripts', False),
trim=config.get('trim_silence', True),
trim=config.get('trim_silence', False),
parser=config.get('parser', 'en'),
shuffle=shuffle,
device_id=device_id,

View file

@ -35,8 +35,8 @@ from nemo.utils import logging
try:
import torchaudio
import torchaudio.transforms
import torchaudio.functional
import torchaudio.transforms
TORCHAUDIO_VERSION = version.parse(torchaudio.__version__)
TORCHAUDIO_VERSION_MIN = version.parse('0.5')

View file

@ -61,8 +61,7 @@ class BeamSearchDecoderWithLM(NeuralModule):
):
try:
from ctc_decoders import Scorer
from ctc_decoders import ctc_beam_search_decoder_batch
from ctc_decoders import Scorer, ctc_beam_search_decoder_batch
except ModuleNotFoundError:
raise ModuleNotFoundError(
"BeamSearchDecoderWithLM requires the "

View file

@ -266,6 +266,7 @@ class FilterbankFeatures(nn.Module):
win_length=self.win_length,
center=False if stft_exact_pad else True,
window=self.window.to(dtype=torch.float),
return_complex=False,
)
self.normalize = normalize

View file

@ -61,7 +61,7 @@ def stft(x, fft_size, hop_size, win_length, window):
Returns:
Tensor: Magnitude spectrogram (B, #frames, fft_size // 2 + 1).
"""
x_stft = torch.stft(x, fft_size, hop_size, win_length, window)
x_stft = torch.stft(x, fft_size, hop_size, win_length, window, return_complex=False)
real = x_stft[..., 0]
imag = x_stft[..., 1]

View file

@ -129,7 +129,7 @@ class InverseSTFT(nn.Module):
# when the model is in nn.DataParallel
# of PyTorch 1.2.0 (py3.7_cuda10.0.130_cudnn7.6.2_01.2)
eye_realimag = torch.stack((eye, torch.zeros(n_fft, n_fft)), dim=-1)
basis = torch.ifft(eye_realimag, signal_ndim=1) # n_fft, n_fft, 2
basis = torch.fft.ifft(eye_realimag, signal_ndim=1) # n_fft, n_fft, 2
basis[..., 1] *= -1 # because (a+b*1j)*(c+d*1j) == a*c - b*d
basis *= window
self.basis = nn.Parameter(basis, requires_grad=False) # n_fft, n_fft, 2
@ -517,7 +517,7 @@ class DegliModule(NeuralModule, Exportable):
self.mode = OperationMode.infer
def stft(self, x):
return torch.stft(x, n_fft=self.n_fft, hop_length=self.hop_length, window=self.window)
return torch.stft(x, n_fft=self.n_fft, hop_length=self.hop_length, window=self.window, return_complex=False)
@typecheck()
def forward(self, x, mag, max_length=None, repeat=2):

View file

@ -592,7 +592,6 @@ class ModelPT(LightningModule, Model):
finally:
cls._set_model_restore_state(is_being_restored=False)
return checkpoint
@abstractmethod

View file

@ -45,7 +45,7 @@ Instructions
# Import the API Key
try:
from freesound_private_apikey import client_id, api_key
from freesound_private_apikey import api_key, client_id
print("API Key found !")
except ImportError:

View file

@ -127,7 +127,8 @@ def listener_process(queue, configurer, log_file, level):
logger.handle(record) # No level or filter logic applied - just do it!
except Exception:
import sys, traceback
import sys
import traceback
print('Problem:', file=sys.stderr)
traceback.print_exc(file=sys.stderr)