Correct ASR issues + Patch for Pytorch 1.8 (#1565)
* Trim silence default to False Signed-off-by: smajumdar <titu1994@gmail.com> * Update stft and torch.fft.ifft for Pytorch 1.8 Signed-off-by: smajumdar <titu1994@gmail.com> * Style fixes Signed-off-by: smajumdar <titu1994@gmail.com> * Clear up old code Signed-off-by: smajumdar <titu1994@gmail.com>
This commit is contained in:
parent
e07d05c411
commit
51ae260a50
|
@ -30,6 +30,4 @@ from .package_info import (
|
|||
)
|
||||
|
||||
if "NEMO_PACKAGE_BUILDING" not in os.environ:
|
||||
from nemo import core
|
||||
from nemo import utils
|
||||
from nemo import collections
|
||||
from nemo import collections, core, utils
|
||||
|
|
|
@ -43,7 +43,7 @@ def get_char_dataset(config: dict, augmentor: Optional['AudioAugmentor'] = None)
|
|||
blank_index=config.get('blank_index', -1),
|
||||
unk_index=config.get('unk_index', -1),
|
||||
normalize=config.get('normalize_transcripts', False),
|
||||
trim=config.get('trim_silence', True),
|
||||
trim=config.get('trim_silence', False),
|
||||
load_audio=config.get('load_audio', True),
|
||||
parser=config.get('parser', 'en'),
|
||||
add_misc=config.get('add_misc', False),
|
||||
|
@ -74,7 +74,7 @@ def get_bpe_dataset(
|
|||
max_duration=config.get('max_duration', None),
|
||||
min_duration=config.get('min_duration', None),
|
||||
max_utts=config.get('max_utts', 0),
|
||||
trim=config.get('trim_silence', True),
|
||||
trim=config.get('trim_silence', False),
|
||||
load_audio=config.get('load_audio', True),
|
||||
add_misc=config.get('add_misc', False),
|
||||
use_start_end_token=config.get('use_start_end_token', True),
|
||||
|
@ -113,7 +113,7 @@ def get_tarred_char_dataset(
|
|||
blank_index=config.get('blank_index', -1),
|
||||
unk_index=config.get('unk_index', -1),
|
||||
normalize=config.get('normalize_transcripts', False),
|
||||
trim=config.get('trim_silence', True),
|
||||
trim=config.get('trim_silence', False),
|
||||
parser=config.get('parser', 'en'),
|
||||
add_misc=config.get('add_misc', False),
|
||||
shard_strategy=config.get('tarred_shard_strategy', 'scatter'),
|
||||
|
@ -157,7 +157,7 @@ def get_tarred_bpe_dataset(
|
|||
max_duration=config.get('max_duration', None),
|
||||
min_duration=config.get('min_duration', None),
|
||||
max_utts=config.get('max_utts', 0),
|
||||
trim=config.get('trim_silence', True),
|
||||
trim=config.get('trim_silence', False),
|
||||
add_misc=config.get('add_misc', False),
|
||||
use_start_end_token=config.get('use_start_end_token', True),
|
||||
shard_strategy=config.get('tarred_shard_strategy', 'scatter'),
|
||||
|
@ -201,7 +201,7 @@ def get_dali_char_dataset(
|
|||
blank_index=config.get('blank_index', -1),
|
||||
unk_index=config.get('unk_index', -1),
|
||||
normalize=config.get('normalize_transcripts', False),
|
||||
trim=config.get('trim_silence', True),
|
||||
trim=config.get('trim_silence', False),
|
||||
parser=config.get('parser', 'en'),
|
||||
shuffle=shuffle,
|
||||
device_id=device_id,
|
||||
|
|
|
@ -35,8 +35,8 @@ from nemo.utils import logging
|
|||
|
||||
try:
|
||||
import torchaudio
|
||||
import torchaudio.transforms
|
||||
import torchaudio.functional
|
||||
import torchaudio.transforms
|
||||
|
||||
TORCHAUDIO_VERSION = version.parse(torchaudio.__version__)
|
||||
TORCHAUDIO_VERSION_MIN = version.parse('0.5')
|
||||
|
|
|
@ -61,8 +61,7 @@ class BeamSearchDecoderWithLM(NeuralModule):
|
|||
):
|
||||
|
||||
try:
|
||||
from ctc_decoders import Scorer
|
||||
from ctc_decoders import ctc_beam_search_decoder_batch
|
||||
from ctc_decoders import Scorer, ctc_beam_search_decoder_batch
|
||||
except ModuleNotFoundError:
|
||||
raise ModuleNotFoundError(
|
||||
"BeamSearchDecoderWithLM requires the "
|
||||
|
|
|
@ -266,6 +266,7 @@ class FilterbankFeatures(nn.Module):
|
|||
win_length=self.win_length,
|
||||
center=False if stft_exact_pad else True,
|
||||
window=self.window.to(dtype=torch.float),
|
||||
return_complex=False,
|
||||
)
|
||||
|
||||
self.normalize = normalize
|
||||
|
|
|
@ -61,7 +61,7 @@ def stft(x, fft_size, hop_size, win_length, window):
|
|||
Returns:
|
||||
Tensor: Magnitude spectrogram (B, #frames, fft_size // 2 + 1).
|
||||
"""
|
||||
x_stft = torch.stft(x, fft_size, hop_size, win_length, window)
|
||||
x_stft = torch.stft(x, fft_size, hop_size, win_length, window, return_complex=False)
|
||||
real = x_stft[..., 0]
|
||||
imag = x_stft[..., 1]
|
||||
|
||||
|
|
|
@ -129,7 +129,7 @@ class InverseSTFT(nn.Module):
|
|||
# when the model is in nn.DataParallel
|
||||
# of PyTorch 1.2.0 (py3.7_cuda10.0.130_cudnn7.6.2_01.2)
|
||||
eye_realimag = torch.stack((eye, torch.zeros(n_fft, n_fft)), dim=-1)
|
||||
basis = torch.ifft(eye_realimag, signal_ndim=1) # n_fft, n_fft, 2
|
||||
basis = torch.fft.ifft(eye_realimag, signal_ndim=1) # n_fft, n_fft, 2
|
||||
basis[..., 1] *= -1 # because (a+b*1j)*(c+d*1j) == a*c - b*d
|
||||
basis *= window
|
||||
self.basis = nn.Parameter(basis, requires_grad=False) # n_fft, n_fft, 2
|
||||
|
@ -517,7 +517,7 @@ class DegliModule(NeuralModule, Exportable):
|
|||
self.mode = OperationMode.infer
|
||||
|
||||
def stft(self, x):
|
||||
return torch.stft(x, n_fft=self.n_fft, hop_length=self.hop_length, window=self.window)
|
||||
return torch.stft(x, n_fft=self.n_fft, hop_length=self.hop_length, window=self.window, return_complex=False)
|
||||
|
||||
@typecheck()
|
||||
def forward(self, x, mag, max_length=None, repeat=2):
|
||||
|
|
|
@ -592,7 +592,6 @@ class ModelPT(LightningModule, Model):
|
|||
|
||||
finally:
|
||||
cls._set_model_restore_state(is_being_restored=False)
|
||||
|
||||
return checkpoint
|
||||
|
||||
@abstractmethod
|
||||
|
|
|
@ -45,7 +45,7 @@ Instructions
|
|||
|
||||
# Import the API Key
|
||||
try:
|
||||
from freesound_private_apikey import client_id, api_key
|
||||
from freesound_private_apikey import api_key, client_id
|
||||
|
||||
print("API Key found !")
|
||||
except ImportError:
|
||||
|
|
|
@ -127,7 +127,8 @@ def listener_process(queue, configurer, log_file, level):
|
|||
logger.handle(record) # No level or filter logic applied - just do it!
|
||||
|
||||
except Exception:
|
||||
import sys, traceback
|
||||
import sys
|
||||
import traceback
|
||||
|
||||
print('Problem:', file=sys.stderr)
|
||||
traceback.print_exc(file=sys.stderr)
|
||||
|
|
Loading…
Reference in a new issue