Refactor and Minimize Dependencies (#2643)

* squash

Signed-off-by: Jason <jasoli@nvidia.com>

* add comments

Signed-off-by: Jason <jasoli@nvidia.com>

* style and cleanup

Signed-off-by: Jason <jasoli@nvidia.com>

* cleanup

Signed-off-by: Jason <jasoli@nvidia.com>

* add new test file

Signed-off-by: Jason <jasoli@nvidia.com>

* syntax

Signed-off-by: Jason <jasoli@nvidia.com>

* style

Signed-off-by: Jason <jasoli@nvidia.com>

* typo

Signed-off-by: Jason <jasoli@nvidia.com>

* update

Signed-off-by: Jason <jasoli@nvidia.com>

* update

Signed-off-by: Jason <jasoli@nvidia.com>

* update

Signed-off-by: Jason <jasoli@nvidia.com>

* try again

Signed-off-by: Jason <jasoli@nvidia.com>

* wip

Signed-off-by: Jason <jasoli@nvidia.com>

* style; ci should fail

Signed-off-by: Jason <jasoli@nvidia.com>

* final

Signed-off-by: Jason <jasoli@nvidia.com>
This commit is contained in:
Jason 2021-08-17 10:55:43 -04:00 committed by GitHub
parent 94126c4b65
commit 4f2ea4913c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
34 changed files with 1102 additions and 121 deletions

22
Jenkinsfile vendored
View file

@ -40,6 +40,24 @@ pipeline {
sh 'python setup.py style'
}
}
stage('Torch TTS unit tests') {
when {
anyOf {
branch 'main'
changeRequest target: 'main'
}
}
steps {
sh 'pip install ".[torch_tts]"'
sh 'pip list'
sh 'test $(pip list | grep -c lightning) -eq 0'
sh 'test $(pip list | grep -c omegaconf) -eq 0'
sh 'test $(pip list | grep -c hydra) -eq 0'
sh 'pytest -m "torch_tts" --cpu tests/collections/tts/test_torch_tts.py --relax_numba_compat'
}
}
stage('Installation') {
steps {
sh './reinstall.sh release'
@ -60,7 +78,7 @@ pipeline {
stage('L0: Unit Tests GPU') {
steps {
sh 'pytest -m "not pleasefixme" --with_downloads --relax_numba_compat'
sh 'pytest -m "not pleasefixme and not torch_tts" --with_downloads --relax_numba_compat'
}
}
@ -72,7 +90,7 @@ pipeline {
}
}
steps {
sh 'CUDA_VISIBLE_DEVICES="" pytest -m "not pleasefixme" --cpu --with_downloads --relax_numba_compat'
sh 'CUDA_VISIBLE_DEVICES="" pytest -m "not pleasefixme and not torch_tts" --cpu --with_downloads --relax_numba_compat'
}
}

View file

@ -16,15 +16,17 @@ import abc
import itertools
import re
import string
import time
import unicodedata
from builtins import str as unicode
from typing import List
import nltk
import torch
from pytorch_lightning.utilities.distributed import rank_zero_only, sync_ddp_if_available
from nemo.collections.common.parts.preprocessing import parsers
from nemo.utils import logging
from nemo.utils.get_rank import is_global_rank_zero
_words_re = re.compile("([a-z\-]+'[a-z\-]+|[a-z\-]+)|([^a-z{}]+)")
@ -43,7 +45,6 @@ def _word_tokenize(text):
return words
@rank_zero_only
def download_corpora():
# Download NLTK datasets if this class is to be instantiated
try:
@ -310,10 +311,29 @@ class Phonemes(Base):
self.spaces = spaces
self.pad_with_space = pad_with_space
download_corpora()
_ = sync_ddp_if_available(torch.tensor(0)) # Barrier until rank 0 downloads the corpora
# g2p_en tries to run download_corpora() on import but it is not rank zero guarded
# Try to check if torch distributed is available, if not get global rank zero to download corpora and make
# all other ranks sleep for a minute
if torch.distributed.is_available() and torch.distributed.is_initialized():
group = torch.distributed.group.WORLD
if is_global_rank_zero():
download_corpora()
torch.distributed.barrier(group=group)
elif is_global_rank_zero():
logging.error(
f"Torch distributed needs to be initialized before you initialized {self}. This class is prone to "
"data access race conditions. Now downloading corpora from global rank 0. If other ranks pass this "
"before rank 0, errors might result."
)
download_corpora()
else:
logging.error(
f"Torch distributed needs to be initialized before you initialized {self}. This class is prone to "
"data access race conditions. This process is not rank 0, and now going to sleep for 1 min. If this "
"rank wakes from sleep prior to rank 0 finishing downloading, errors might result."
)
time.sleep(60)
import g2p_en # noqa pylint: disable=import-outside-toplevel
_g2p = g2p_en.G2p()

View file

@ -12,11 +12,27 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from nemo.collections.asr.models.asr_model import ASRModel
from nemo.collections.asr.models.classification_models import EncDecClassificationModel
from nemo.collections.asr.models.clustering_diarizer import ClusteringDiarizer
from nemo.collections.asr.models.ctc_bpe_models import EncDecCTCModelBPE
from nemo.collections.asr.models.ctc_models import EncDecCTCModel
from nemo.collections.asr.models.label_models import EncDecSpeakerLabelModel, ExtractSpeakerEmbeddingsModel
from nemo.collections.asr.models.rnnt_bpe_models import EncDecRNNTBPEModel
from nemo.collections.asr.models.rnnt_models import EncDecRNNTModel
# TODO @blisc: Perhaps refactor instead of import guarding
try:
from nemo.collections.asr.models.asr_model import ASRModel
from nemo.collections.asr.models.classification_models import EncDecClassificationModel
from nemo.collections.asr.models.clustering_diarizer import ClusteringDiarizer
from nemo.collections.asr.models.ctc_bpe_models import EncDecCTCModelBPE
from nemo.collections.asr.models.ctc_models import EncDecCTCModel
from nemo.collections.asr.models.label_models import EncDecSpeakerLabelModel, ExtractSpeakerEmbeddingsModel
from nemo.collections.asr.models.rnnt_bpe_models import EncDecRNNTBPEModel
from nemo.collections.asr.models.rnnt_models import EncDecRNNTModel
except ModuleNotFoundError:
from nemo.utils.exceptions import CheckInstall
# fmt: off
class ASRModel(CheckInstall): pass
class EncDecClassificationModel(CheckInstall): pass
class ClusteringDiarizer(CheckInstall): pass
class EncDecCTCModelBPE(CheckInstall): pass
class EncDecCTCModel(CheckInstall): pass
class EncDecSpeakerLabelModel(CheckInstall): pass
class ExtractSpeakerEmbeddingsModel(CheckInstall): pass
class EncDecRNNTBPEModel(CheckInstall): pass
class EncDecRNNTModel(CheckInstall): pass
# fmt: on

View file

@ -20,13 +20,27 @@ from nemo.collections.asr.modules.audio_preprocessing import (
)
from nemo.collections.asr.modules.beam_search_decoder import BeamSearchDecoderWithLM
from nemo.collections.asr.modules.conformer_encoder import ConformerEncoder
from nemo.collections.asr.modules.conv_asr import (
ConvASRDecoder,
ConvASRDecoderClassification,
ConvASREncoder,
ECAPAEncoder,
ParallelConvASREncoder,
SpeakerDecoder,
)
from nemo.collections.asr.modules.lstm_decoder import LSTMDecoder
from nemo.collections.asr.modules.rnnt import RNNTDecoder, RNNTJoint
# TODO @blisc: Perhaps refactor instead of import guarding
try:
from nemo.collections.asr.modules.conv_asr import (
ConvASRDecoder,
ConvASRDecoderClassification,
ConvASREncoder,
ECAPAEncoder,
ParallelConvASREncoder,
SpeakerDecoder,
)
except ModuleNotFoundError:
from nemo.utils.exceptions import CheckInstall
# fmt: off
class ConvASRDecoder(CheckInstall): pass
class ConvASRDecoderClassification(CheckInstall): pass
class ConvASREncoder(CheckInstall): pass
class ECAPAEncoder(CheckInstall): pass
class ParallelConvASREncoder(CheckInstall): pass
class SpeakerDecoder(CheckInstall): pass
# fmt: on

View file

@ -41,13 +41,23 @@ import torch.nn as nn
import torch.nn.functional as F
from librosa.util import tiny
from torch.autograd import Variable
from torch_stft import STFT
from nemo.collections.asr.parts.preprocessing.perturb import AudioAugmentor
from nemo.collections.asr.parts.preprocessing.segment import AudioSegment
from nemo.collections.common.parts.patch_utils import stft_patch
from nemo.utils import logging
# TODO @blisc: Perhaps refactor instead of import guarding
try:
from torch_stft import STFT
except ModuleNotFoundError:
from nemo.utils.exceptions import CheckInstall
# fmt: off
class STFT(CheckInstall): pass
# fmt: on
CONSTANT = 1e-5

View file

@ -43,8 +43,6 @@ from typing import List, Optional, Union
import librosa
import numpy as np
import soundfile as sf
import webdataset as wd
from omegaconf import DictConfig, OmegaConf
from scipy import signal
from torch.utils.data import IterableDataset
@ -52,6 +50,17 @@ from nemo.collections.asr.parts.preprocessing.segment import AudioSegment
from nemo.collections.common.parts.preprocessing import collections, parsers
from nemo.utils import logging
# TODO @blisc: Perhaps refactor instead of import guarding
HAVE_OMEGACONG_WEBDATASET = True
try:
import webdataset as wd
from omegaconf import DictConfig, OmegaConf
except ModuleNotFoundError:
from nemo.utils.exceptions import LightningNotInstalledException
HAVE_OMEGACONG_WEBDATASET = False
try:
from nemo.collections.asr.parts.utils import numba_utils
@ -792,10 +801,13 @@ def process_augmentations(augmenter) -> Optional[AudioAugmentor]:
if isinstance(augmenter, AudioAugmentor):
return augmenter
if not type(augmenter) in {dict, DictConfig}:
augmenter_types = {dict}
if HAVE_OMEGACONG_WEBDATASET:
augmenter_types = {dict, DictConfig}
if not type(augmenter) in augmenter_types:
raise ValueError("Cannot parse augmenter. Must be a dict or an AudioAugmentor object ")
if isinstance(augmenter, DictConfig):
if HAVE_OMEGACONG_WEBDATASET and isinstance(augmenter, DictConfig):
augmenter = OmegaConf.to_container(augmenter, resolve=True)
augmenter = copy.deepcopy(augmenter)
@ -864,6 +876,8 @@ class AugmentationDataset(IterableDataset):
if bkey in tar_filepaths:
tar_filepaths = tar_filepaths.replace(bkey, "}")
if not HAVE_OMEGACONG_WEBDATASET:
raise LightningNotInstalledException(self)
self.audio_dataset = wd.WebDataset(urls=tar_filepaths, nodesplitter=None)
if shuffle_n > 0:

View file

@ -39,13 +39,20 @@ import random
import librosa
import numpy as np
import soundfile as sf
from kaldiio.matio import read_kaldi
from kaldiio.utils import open_like_kaldi
from pydub import AudioSegment as Audio
from pydub.exceptions import CouldntDecodeError
from nemo.utils import logging
# TODO @blisc: Perhaps refactor instead of import guarding
HAVE_KALDI_PYDUB = True
try:
from kaldiio.matio import read_kaldi
from kaldiio.utils import open_like_kaldi
from pydub import AudioSegment as Audio
from pydub.exceptions import CouldntDecodeError
except ModuleNotFoundError:
HAVE_KALDI_PYDUB = False
available_formats = sf.available_formats()
sf_supported_formats = ["." + i.lower() for i in available_formats.keys()]
@ -149,7 +156,7 @@ class AudioSegment(object):
f"Loading {audio_file} via SoundFile raised RuntimeError: `{e}`. "
f"NeMo will fallback to loading via pydub."
)
elif isinstance(audio_file, str) and audio_file.strip()[-1] == "|":
elif HAVE_KALDI_PYDUB and isinstance(audio_file, str) and audio_file.strip()[-1] == "|":
f = open_like_kaldi(audio_file, "rb")
sample_rate, samples = read_kaldi(f)
if offset > 0:
@ -160,7 +167,7 @@ class AudioSegment(object):
abs_max_value = np.abs(samples).max()
samples = np.array(samples, dtype=np.float) / abs_max_value
if samples is None:
if HAVE_KALDI_PYDUB and samples is None:
try:
samples = Audio.from_file(audio_file)
sample_rate = samples.frame_rate
@ -172,8 +179,12 @@ class AudioSegment(object):
seconds = duration * 1000
samples = samples[: int(seconds)]
samples = np.array(samples.get_array_of_samples())
except CouldntDecodeError as e:
logging.error(f"Loading {audio_file} via pydub raised CouldntDecodeError: `{e}`.")
except CouldntDecodeError as err:
logging.error(f"Loading {audio_file} via pydub raised CouldntDecodeError: `{err}`.")
if samples is None:
libs = "soundfile, kaldiio, and pydub" if HAVE_KALDI_PYDUB else "soundfile"
raise Exception(f"Your audio file {audio_file} could not be decoded. We tried using {libs}.")
return cls(samples, sample_rate, target_sr=target_sr, trim=trim, orig_sr=orig_sr)

View file

@ -12,4 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from nemo.collections.common.callbacks.callbacks import LogEpochTimeCallback
# TODO @blisc: Perhaps refactor instead of import guarding
try:
from nemo.collections.common.callbacks.callbacks import LogEpochTimeCallback
except ModuleNotFoundError:
from nemo.utils.exceptions import CheckInstall
# fmt: off
class LogEpochTimeCallback(CheckInstall): pass
# fmt: on

View file

@ -11,6 +11,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
'''
A collection of simple character based parsers. These parser handle cleaning and tokenization by default.
We currently support English.
'''
import string
from typing import List, Optional
@ -36,6 +42,7 @@ class CharParser:
blank_id: int = -1,
do_normalize: bool = True,
do_lowercase: bool = True,
do_tokenize: bool = True,
):
"""Creates simple mapping char parser.
@ -56,6 +63,7 @@ class CharParser:
self._blank_id = blank_id
self._do_normalize = do_normalize
self._do_lowercase = do_lowercase
self._do_tokenize = do_tokenize
self._labels_map = {label: index for index, label in enumerate(labels)}
self._special_labels = set([label for label in labels if len(label) > 1])
@ -66,8 +74,10 @@ class CharParser:
if text is None:
return None
text_tokens = self._tokenize(text)
if not self._do_tokenize:
return text
text_tokens = self._tokenize(text)
return text_tokens
def _normalize(self, text: str) -> Optional[str]:
@ -97,6 +107,23 @@ class CharParser:
return tokens
def decode(self, str_input):
r_map = {}
for k, v in self._labels_map.items():
r_map[v] = k
r_map[len(self._labels_map)] = "<BOS>"
r_map[len(self._labels_map) + 1] = "<EOS>"
r_map[len(self._labels_map) + 2] = "<P>"
out = []
for i in str_input:
# Skip OOV
if i not in r_map:
continue
out.append(r_map[i.item()])
return "".join(out)
class ENCharParser(CharParser):
"""Incorporates english-specific parsing logic."""

View file

@ -14,7 +14,17 @@
from nemo.collections.common.tokenizers.bytelevel_tokenizers import ByteLevelTokenizer
from nemo.collections.common.tokenizers.char_tokenizer import CharTokenizer
from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import AutoTokenizer
from nemo.collections.common.tokenizers.sentencepiece_tokenizer import SentencePieceTokenizer
from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec
from nemo.collections.common.tokenizers.word_tokenizer import WordTokenizer
# TODO @blisc: Perhaps refactor instead of import guarding
try:
from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import AutoTokenizer
from nemo.collections.common.tokenizers.sentencepiece_tokenizer import SentencePieceTokenizer
except ModuleNotFoundError:
from nemo.utils.exceptions import CheckInstall
# fmt: off
class AutoTokenizer(CheckInstall): pass
class SentencePieceTokenizer(CheckInstall): pass
# fmt: on

View file

@ -23,10 +23,22 @@ from numba import jit, prange
from numpy import ndarray
from pesq import pesq
from pystoi import stoi
from pytorch_lightning.utilities import rank_zero_only
from nemo.utils import logging
try:
from pytorch_lightning.utilities import rank_zero_only
except ModuleNotFoundError:
from functools import wraps
def rank_zero_only(fn):
@wraps(fn)
def wrapped_fn(*args, **kwargs):
logging.error(
f"Function {fn} requires lighting to be installed, but it was not found. Please install lightning first"
)
exit(1)
class OperationMode(Enum):
"""Training or Inference (Evaluation) mode"""

View file

@ -12,22 +12,25 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from nemo.collections.tts.models.aligner import AlignerModel
from nemo.collections.tts.models.degli import DegliModel
from nemo.collections.tts.models.ed_mel2spec import EDMel2SpecModel
from nemo.collections.tts.models.fastpitch import FastPitchModel
from nemo.collections.tts.models.fastpitch_hifigan_e2e import FastPitchHifiGanE2EModel
from nemo.collections.tts.models.fastspeech2 import FastSpeech2Model
from nemo.collections.tts.models.fastspeech2_hifigan_e2e import FastSpeech2HifiGanE2EModel
from nemo.collections.tts.models.glow_tts import GlowTTSModel
from nemo.collections.tts.models.hifigan import HifiGanModel
from nemo.collections.tts.models.melgan import MelGanModel
from nemo.collections.tts.models.squeezewave import SqueezeWaveModel
from nemo.collections.tts.models.tacotron2 import Tacotron2Model
from nemo.collections.tts.models.talknet import TalkNetDursModel, TalkNetPitchModel, TalkNetSpectModel
from nemo.collections.tts.models.two_stages import GriffinLimModel, MelPsuedoInverseModel, TwoStagesModel
from nemo.collections.tts.models.uniglow import UniGlowModel
from nemo.collections.tts.models.waveglow import WaveGlowModel
try:
from nemo.collections.tts.models.aligner import AlignerModel
from nemo.collections.tts.models.degli import DegliModel
from nemo.collections.tts.models.ed_mel2spec import EDMel2SpecModel
from nemo.collections.tts.models.fastpitch import FastPitchModel
from nemo.collections.tts.models.fastpitch_hifigan_e2e import FastPitchHifiGanE2EModel
from nemo.collections.tts.models.fastspeech2 import FastSpeech2Model
from nemo.collections.tts.models.fastspeech2_hifigan_e2e import FastSpeech2HifiGanE2EModel
from nemo.collections.tts.models.glow_tts import GlowTTSModel
from nemo.collections.tts.models.hifigan import HifiGanModel
from nemo.collections.tts.models.melgan import MelGanModel
from nemo.collections.tts.models.squeezewave import SqueezeWaveModel
from nemo.collections.tts.models.tacotron2 import Tacotron2Model
from nemo.collections.tts.models.talknet import TalkNetDursModel, TalkNetPitchModel, TalkNetSpectModel
from nemo.collections.tts.models.two_stages import GriffinLimModel, MelPsuedoInverseModel, TwoStagesModel
from nemo.collections.tts.models.uniglow import UniGlowModel
from nemo.collections.tts.models.waveglow import WaveGlowModel
except ModuleNotFoundError:
pass
__all__ = [
"GlowTTSModel",

View file

@ -44,8 +44,8 @@ import torch.nn.functional as F
from torch.nn import AvgPool1d, Conv1d, Conv2d, ConvTranspose1d
from torch.nn.utils import remove_weight_norm, spectral_norm, weight_norm
from nemo.core.classes import NeuralModule
from nemo.core.classes.common import typecheck
from nemo.core.classes.module import NeuralModule
from nemo.core.neural_types.elements import AudioSignal, MelSpectrogramType, VoidType
from nemo.core.neural_types.neural_type import NeuralType

View file

@ -0,0 +1,13 @@
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View file

@ -0,0 +1,521 @@
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import pickle
from pathlib import Path
from typing import Dict, Optional
import librosa
import torch
from nemo.collections.asr.data.vocabs import Base, Phonemes
from nemo.collections.asr.parts.preprocessing.features import WaveformFeaturizer
from nemo.collections.common.parts.patch_utils import stft_patch
from nemo.collections.common.parts.preprocessing.parsers import make_parser
from nemo.collections.tts.torch.helpers import beta_binomial_prior_distribution
from nemo.core.classes import Dataset
from nemo.core.neural_types.elements import *
from nemo.core.neural_types.neural_type import NeuralType
from nemo.utils import logging
CONSTANT = 1e-5
class CharMelAudioDataset(Dataset):
@property
def output_types(self) -> Optional[Dict[str, NeuralType]]:
return {
'transcripts': NeuralType(('B', 'T'), TokenIndex()),
'transcript_length': NeuralType(('B'), LengthsType()),
'mels': NeuralType(('B', 'D', 'T'), TokenIndex()),
'mel_length': NeuralType(('B'), LengthsType()),
'audio': NeuralType(('B', 'T'), AudioSignal()),
'audio_length': NeuralType(('B'), LengthsType()),
'duration_prior': NeuralType(('B', 'T'), TokenDurationType()),
'pitches': NeuralType(('B', 'T'), RegressionValuesType()),
'energies': NeuralType(('B', 'T'), RegressionValuesType()),
}
def __init__(
self,
manifest_filepath: str,
sample_rate: int,
supplementary_folder: Path,
max_duration: Optional[float] = None,
min_duration: Optional[float] = None,
ignore_file: Optional[str] = None,
trim: bool = False,
n_fft=1024,
win_length=None,
hop_length=None,
window="hann",
n_mels=64,
lowfreq=0,
highfreq=None,
pitch_fmin=80,
pitch_fmax=640,
pitch_avg=0,
pitch_std=1,
tokenize_text=True,
):
"""Dataset that loads audio, log mel specs, text tokens, duration / attention priors, pitches, and energies.
Log mels, priords, pitches, and energies will be computed on the fly and saved in the supplementary_folder if
they did not exist before.
Args:
manifest_filepath (str, Path, List[str, Path]): Path(s) to the .json manifests containing information on the
dataset. Each line in the .json file should be valid json. Note: the .json file itself is not valid
json. Each line should contain the following:
"audio_filepath": <PATH_TO_WAV>
"mel_filepath": <PATH_TO_LOG_MEL_PT> (Optional)
"duration": <Duration of audio clip in seconds> (Optional)
"text": <THE_TRANSCRIPT> (Optional)
sample_rate (int): The sample rate of the audio. Or the sample rate that we will resample all files to.
supplementary_folder (Path): A folder that contains or will contain extra information such as log_mel if not
specified in the manifest .json file. It will also contain priors, pitches, and energies
max_duration (Optional[float]): Max duration of audio clips in seconds. All samples exceeding this will be
pruned prior to training. Note: Requires "duration" to be set in the manifest file. It does not load
audio to compute duration. Defaults to None which does not prune.
min_duration (Optional[float]): Min duration of audio clips in seconds. All samples lower than this will be
pruned prior to training. Note: Requires "duration" to be set in the manifest file. It does not load
audio to compute duration. Defaults to None which does not prune.
ignore_file (Optional[str, Path]): The location of a pickle-saved list of audio_ids (the stem of the audio
files) that will be pruned prior to training. Defaults to None which does not prune.
trim (Optional[bool]): Whether to apply librosa.effects.trim to the audio file. Defaults to False.
n_fft (Optional[int]): The number of fft samples. Defaults to 1024
win_length (Optional[int]): The length of the stft windows. Defaults to None which uses n_fft.
hop_length (Optional[int]): The hope length between fft computations. Defaults to None which uses n_fft//4.
window (Optional[str]): One of 'hann', 'hamming', 'blackman','bartlett', 'none'. Which corresponds to the
equivalent torch window function.
n_mels (Optional[int]): The number of mel filters. Defaults to 64.
lowfreq (Optional[int]): The lowfreq input to the mel filter calculation. Defaults to 0.
highfreq (Optional[int]): The highfreq input to the mel filter calculation. Defaults to None.
pitch_fmin (Optional[int]): The fmin input to librosa.pyin. Defaults to None.
pitch_fmax (Optional[int]): The fmax input to librosa.pyin. Defaults to None.
pitch_avg (Optional[float]): The mean that we use to normalize the pitch. Defaults to 0.
pitch_std (Optional[float]): The std that we use to normalize the pitch. Defaults to 1.
tokenize_text (Optional[bool]): Whether to tokenize (turn chars into ints). Defaults to True.
"""
super().__init__()
self.pitch_fmin = pitch_fmin
self.pitch_fmax = pitch_fmax
self.pitch_avg = pitch_avg
self.pitch_std = pitch_std
self.win_length = win_length or n_fft
self.sample_rate = sample_rate
self.hop_len = hop_length or n_fft // 4
self.parser = make_parser(name="en", do_tokenize=tokenize_text)
self.pad_id = self.parser._blank_id
Path(supplementary_folder).mkdir(parents=True, exist_ok=True)
self.supplementary_folder = supplementary_folder
audio_files = []
total_duration = 0
# Load data from manifests
# Note: audio is always required, even for text -> mel_spectrogram models, due to the fact that most models
# extract pitch from the audio
# Note: mel_filepath is not required and if not present, we then check the supplementary folder. If we fail, we
# compute the mel on the fly and save it to the supplementary folder
# Note: text is not required. Any models that require on text (spectrogram generators, end-to-end models) will
# fail if not set. However vocoders (mel -> audio) will be able to work without text
if isinstance(manifest_filepath, str):
manifest_filepath = [manifest_filepath]
for manifest_file in manifest_filepath:
with open(Path(manifest_file).expanduser(), 'r') as f:
logging.info(f"Loading dataset from {manifest_file}.")
for line in f:
item = json.loads(line)
# Grab audio, text, mel if they exist
file_info = {}
file_info["audio_filepath"] = item["audio_filepath"]
file_info["mel_filepath"] = item["mel_filepath"] if "mel_filepath" in item else None
file_info["duration"] = item["duration"] if "duration" in item else None
# Parse text
file_info["text_tokens"] = None
if "text" in item:
text = item["text"]
text_tokens = self.parser(text)
file_info["text_tokens"] = text_tokens
audio_files.append(file_info)
if file_info["duration"] is None:
logging.info(
"Not all audio files have duration information. Duration logging will be disabled."
)
total_duration = None
if total_duration is not None:
total_duration += item["duration"]
logging.info(f"Loaded dataset with {len(audio_files)} files.")
if total_duration is not None:
logging.info(f"Dataset contains {total_duration/3600:.2f} hours.")
self.data = []
if ignore_file:
logging.info(f"using {ignore_file} to prune dataset.")
with open(Path(ignore_file).expanduser(), "rb") as f:
wavs_to_ignore = set(pickle.load(f))
pruned_duration = 0 if total_duration is not None else None
pruned_items = 0
for item in audio_files:
audio_path = item['audio_filepath']
audio_id = Path(audio_path).stem
# Prune data according to min/max_duration & the ignore file
if total_duration is not None:
if (min_duration and item["duration"] < min_duration) or (
max_duration and item["duration"] > max_duration
):
pruned_duration += item["duration"]
pruned_items += 1
continue
if ignore_file and (audio_id in wavs_to_ignore):
pruned_items += 1
pruned_duration += item["duration"]
wavs_to_ignore.remove(audio_id)
continue
self.data.append(item)
logging.info(f"Pruned {pruned_items} files. Final dataset contains {len(self.data)} files")
if pruned_duration is not None:
logging.info(
f"Pruned {pruned_duration/3600:.2f} hours. Final dataset contains "
f"{(total_duration-pruned_duration)/3600:.2f} hours."
)
self.featurizer = WaveformFeaturizer(sample_rate=sample_rate)
self.trim = trim
filterbanks = torch.tensor(
librosa.filters.mel(sample_rate, n_fft, n_mels=n_mels, fmin=lowfreq, fmax=highfreq), dtype=torch.float
).unsqueeze(0)
self.fb = filterbanks
torch_windows = {
'hann': torch.hann_window,
'hamming': torch.hamming_window,
'blackman': torch.blackman_window,
'bartlett': torch.bartlett_window,
'none': None,
}
window_fn = torch_windows.get(window, None)
window_tensor = window_fn(self.win_length, periodic=False) if window_fn else None
self.stft = lambda x: stft_patch(
input=x,
n_fft=n_fft,
hop_length=self.hop_len,
win_length=self.win_length,
window=window_tensor.to(torch.float),
)
def __getitem__(self, index):
spec = None
sample = self.data[index]
features = self.featurizer.process(sample["audio_filepath"], trim=self.trim)
audio, audio_length = features, torch.tensor(features.shape[0]).long()
if isinstance(sample["text_tokens"], str):
# If tokenize_text is False for Phone dataset
text = sample["text_tokens"]
text_length = None
else:
text = torch.tensor(sample["text_tokens"]).long()
text_length = torch.tensor(len(sample["text_tokens"])).long()
audio_stem = Path(sample["audio_filepath"]).stem
# Load mel if it exists
mel_path = sample["mel_filepath"]
if mel_path and Path(mel_path).exists():
log_mel = torch.load(mel_path)
else:
mel_path = Path(self.supplementary_folder) / f"mel_{audio_stem}.pt"
if mel_path.exists():
log_mel = torch.load(mel_path)
else:
# disable autocast to get full range of stft values
with torch.cuda.amp.autocast(enabled=False):
spec = self.stft(audio)
# guard is needed for sqrt if grads are passed through
guard = CONSTANT # TODO: Enable 0 if not self.use_grads else CONSTANT
if spec.dtype in [torch.cfloat, torch.cdouble]:
spec = torch.view_as_real(spec)
spec = torch.sqrt(spec.pow(2).sum(-1) + guard)
mel = torch.matmul(self.fb.to(spec.dtype), spec)
log_mel = torch.log(torch.clamp(mel, min=torch.finfo(mel.dtype).tiny))
torch.save(log_mel, mel_path)
log_mel = log_mel.squeeze(0)
log_mel_length = torch.tensor(log_mel.shape[1]).long()
duration_prior = None
if text_length is not None:
### Make duration attention prior if not exist in the supplementary folder
prior_path = Path(self.supplementary_folder) / f"pr_tl{text_length}_al_{log_mel_length}.pt"
if prior_path.exists():
duration_prior = torch.load(prior_path)
else:
duration_prior = beta_binomial_prior_distribution(text_length, log_mel_length)
duration_prior = torch.from_numpy(duration_prior)
torch.save(duration_prior, prior_path)
# Load pitch file (F0s)
pitch_path = (
Path(self.supplementary_folder)
/ f"{audio_stem}_pitch_pyin_fmin{self.pitch_fmin}_fmax{self.pitch_fmax}_fl{self.win_length}_hs{self.hop_len}.pt"
)
if pitch_path.exists():
pitch = torch.load(pitch_path)
else:
pitch, _, _ = librosa.pyin(
audio.numpy(),
fmin=self.pitch_fmin,
fmax=self.pitch_fmax,
frame_length=self.win_length,
sr=self.sample_rate,
fill_na=0.0,
)
pitch = torch.from_numpy(pitch)
torch.save(pitch, pitch_path)
# Standize pitch
pitch -= self.pitch_avg
pitch[pitch == -self.pitch_avg] = 0.0 # Zero out values that were perviously zero
pitch /= self.pitch_std
# Load energy file (L2-norm of the amplitude of each STFT frame of an utterance)
energy_path = Path(self.supplementary_folder) / f"{audio_stem}_energy_wl{self.win_length}_hs{self.hop_len}.pt"
if energy_path.exists():
energy = torch.load(energy_path)
else:
if spec is None:
spec = self.stft(audio)
energy = torch.linalg.norm(spec.squeeze(0), axis=0)
# Save to new file
torch.save(energy, energy_path)
return text, text_length, log_mel, log_mel_length, audio, audio_length, duration_prior, pitch, energy
def __len__(self):
return len(self.data)
def _collate_fn(self, batch):
log_mel_pad = torch.finfo(batch[0][2].dtype).tiny
_, tokens_lengths, _, log_mel_lengths, _, audio_lengths, duration_priors_list, pitches, energies = zip(*batch)
max_tokens_len = max(tokens_lengths).item()
max_log_mel_len = max(log_mel_lengths)
max_audio_len = max(audio_lengths).item()
max_pitches_len = max([len(i) for i in pitches])
max_energies_len = max([len(i) for i in energies])
if max_pitches_len != max_energies_len or max_pitches_len != max_log_mel_len:
logging.warning(
f"max_pitches_len: {max_pitches_len} != max_energies_len: {max_energies_len} != "
f"max_mel_len:{max_log_mel_len}. Your training run will error out!"
)
# Define empty lists to be batched
duration_priors = torch.zeros(
len(duration_priors_list),
max([prior_i.shape[0] for prior_i in duration_priors_list]),
max([prior_i.shape[1] for prior_i in duration_priors_list]),
)
audios, tokens, log_mels, pitches, energies = [], [], [], [], []
for i, sample_tuple in enumerate(batch):
token, token_len, log_mel, log_mel_len, audio, audio_len, duration_prior, pitch, energy = sample_tuple
# Pad text tokens
token_len = token_len.item()
if token_len < max_tokens_len:
pad = (0, max_tokens_len - token_len)
token = torch.nn.functional.pad(token, pad, value=self.pad_id)
tokens.append(token)
# Pad mel
log_mel_len = log_mel_len
if log_mel_len < max_log_mel_len:
pad = (0, max_log_mel_len - log_mel_len)
log_mel = torch.nn.functional.pad(log_mel, pad, value=log_mel_pad)
log_mels.append(log_mel)
# Pad audio
audio_len = audio_len.item()
if audio_len < max_audio_len:
pad = (0, max_audio_len - audio_len)
audio = torch.nn.functional.pad(audio, pad)
audios.append(audio)
# Pad duration_prior
duration_priors[i, : duration_prior.shape[0], : duration_prior.shape[1]] = duration_prior
# Pad pitch
if len(pitch) < max_pitches_len:
pad = (0, max_pitches_len - len(pitch))
pitch = torch.nn.functional.pad(pitch, pad)
pitches.append(pitch)
# Pad energy
if len(energy) < max_energies_len:
pad = (0, max_energies_len - len(energy))
energy = torch.nn.functional.pad(energy, pad)
energies.append(energy)
audios = torch.stack(audios)
audio_lengths = torch.stack(audio_lengths)
tokens = torch.stack(tokens)
tokens_lengths = torch.stack(tokens_lengths)
log_mels = torch.stack(log_mels)
log_mel_lengths = torch.stack(log_mel_lengths)
pitches = torch.stack(pitches)
energies = torch.stack(energies)
logging.debug(f"audios: {audios.shape}")
logging.debug(f"audio_lengths: {audio_lengths.shape}")
logging.debug(f"tokens: {tokens.shape}")
logging.debug(f"tokens_lengths: {tokens_lengths.shape}")
logging.debug(f"log_mels: {log_mels.shape}")
logging.debug(f"log_mel_lengths: {log_mel_lengths.shape}")
logging.debug(f"duration_priors: {duration_priors.shape}")
logging.debug(f"pitches: {pitches.shape}")
logging.debug(f"energies: {energies.shape}")
return (tokens, tokens_lengths, log_mels, log_mel_lengths, duration_priors, pitches, energies)
def decode(self, tokens):
assert len(tokens.squeeze().shape) in [0, 1]
return self.parser.decode(tokens)
class PhoneMelAudioDataset(CharMelAudioDataset):
@property
def output_types(self) -> Optional[Dict[str, NeuralType]]:
return {
'transcripts': NeuralType(('B', 'T'), TokenIndex()),
'transcript_length': NeuralType(('B'), LengthsType()),
'mels': NeuralType(('B', 'D', 'T'), TokenIndex()),
'mel_length': NeuralType(('B'), LengthsType()),
'audio': NeuralType(('B', 'T'), AudioSignal()),
'audio_length': NeuralType(('B'), LengthsType()),
'duration_prior': NeuralType(('B', 'T'), TokenDurationType()),
'pitches': NeuralType(('B', 'T'), RegressionValuesType()),
'energies': NeuralType(('B', 'T'), RegressionValuesType()),
}
def __init__(
self,
punct=True,
stresses=False,
spaces=True,
chars=False,
space=' ',
silence=None,
apostrophe=True,
oov=Base.OOV,
sep='|',
add_blank_at="last_but_one",
pad_with_space=False,
improved_version_g2p=False,
phoneme_dict_path=None,
**kwargs,
):
"""Dataset which extends CharMelAudioDataset to load phones in place of characters. It returns audio, log mel
specs, phone tokens, duration / attention priors, pitches, and energies. Log mels, priords, pitches, and
energies will be computed on the fly and saved in the supplementary_folder if they did not exist before. These
supplementary files can be shared with CharMelAudioDataset.
Args:
punct (bool): Whether to keep punctuation in the input. Defaults to True
stresses (bool): Whether to add phone stresses in the input. Defaults to False
spaces (bool): Whether to encode space characters. Defaults to True
chars (bool): Whether to use add characters to the labels map. NOTE: The current parser class does not
actually parse transcripts to characters. Defaults to False
space (str): The space character. Defaults to ' '
silence (bool): Whether to use add silence tokens. Defaults to False
apostrophe (bool): Whether to use keep apostrophes. Defaults to True
oov (str): How out of vocabulary tokens are decoded. Defaults to Base.OOV == "<oov>"
sep (str): How to seperate phones when tokens are decoded. Defaults to "|"
add_blank_at (str): Where to add the blank symbol that is used in CTC. Can be None which does not add a
blank token in the vocab, "last" which makes self.vocab.labels[-1] the blank token, or
"last_but_one" which makes self.vocab.labels[-2] the blank token
pad_with_space (bool): Whether to use pad input with space tokens at start and end. Defaults to False
improved_version_g2p (bool): Defaults to False
phoneme_dict_path (path): Location of cmudict. Defaults to None which means the code will download it
automatically
"""
if "tokenize_text" in kwargs:
tokenize_text = kwargs.pop("tokenize_text")
if not tokenize_text:
logging.warning(
f"{self} requires tokenize_text to be False. Setting it to False and ignoring provided value of "
f"{tokenize_text}"
)
super().__init__(tokenize_text=False, **kwargs)
self.vocab = Phonemes(
punct=punct,
stresses=stresses,
spaces=spaces,
chars=chars,
add_blank_at=add_blank_at,
pad_with_space=pad_with_space,
improved_version_g2p=improved_version_g2p,
phoneme_dict_path=phoneme_dict_path,
space=space,
silence=silence,
apostrophe=apostrophe,
oov=oov,
sep=sep,
)
self.pad_id = self.vocab.pad
def __getitem__(self, index):
(text, _, log_mel, log_mel_length, audio, audio_length, _, pitch, energy) = super().__getitem__(index)
phones_tokenized = torch.tensor(self.vocab.encode(text)).long()
phones_length = torch.tensor(len(phones_tokenized)).long()
### Make duration attention prior if not exist in the supplementary folder
prior_path = Path(self.supplementary_folder) / f"pr_tl{phones_length}_al_{log_mel_length}.pt"
if prior_path.exists():
duration_prior = torch.load(prior_path)
else:
duration_prior = beta_binomial_prior_distribution(phones_length, log_mel_length)
duration_prior = torch.from_numpy(duration_prior)
torch.save(duration_prior, prior_path)
return (
phones_tokenized,
phones_length,
log_mel,
log_mel_length,
audio,
audio_length,
duration_prior,
pitch,
energy,
)
def __len__(self):
return len(self.data)
def decode(self, tokens):
"""
Accepts a singule list of tokens, not a batch
"""
assert len(tokens.squeeze().shape) in [0, 1]
return self.vocab.decode(tokens)

View file

@ -0,0 +1,26 @@
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
from scipy.stats import betabinom
def beta_binomial_prior_distribution(phoneme_count, mel_count, scaling_factor=1.0):
x = np.arange(0, phoneme_count)
mel_text_probs = []
for i in range(1, mel_count + 1):
a, b = scaling_factor * i, scaling_factor * (mel_count + 1 - i)
mel_i_prob = betabinom(phoneme_count, a, b).pmf(x)
mel_text_probs.append(mel_i_prob)
return np.array(mel_text_probs)

View file

@ -0,0 +1,106 @@
# Torch TTS Collection
This section of code can be used by installing the requirements inside our *requirements.txt* and *requirements_torch_tts.txt*.
## Install
This collection can be installed in the following ways:
- pip install from github
> ```bash
> pip install git+https://github.com/nvidia/NeMo.git#egg=nemo_toolkit[torch_tts]
> ```
- inside a requirements file
> `git+https://github.com/nvidia/NeMo.git#egg=nemo_toolkit[torch_tts]`
- cloning from github, and then installing
> ```bash
> git clone https://github.com/nvidia/NeMo.git && cd NeMo && pip install ".[torch_tts]"
> ```
## Usage
We can check that lightning is not installed by checking pip:
```bash
pip list | grep lightning
```
Now even though lightning isn't installed, we can still use parts from the torch_tts collection.
### TTS Dataset
Let's import our dataset class and then loop through the batches. Note that in the sample .json files, we only have text
and audio. Our dataset will then create the log_mels, priors, pitches, and energies and store them in `supplementary_folder`
which in this case is `./debug0`.
```python
import torch
from nemo.collections.tts.torch.data import CharMelAudioDataset
dataset = CharMelAudioDataset(
manifest_filepath="<PATH_TO_MANIFEST_JSON>", # Path to file that describes the location of audio and text
sample_rate=22050,
supplementary_folder="./debug0", # An additional folder that will store log_mels, priors, pitches, and energies
max_duration=20., # Max duration of samples in seconds
min_duration=0.1, # Min duration of samples in seconds
ignore_file=None,
trim=False, # Whether to use librosa.effects.trim
n_fft=1024,
win_length=1024,
hop_length=256,
window="hann",
n_mels=64, # Number of mel filters
lowfreq=0, # lowfreq for mel filters
highfreq=8000, # highfreq for mel filters
pitch_fmin=80,
pitch_fmax=640,
)
dataloader = torch.utils.data.DataLoader(dataset, 10, collate_fn=dataset._collate_fn)
for batch in dataloader:
tokens, tokens_lengths, log_mels, log_mel_lengths, duration_priors, pitches, energies = batch
## Train models, etc.
# Tokens represent already tokenized characters which probably will not work with previous tokenziers
# You can get the label map from dataset.parser._labels_map. You can tokenize text via dataset.parser("text!")
# You can detokenize using dataset.decode()
```
```python
import torch
from nemo.collections.tts.torch.data import PhoneMelAudioDataset
dataset = PhoneMelAudioDataset(
manifest_filepath="<PATH_TO_MANIFEST_JSON>", # Path to file that describes the location of audio and text
sample_rate=22050,
supplementary_folder="./debug0", # An additional folder that will store log_mels, priors, pitches, and energies
)
dataloader = torch.utils.data.DataLoader(dataset, 10, collate_fn=dataset._collate_fn)
for batch in dataloader:
tokens, tokens_lengths, log_mels, log_mel_lengths, duration_priors, pitches, energies = batch
## Train models, etc.
# Tokens represent already tokenized characters which probably will not work with previous tokenziers
# You can tokenize via dataset.vocab.encode(), and go backwards with dataset.vocab.decode().
```
## NeMo Features
If you look into the code we see that `TextMelAudioDataset` is a child of `nemo.core.classes.Dataset`. **You do not have to subclass this to add to the torch_tts repo**. It is sufficient to use `torch.utils.data.Dataset`. Using the nemo class adds some additional features not present in torch:
- *(Optional)* Adding typing information
- Looking at `TextMelAudioDataset`, it has a `output_types` that tells us how tensors are organized. For example, the `mels` returned by this dataset has dimensions B x D x T, which is short for saying the first dimension represents batch, the second represents a generic channels / n_mel_filters dimension, and the last represents time
- *(Optional)* *(ToDo)* Enables serialization
- We can now call `to_config_dict()` to return a dictionary which we can now pass to `from_config_dict()` to create another instanace of the dataset with the same arguments allowing us to easily restate code using these dictionaries. Please note to change any local paths if changing computers.
## ToDos
- [ ] Populate *torch_tts*
- [x] Create a new datalayer that can be used interchangeably
- [x] Add phone support
- [ ] Add TTS models
- [ ] Split Lightning away from core
- [x] v0.1 that import checks a lot of lightning
- [ ] Split up code (core, collections, utils) better
- [ ] Enable building *text_normlization* without installing lightning
- [ ] Look into how `Serialization` works without hydra

View file

@ -25,5 +25,17 @@ from nemo.core.classes.common import (
from nemo.core.classes.dataset import Dataset, IterableDataset
from nemo.core.classes.exportable import Exportable, ExportFormat
from nemo.core.classes.loss import Loss
from nemo.core.classes.modelPT import ModelPT
from nemo.core.classes.module import NeuralModule
from nemo.utils import exceptions
# TODO @blisc: Perhaps refactor instead of import guarding
try:
import pytorch_lightning
import hydra
import omegaconf
from nemo.core.classes.modelPT import ModelPT
except ModuleNotFoundError:
from nemo.utils.exceptions import CheckInstall
class ModelPT(CheckInstall):
pass

View file

@ -24,17 +24,28 @@ from functools import total_ordering
from pathlib import Path
from typing import Dict, List, Optional, Union
import hydra
import wrapt
from omegaconf import DictConfig, OmegaConf
import nemo
from nemo.core.connectors.save_restore_connector import SaveRestoreConnector
from nemo.core.neural_types import NeuralType, NeuralTypeComparisonResult
from nemo.utils import logging
from nemo.utils.cloud import maybe_download_from_cloud
from nemo.utils.model_utils import import_class_by_path, maybe_update_config_version
# TODO @blisc: Perhaps refactor instead of import guarding
_HAS_HYDRA = True
try:
import hydra
from omegaconf import DictConfig, OmegaConf
from nemo.core.connectors.save_restore_connector import SaveRestoreConnector
except ModuleNotFoundError:
_HAS_HYDRA = False
from nemo.utils.exceptions import CheckInstall
class SaveRestoreConnector(CheckInstall):
pass
__all__ = ['Typing', 'FileIO', 'Model', 'Serialization', 'typecheck']
@ -418,22 +429,23 @@ class Typing(ABC):
class Serialization(ABC):
@classmethod
def from_config_dict(cls, config: DictConfig):
def from_config_dict(cls, config: 'DictConfig'):
"""Instantiates object using DictConfig-based configuration"""
# Resolve the config dict
if isinstance(config, DictConfig):
config = OmegaConf.to_container(config, resolve=True)
config = OmegaConf.create(config)
OmegaConf.set_struct(config, True)
if _HAS_HYDRA:
if isinstance(config, DictConfig):
config = OmegaConf.to_container(config, resolve=True)
config = OmegaConf.create(config)
OmegaConf.set_struct(config, True)
config = maybe_update_config_version(config)
config = maybe_update_config_version(config)
# Hydra 0.x API
if ('cls' in config or 'target' in config) and 'params' in config:
if ('cls' in config or 'target' in config) and 'params' in config and _HAS_HYDRA:
# regular hydra-based instantiation
instance = hydra.utils.instantiate(config=config)
# Hydra 1.x API
elif '_target_' in config:
elif '_target_' in config and _HAS_HYDRA:
# regular hydra-based instantiation
instance = hydra.utils.instantiate(config=config)
else:
@ -441,7 +453,7 @@ class Serialization(ABC):
imported_cls_tb = None
# Attempt class path resolution from config `target` class (if it exists)
if 'target' in config:
target_cls = config.target
target_cls = config["target"] # No guarantee that this is a omegaconf class
imported_cls = None
try:
# try to import the target class
@ -476,15 +488,16 @@ class Serialization(ABC):
instance._cfg = config
return instance
def to_config_dict(self) -> DictConfig:
def to_config_dict(self) -> 'DictConfig':
"""Returns object's configuration to config dictionary"""
if hasattr(self, '_cfg') and self._cfg is not None and isinstance(self._cfg, DictConfig):
if hasattr(self, '_cfg') and self._cfg is not None:
# Resolve the config dict
config = OmegaConf.to_container(self._cfg, resolve=True)
config = OmegaConf.create(config)
OmegaConf.set_struct(config, True)
if _HAS_HYDRA and isinstance(self._cfg, DictConfig):
config = OmegaConf.to_container(self._cfg, resolve=True)
config = OmegaConf.create(config)
OmegaConf.set_struct(config, True)
config = maybe_update_config_version(config)
config = maybe_update_config_version(config)
self._cfg = config

View file

@ -156,7 +156,7 @@ class ModelPT(LightningModule, Model):
self, config_path: str, src: str, verify_src_exists: bool = True,
):
""" Register model artifacts with this function. These artifacts (files) will be included inside .nemo file
when model.save_to("mymodel.nemo") is called.
when model.save_to("mymodel.nemo") is called.
How it works:
1. It always returns existing absolute path which can be used during Model constructor call
@ -174,7 +174,7 @@ class ModelPT(LightningModule, Model):
Args:
config_path (str): Artifact key. Usually corresponds to the model config.
src (str): Path to artifact.
verify_src_exists (bool): If set to False, then the artifact is optional and register_artifact will return None even if
verify_src_exists (bool): If set to False, then the artifact is optional and register_artifact will return None even if
src is not found. Defaults to True.
save_restore_connector (SaveRestoreConnector): Can be overrided to add custom save and restore logic.

View file

@ -16,7 +16,11 @@
from nemo.utils.app_state import AppState
from nemo.utils.nemo_logging import Logger as _Logger
from nemo.utils.nemo_logging import LogMode as logging_mode
from nemo.utils.lightning_logger_patch import add_memory_handlers_to_pl_logger
logging = _Logger()
add_memory_handlers_to_pl_logger()
try:
from nemo.utils.lightning_logger_patch import add_memory_handlers_to_pl_logger
add_memory_handlers_to_pl_logger()
except ModuleNotFoundError:
pass

View file

@ -17,13 +17,19 @@ import inspect
from dataclasses import is_dataclass
from typing import Dict, List, Optional
from omegaconf import DictConfig, OmegaConf, open_dict
from nemo.core.config.modelPT import NemoConfig
from nemo.utils import logging
# TODO @blisc: Perhaps refactor instead of import guarding
_HAS_HYDRA = True
try:
from omegaconf import DictConfig, OmegaConf, open_dict
except ModuleNotFoundError:
_HAS_HYDRA = False
def update_model_config(model_cls: NemoConfig, update_cfg: DictConfig, drop_missing_subconfigs: bool = True):
def update_model_config(
model_cls: 'nemo.core.config.modelPT.NemoConfig', update_cfg: 'DictConfig', drop_missing_subconfigs: bool = True
):
"""
Helper class that updates the default values of a ModelPT config class with the values
in a DictConfig that mirrors the structure of the config class.
@ -52,6 +58,9 @@ def update_model_config(model_cls: NemoConfig, update_cfg: DictConfig, drop_miss
A DictConfig with updated values that can be used to instantiate the NeMo Model along with supporting
infrastructure.
"""
if not _HAS_HYDRA:
logging.error("This function requires Hydra/Omegaconf and it was not installed.")
exit(1)
if not (is_dataclass(model_cls) or isinstance(model_cls, DictConfig)):
raise ValueError("`model_cfg` must be a dataclass or a structured OmegaConf object")
@ -92,7 +101,7 @@ def update_model_config(model_cls: NemoConfig, update_cfg: DictConfig, drop_miss
def _update_subconfig(
model_cfg: DictConfig, update_cfg: DictConfig, subconfig_key: str, drop_missing_subconfigs: bool
model_cfg: 'DictConfig', update_cfg: 'DictConfig', subconfig_key: str, drop_missing_subconfigs: bool
):
"""
Updates the NemoConfig DictConfig such that:
@ -112,6 +121,9 @@ def _update_subconfig(
Returns:
The updated DictConfig for the NemoConfig
"""
if not _HAS_HYDRA:
logging.error("This function requires Hydra/Omegaconf and it was not installed.")
exit(1)
with open_dict(model_cfg.model):
# If update config has the key, but model cfg doesnt have the key
# Add the update cfg subconfig to the model cfg
@ -127,7 +139,7 @@ def _update_subconfig(
return model_cfg
def _add_subconfig_keys(model_cfg: DictConfig, update_cfg: DictConfig, subconfig_key: str):
def _add_subconfig_keys(model_cfg: 'DictConfig', update_cfg: 'DictConfig', subconfig_key: str):
"""
For certain sub-configs, the default values specified by the NemoConfig class is insufficient.
In order to support every potential value in the merge between the `update_cfg`, it would require
@ -155,6 +167,9 @@ def _add_subconfig_keys(model_cfg: DictConfig, update_cfg: DictConfig, subconfig
Returns:
A ModelPT DictConfig with additional keys added to the sub-config.
"""
if not _HAS_HYDRA:
logging.error("This function requires Hydra/Omegaconf and it was not installed.")
exit(1)
with open_dict(model_cfg.model):
# Create copy of original model sub config
if subconfig_key in update_cfg.model:

View file

@ -16,4 +16,22 @@
class NeMoBaseException(Exception):
""" NeMo Base Exception. All exceptions created in NeMo should inherit from this class"""
pass
class LightningNotInstalledException(NeMoBaseException):
def __init__(self, obj):
message = (
f" You are trying to use {obj} without installing all of pytorch_lightning, hydra, and "
f"omegaconf. Please install those packages before trying to access {obj}."
)
super().__init__(message)
class CheckInstall:
def __init__(self, *args, **kwargs):
raise LightningNotInstalledException(self)
def __call__(self, *args, **kwargs):
raise LightningNotInstalledException(self)
def __getattr__(self, *args, **kwargs):
raise LightningNotInstalledException(self)

View file

@ -19,14 +19,22 @@ from enum import Enum
from pathlib import Path
from typing import List, Optional, Union
import pytorch_lightning as pl
import wrapt
from omegaconf import DictConfig, ListConfig, OmegaConf
from omegaconf import errors as omegaconf_errors
from packaging import version
from nemo.utils import logging
# TODO @blisc: Perhaps refactor instead of import guarding
_HAS_HYDRA = True
try:
from omegaconf import DictConfig, ListConfig, OmegaConf
from omegaconf import errors as omegaconf_errors
from packaging import version
except ModuleNotFoundError:
_HAS_HYDRA = False
_VAL_TEST_FASTPATH_KEY = 'ds_item'
@ -49,7 +57,7 @@ class ArtifactItem:
hashed_path: Optional[str] = None
def resolve_dataset_name_from_cfg(cfg: DictConfig) -> str:
def resolve_dataset_name_from_cfg(cfg: 'DictConfig') -> str:
"""
Parses items of the provided sub-config to find the first potential key that
resolves to an existing file or directory.
@ -218,6 +226,9 @@ def resolve_validation_dataloaders(model: 'ModelPT'):
Args:
model: ModelPT subclass, which requires >=1 Validation Dataloaders to be setup.
"""
if not _HAS_HYDRA:
logging.error("This function requires Hydra/Omegaconf and it was not installed.")
exit(1)
cfg = copy.deepcopy(model._cfg)
dataloaders = []
@ -287,6 +298,9 @@ def resolve_test_dataloaders(model: 'ModelPT'):
Args:
model: ModelPT subclass, which requires >=1 Test Dataloaders to be setup.
"""
if not _HAS_HYDRA:
logging.error("This function requires Hydra/Omegaconf and it was not installed.")
exit(1)
cfg = copy.deepcopy(model._cfg)
dataloaders = []
@ -335,7 +349,7 @@ def resolve_test_dataloaders(model: 'ModelPT'):
@wrapt.decorator
def wrap_training_step(wrapped, instance: pl.LightningModule, args, kwargs):
def wrap_training_step(wrapped, instance: 'pl.LightningModule', args, kwargs):
output_dict = wrapped(*args, **kwargs)
if isinstance(output_dict, dict) and output_dict is not None and 'log' in output_dict:
@ -345,7 +359,7 @@ def wrap_training_step(wrapped, instance: pl.LightningModule, args, kwargs):
return output_dict
def convert_model_config_to_dict_config(cfg: Union[DictConfig, 'NemoConfig']) -> DictConfig:
def convert_model_config_to_dict_config(cfg: Union['DictConfig', 'NemoConfig']) -> 'DictConfig':
"""
Converts its input into a standard DictConfig.
Possible input values are:
@ -358,6 +372,9 @@ def convert_model_config_to_dict_config(cfg: Union[DictConfig, 'NemoConfig']) ->
Returns:
The equivalent DictConfig
"""
if not _HAS_HYDRA:
logging.error("This function requires Hydra/Omegaconf and it was not installed.")
exit(1)
if not isinstance(cfg, (OmegaConf, DictConfig)) and is_dataclass(cfg):
cfg = OmegaConf.structured(cfg)
@ -369,8 +386,11 @@ def convert_model_config_to_dict_config(cfg: Union[DictConfig, 'NemoConfig']) ->
return config
def _convert_config(cfg: OmegaConf):
def _convert_config(cfg: 'OmegaConf'):
""" Recursive function convertint the configuration from old hydra format to the new one. """
if not _HAS_HYDRA:
logging.error("This function requires Hydra/Omegaconf and it was not installed.")
exit(1)
# Get rid of cls -> _target_.
if 'cls' in cfg and '_target_' not in cfg:
@ -391,7 +411,7 @@ def _convert_config(cfg: OmegaConf):
logging.warning(f"Skipped conversion for config/subconfig:\n{cfg}\n Reason: {e}.")
def maybe_update_config_version(cfg: DictConfig):
def maybe_update_config_version(cfg: 'DictConfig'):
"""
Recursively convert Hydra 0.x configs to Hydra 1.x configs.
@ -406,6 +426,9 @@ def maybe_update_config_version(cfg: DictConfig):
Returns:
An updated DictConfig that conforms to Hydra 1.x format.
"""
if not _HAS_HYDRA:
logging.error("This function requires Hydra/Omegaconf and it was not installed.")
exit(1)
if cfg is not None and not isinstance(cfg, DictConfig):
try:
temp_cfg = OmegaConf.create(cfg)

View file

@ -1,19 +1,13 @@
numpy>=1.18.2
onnx>=1.7.0
pytorch-lightning>=1.4.0
python-dateutil
torch
torchmetrics>=0.4.1rc0
wget
torch>1.7
wrapt
ruamel.yaml
scikit-learn
omegaconf>=2.1.0
hydra-core>=1.1.0
transformers>=4.8.1
sentencepiece<1.0.0
webdataset>=0.1.48,<=0.1.62
tqdm>=4.41.0
numba
grpcio
grpcio-tools
wget
frozendict
unidecode

View file

@ -1,6 +1,5 @@
braceexpand
editdistance
frozendict
inflect
kaldi-io
librosa
@ -9,7 +8,6 @@ packaging
ruamel.yaml
soundfile
sox
unidecode
kaldi-python-io
kaldiio
scipy

View file

@ -0,0 +1,9 @@
pytorch-lightning>=1.3.0
torchmetrics>=0.4.1rc0
transformers>=4.0.1
webdataset>=0.1.48,<=0.1.62
opencc
pangu
jieba
omegaconf>=2.1.0
hydra-core>=1.1.0

View file

@ -2,7 +2,6 @@ boto3
h5py
matplotlib>=3.3.2
sentencepiece
unidecode
youtokentome>=1.0.5
numpy
rapidfuzz

View file

@ -0,0 +1,7 @@
matplotlib
pypinyin
attrdict
pystoi
pesq
g2p_en
pandas

View file

@ -1,7 +1,2 @@
matplotlib
pypinyin
attrdict
pystoi
pesq
g2p_en
librosa
nltk

View file

@ -29,6 +29,7 @@ markers =
docs: mark tests related to documentation (deselect with '-m "not docs"')
skipduringci: marks tests that are skipped ci as they are addressed by Jenkins jobs but should be run to test user setups
pleasefixme: marks tests that are broken and need fixing
torch_tts: marks tests that are to be run when "torch_tts" is installed but not all nemo packages
[isort]
known_localfolder = nemo,tests

View file

@ -79,18 +79,29 @@ install_requires = req_file("requirements.txt")
extras_require = {
# User packages
'test': req_file("requirements_test.txt"),
# Collections Packages
# NeMo Tools
'text_processing': req_file("requirements_text_processing.txt"),
# Torch Packages
'torch_tts': req_file("requirements_torch_tts.txt"),
# Lightning Collections Packages
'core': req_file("requirements_lightning.txt"),
'asr': req_file("requirements_asr.txt"),
'cv': req_file("requirements_cv.txt"),
'nlp': req_file("requirements_nlp.txt"),
'tts': req_file("requirements_tts.txt"),
'text_processing': req_file("requirements_text_processing.txt"),
}
extras_require['all'] = list(chain(extras_require.values()))
# TTS depends on ASR
extras_require['tts'] = list(chain([extras_require['tts'], extras_require['asr']]))
# Add lightning requirements as needed
extras_require['test'] = list(chain([extras_require['tts'], extras_require['core']]))
extras_require['asr'] = list(chain([extras_require['asr'], extras_require['core']]))
extras_require['cv'] = list(chain([extras_require['cv'], extras_require['core']]))
extras_require['nlp'] = list(chain([extras_require['nlp'], extras_require['core']]))
extras_require['tts'] = list(chain([extras_require['tts'], extras_require['core']]))
# TTS has extra dependencies
extras_require['tts'] = list(chain([extras_require['tts'], extras_require['asr'], extras_require['torch_tts']]))
tests_requirements = extras_require["test"]

View file

@ -0,0 +1,54 @@
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import pytest
import torch
from nemo.collections.tts.torch.data import CharMelAudioDataset
class TestCharDataset:
@pytest.mark.run_only_on('CPU')
@pytest.mark.unit
@pytest.mark.torch_tts
def test_dataset(self, test_data_dir):
manifest_path = os.path.join(test_data_dir, 'tts/mini_ljspeech/manifest.json')
sup_path = os.path.join(test_data_dir, 'tts/mini_ljspeech/sup')
dataset = CharMelAudioDataset(
manifest_filepath=manifest_path, sample_rate=22050, supplementary_folder=sup_path
)
dataloader = torch.utils.data.DataLoader(dataset, 2, collate_fn=dataset._collate_fn)
data, _, _, _, _, _, _ = next(iter(dataloader))
class TestPhoneDataset:
@pytest.mark.run_only_on('CPU')
@pytest.mark.unit
@pytest.mark.torch_tts
def test_dataset(self, test_data_dir):
manifest_path = os.path.join(test_data_dir, 'tts/mini_ljspeech/manifest.json')
sup_path = os.path.join(test_data_dir, 'tts/mini_ljspeech/sup')
dataset = CharMelAudioDataset(
manifest_filepath=manifest_path, sample_rate=22050, supplementary_folder=sup_path
)
dataloader = torch.utils.data.DataLoader(dataset, 2, collate_fn=dataset._collate_fn)
_, _, _, _, _, _, _ = next(iter(dataloader))

View file

@ -15,7 +15,6 @@
import os
import pytest
import torch
from nemo.collections.tts.data.datalayers import FastSpeech2Dataset