Refactor and Minimize Dependencies (#2643)
* squash Signed-off-by: Jason <jasoli@nvidia.com> * add comments Signed-off-by: Jason <jasoli@nvidia.com> * style and cleanup Signed-off-by: Jason <jasoli@nvidia.com> * cleanup Signed-off-by: Jason <jasoli@nvidia.com> * add new test file Signed-off-by: Jason <jasoli@nvidia.com> * syntax Signed-off-by: Jason <jasoli@nvidia.com> * style Signed-off-by: Jason <jasoli@nvidia.com> * typo Signed-off-by: Jason <jasoli@nvidia.com> * update Signed-off-by: Jason <jasoli@nvidia.com> * update Signed-off-by: Jason <jasoli@nvidia.com> * update Signed-off-by: Jason <jasoli@nvidia.com> * try again Signed-off-by: Jason <jasoli@nvidia.com> * wip Signed-off-by: Jason <jasoli@nvidia.com> * style; ci should fail Signed-off-by: Jason <jasoli@nvidia.com> * final Signed-off-by: Jason <jasoli@nvidia.com>
This commit is contained in:
parent
94126c4b65
commit
4f2ea4913c
22
Jenkinsfile
vendored
22
Jenkinsfile
vendored
|
@ -40,6 +40,24 @@ pipeline {
|
|||
sh 'python setup.py style'
|
||||
}
|
||||
}
|
||||
|
||||
stage('Torch TTS unit tests') {
|
||||
when {
|
||||
anyOf {
|
||||
branch 'main'
|
||||
changeRequest target: 'main'
|
||||
}
|
||||
}
|
||||
steps {
|
||||
sh 'pip install ".[torch_tts]"'
|
||||
sh 'pip list'
|
||||
sh 'test $(pip list | grep -c lightning) -eq 0'
|
||||
sh 'test $(pip list | grep -c omegaconf) -eq 0'
|
||||
sh 'test $(pip list | grep -c hydra) -eq 0'
|
||||
sh 'pytest -m "torch_tts" --cpu tests/collections/tts/test_torch_tts.py --relax_numba_compat'
|
||||
}
|
||||
}
|
||||
|
||||
stage('Installation') {
|
||||
steps {
|
||||
sh './reinstall.sh release'
|
||||
|
@ -60,7 +78,7 @@ pipeline {
|
|||
|
||||
stage('L0: Unit Tests GPU') {
|
||||
steps {
|
||||
sh 'pytest -m "not pleasefixme" --with_downloads --relax_numba_compat'
|
||||
sh 'pytest -m "not pleasefixme and not torch_tts" --with_downloads --relax_numba_compat'
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -72,7 +90,7 @@ pipeline {
|
|||
}
|
||||
}
|
||||
steps {
|
||||
sh 'CUDA_VISIBLE_DEVICES="" pytest -m "not pleasefixme" --cpu --with_downloads --relax_numba_compat'
|
||||
sh 'CUDA_VISIBLE_DEVICES="" pytest -m "not pleasefixme and not torch_tts" --cpu --with_downloads --relax_numba_compat'
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -16,15 +16,17 @@ import abc
|
|||
import itertools
|
||||
import re
|
||||
import string
|
||||
import time
|
||||
import unicodedata
|
||||
from builtins import str as unicode
|
||||
from typing import List
|
||||
|
||||
import nltk
|
||||
import torch
|
||||
from pytorch_lightning.utilities.distributed import rank_zero_only, sync_ddp_if_available
|
||||
|
||||
from nemo.collections.common.parts.preprocessing import parsers
|
||||
from nemo.utils import logging
|
||||
from nemo.utils.get_rank import is_global_rank_zero
|
||||
|
||||
_words_re = re.compile("([a-z\-]+'[a-z\-]+|[a-z\-]+)|([^a-z{}]+)")
|
||||
|
||||
|
@ -43,7 +45,6 @@ def _word_tokenize(text):
|
|||
return words
|
||||
|
||||
|
||||
@rank_zero_only
|
||||
def download_corpora():
|
||||
# Download NLTK datasets if this class is to be instantiated
|
||||
try:
|
||||
|
@ -310,10 +311,29 @@ class Phonemes(Base):
|
|||
self.spaces = spaces
|
||||
self.pad_with_space = pad_with_space
|
||||
|
||||
download_corpora()
|
||||
_ = sync_ddp_if_available(torch.tensor(0)) # Barrier until rank 0 downloads the corpora
|
||||
|
||||
# g2p_en tries to run download_corpora() on import but it is not rank zero guarded
|
||||
# Try to check if torch distributed is available, if not get global rank zero to download corpora and make
|
||||
# all other ranks sleep for a minute
|
||||
if torch.distributed.is_available() and torch.distributed.is_initialized():
|
||||
group = torch.distributed.group.WORLD
|
||||
if is_global_rank_zero():
|
||||
download_corpora()
|
||||
torch.distributed.barrier(group=group)
|
||||
elif is_global_rank_zero():
|
||||
logging.error(
|
||||
f"Torch distributed needs to be initialized before you initialized {self}. This class is prone to "
|
||||
"data access race conditions. Now downloading corpora from global rank 0. If other ranks pass this "
|
||||
"before rank 0, errors might result."
|
||||
)
|
||||
download_corpora()
|
||||
else:
|
||||
logging.error(
|
||||
f"Torch distributed needs to be initialized before you initialized {self}. This class is prone to "
|
||||
"data access race conditions. This process is not rank 0, and now going to sleep for 1 min. If this "
|
||||
"rank wakes from sleep prior to rank 0 finishing downloading, errors might result."
|
||||
)
|
||||
time.sleep(60)
|
||||
|
||||
import g2p_en # noqa pylint: disable=import-outside-toplevel
|
||||
|
||||
_g2p = g2p_en.G2p()
|
||||
|
|
|
@ -12,11 +12,27 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from nemo.collections.asr.models.asr_model import ASRModel
|
||||
from nemo.collections.asr.models.classification_models import EncDecClassificationModel
|
||||
from nemo.collections.asr.models.clustering_diarizer import ClusteringDiarizer
|
||||
from nemo.collections.asr.models.ctc_bpe_models import EncDecCTCModelBPE
|
||||
from nemo.collections.asr.models.ctc_models import EncDecCTCModel
|
||||
from nemo.collections.asr.models.label_models import EncDecSpeakerLabelModel, ExtractSpeakerEmbeddingsModel
|
||||
from nemo.collections.asr.models.rnnt_bpe_models import EncDecRNNTBPEModel
|
||||
from nemo.collections.asr.models.rnnt_models import EncDecRNNTModel
|
||||
# TODO @blisc: Perhaps refactor instead of import guarding
|
||||
try:
|
||||
from nemo.collections.asr.models.asr_model import ASRModel
|
||||
from nemo.collections.asr.models.classification_models import EncDecClassificationModel
|
||||
from nemo.collections.asr.models.clustering_diarizer import ClusteringDiarizer
|
||||
from nemo.collections.asr.models.ctc_bpe_models import EncDecCTCModelBPE
|
||||
from nemo.collections.asr.models.ctc_models import EncDecCTCModel
|
||||
from nemo.collections.asr.models.label_models import EncDecSpeakerLabelModel, ExtractSpeakerEmbeddingsModel
|
||||
from nemo.collections.asr.models.rnnt_bpe_models import EncDecRNNTBPEModel
|
||||
from nemo.collections.asr.models.rnnt_models import EncDecRNNTModel
|
||||
except ModuleNotFoundError:
|
||||
from nemo.utils.exceptions import CheckInstall
|
||||
|
||||
# fmt: off
|
||||
class ASRModel(CheckInstall): pass
|
||||
class EncDecClassificationModel(CheckInstall): pass
|
||||
class ClusteringDiarizer(CheckInstall): pass
|
||||
class EncDecCTCModelBPE(CheckInstall): pass
|
||||
class EncDecCTCModel(CheckInstall): pass
|
||||
class EncDecSpeakerLabelModel(CheckInstall): pass
|
||||
class ExtractSpeakerEmbeddingsModel(CheckInstall): pass
|
||||
class EncDecRNNTBPEModel(CheckInstall): pass
|
||||
class EncDecRNNTModel(CheckInstall): pass
|
||||
# fmt: on
|
||||
|
|
|
@ -20,13 +20,27 @@ from nemo.collections.asr.modules.audio_preprocessing import (
|
|||
)
|
||||
from nemo.collections.asr.modules.beam_search_decoder import BeamSearchDecoderWithLM
|
||||
from nemo.collections.asr.modules.conformer_encoder import ConformerEncoder
|
||||
from nemo.collections.asr.modules.conv_asr import (
|
||||
ConvASRDecoder,
|
||||
ConvASRDecoderClassification,
|
||||
ConvASREncoder,
|
||||
ECAPAEncoder,
|
||||
ParallelConvASREncoder,
|
||||
SpeakerDecoder,
|
||||
)
|
||||
from nemo.collections.asr.modules.lstm_decoder import LSTMDecoder
|
||||
from nemo.collections.asr.modules.rnnt import RNNTDecoder, RNNTJoint
|
||||
|
||||
# TODO @blisc: Perhaps refactor instead of import guarding
|
||||
try:
|
||||
from nemo.collections.asr.modules.conv_asr import (
|
||||
ConvASRDecoder,
|
||||
ConvASRDecoderClassification,
|
||||
ConvASREncoder,
|
||||
ECAPAEncoder,
|
||||
ParallelConvASREncoder,
|
||||
SpeakerDecoder,
|
||||
)
|
||||
except ModuleNotFoundError:
|
||||
from nemo.utils.exceptions import CheckInstall
|
||||
|
||||
# fmt: off
|
||||
class ConvASRDecoder(CheckInstall): pass
|
||||
class ConvASRDecoderClassification(CheckInstall): pass
|
||||
class ConvASREncoder(CheckInstall): pass
|
||||
class ECAPAEncoder(CheckInstall): pass
|
||||
class ParallelConvASREncoder(CheckInstall): pass
|
||||
class SpeakerDecoder(CheckInstall): pass
|
||||
# fmt: on
|
||||
|
|
|
@ -41,13 +41,23 @@ import torch.nn as nn
|
|||
import torch.nn.functional as F
|
||||
from librosa.util import tiny
|
||||
from torch.autograd import Variable
|
||||
from torch_stft import STFT
|
||||
|
||||
from nemo.collections.asr.parts.preprocessing.perturb import AudioAugmentor
|
||||
from nemo.collections.asr.parts.preprocessing.segment import AudioSegment
|
||||
from nemo.collections.common.parts.patch_utils import stft_patch
|
||||
from nemo.utils import logging
|
||||
|
||||
# TODO @blisc: Perhaps refactor instead of import guarding
|
||||
try:
|
||||
from torch_stft import STFT
|
||||
except ModuleNotFoundError:
|
||||
from nemo.utils.exceptions import CheckInstall
|
||||
|
||||
# fmt: off
|
||||
class STFT(CheckInstall): pass
|
||||
# fmt: on
|
||||
|
||||
|
||||
CONSTANT = 1e-5
|
||||
|
||||
|
||||
|
|
|
@ -43,8 +43,6 @@ from typing import List, Optional, Union
|
|||
import librosa
|
||||
import numpy as np
|
||||
import soundfile as sf
|
||||
import webdataset as wd
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
from scipy import signal
|
||||
from torch.utils.data import IterableDataset
|
||||
|
||||
|
@ -52,6 +50,17 @@ from nemo.collections.asr.parts.preprocessing.segment import AudioSegment
|
|||
from nemo.collections.common.parts.preprocessing import collections, parsers
|
||||
from nemo.utils import logging
|
||||
|
||||
# TODO @blisc: Perhaps refactor instead of import guarding
|
||||
HAVE_OMEGACONG_WEBDATASET = True
|
||||
try:
|
||||
import webdataset as wd
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
except ModuleNotFoundError:
|
||||
from nemo.utils.exceptions import LightningNotInstalledException
|
||||
|
||||
HAVE_OMEGACONG_WEBDATASET = False
|
||||
|
||||
|
||||
try:
|
||||
from nemo.collections.asr.parts.utils import numba_utils
|
||||
|
||||
|
@ -792,10 +801,13 @@ def process_augmentations(augmenter) -> Optional[AudioAugmentor]:
|
|||
if isinstance(augmenter, AudioAugmentor):
|
||||
return augmenter
|
||||
|
||||
if not type(augmenter) in {dict, DictConfig}:
|
||||
augmenter_types = {dict}
|
||||
if HAVE_OMEGACONG_WEBDATASET:
|
||||
augmenter_types = {dict, DictConfig}
|
||||
if not type(augmenter) in augmenter_types:
|
||||
raise ValueError("Cannot parse augmenter. Must be a dict or an AudioAugmentor object ")
|
||||
|
||||
if isinstance(augmenter, DictConfig):
|
||||
if HAVE_OMEGACONG_WEBDATASET and isinstance(augmenter, DictConfig):
|
||||
augmenter = OmegaConf.to_container(augmenter, resolve=True)
|
||||
|
||||
augmenter = copy.deepcopy(augmenter)
|
||||
|
@ -864,6 +876,8 @@ class AugmentationDataset(IterableDataset):
|
|||
if bkey in tar_filepaths:
|
||||
tar_filepaths = tar_filepaths.replace(bkey, "}")
|
||||
|
||||
if not HAVE_OMEGACONG_WEBDATASET:
|
||||
raise LightningNotInstalledException(self)
|
||||
self.audio_dataset = wd.WebDataset(urls=tar_filepaths, nodesplitter=None)
|
||||
|
||||
if shuffle_n > 0:
|
||||
|
|
|
@ -39,13 +39,20 @@ import random
|
|||
import librosa
|
||||
import numpy as np
|
||||
import soundfile as sf
|
||||
from kaldiio.matio import read_kaldi
|
||||
from kaldiio.utils import open_like_kaldi
|
||||
from pydub import AudioSegment as Audio
|
||||
from pydub.exceptions import CouldntDecodeError
|
||||
|
||||
from nemo.utils import logging
|
||||
|
||||
# TODO @blisc: Perhaps refactor instead of import guarding
|
||||
HAVE_KALDI_PYDUB = True
|
||||
try:
|
||||
from kaldiio.matio import read_kaldi
|
||||
from kaldiio.utils import open_like_kaldi
|
||||
from pydub import AudioSegment as Audio
|
||||
from pydub.exceptions import CouldntDecodeError
|
||||
except ModuleNotFoundError:
|
||||
HAVE_KALDI_PYDUB = False
|
||||
|
||||
|
||||
available_formats = sf.available_formats()
|
||||
sf_supported_formats = ["." + i.lower() for i in available_formats.keys()]
|
||||
|
||||
|
@ -149,7 +156,7 @@ class AudioSegment(object):
|
|||
f"Loading {audio_file} via SoundFile raised RuntimeError: `{e}`. "
|
||||
f"NeMo will fallback to loading via pydub."
|
||||
)
|
||||
elif isinstance(audio_file, str) and audio_file.strip()[-1] == "|":
|
||||
elif HAVE_KALDI_PYDUB and isinstance(audio_file, str) and audio_file.strip()[-1] == "|":
|
||||
f = open_like_kaldi(audio_file, "rb")
|
||||
sample_rate, samples = read_kaldi(f)
|
||||
if offset > 0:
|
||||
|
@ -160,7 +167,7 @@ class AudioSegment(object):
|
|||
abs_max_value = np.abs(samples).max()
|
||||
samples = np.array(samples, dtype=np.float) / abs_max_value
|
||||
|
||||
if samples is None:
|
||||
if HAVE_KALDI_PYDUB and samples is None:
|
||||
try:
|
||||
samples = Audio.from_file(audio_file)
|
||||
sample_rate = samples.frame_rate
|
||||
|
@ -172,8 +179,12 @@ class AudioSegment(object):
|
|||
seconds = duration * 1000
|
||||
samples = samples[: int(seconds)]
|
||||
samples = np.array(samples.get_array_of_samples())
|
||||
except CouldntDecodeError as e:
|
||||
logging.error(f"Loading {audio_file} via pydub raised CouldntDecodeError: `{e}`.")
|
||||
except CouldntDecodeError as err:
|
||||
logging.error(f"Loading {audio_file} via pydub raised CouldntDecodeError: `{err}`.")
|
||||
|
||||
if samples is None:
|
||||
libs = "soundfile, kaldiio, and pydub" if HAVE_KALDI_PYDUB else "soundfile"
|
||||
raise Exception(f"Your audio file {audio_file} could not be decoded. We tried using {libs}.")
|
||||
|
||||
return cls(samples, sample_rate, target_sr=target_sr, trim=trim, orig_sr=orig_sr)
|
||||
|
||||
|
|
|
@ -12,4 +12,12 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from nemo.collections.common.callbacks.callbacks import LogEpochTimeCallback
|
||||
# TODO @blisc: Perhaps refactor instead of import guarding
|
||||
try:
|
||||
from nemo.collections.common.callbacks.callbacks import LogEpochTimeCallback
|
||||
except ModuleNotFoundError:
|
||||
from nemo.utils.exceptions import CheckInstall
|
||||
|
||||
# fmt: off
|
||||
class LogEpochTimeCallback(CheckInstall): pass
|
||||
# fmt: on
|
||||
|
|
|
@ -11,6 +11,12 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
'''
|
||||
A collection of simple character based parsers. These parser handle cleaning and tokenization by default.
|
||||
We currently support English.
|
||||
'''
|
||||
|
||||
import string
|
||||
from typing import List, Optional
|
||||
|
||||
|
@ -36,6 +42,7 @@ class CharParser:
|
|||
blank_id: int = -1,
|
||||
do_normalize: bool = True,
|
||||
do_lowercase: bool = True,
|
||||
do_tokenize: bool = True,
|
||||
):
|
||||
"""Creates simple mapping char parser.
|
||||
|
||||
|
@ -56,6 +63,7 @@ class CharParser:
|
|||
self._blank_id = blank_id
|
||||
self._do_normalize = do_normalize
|
||||
self._do_lowercase = do_lowercase
|
||||
self._do_tokenize = do_tokenize
|
||||
|
||||
self._labels_map = {label: index for index, label in enumerate(labels)}
|
||||
self._special_labels = set([label for label in labels if len(label) > 1])
|
||||
|
@ -66,8 +74,10 @@ class CharParser:
|
|||
if text is None:
|
||||
return None
|
||||
|
||||
text_tokens = self._tokenize(text)
|
||||
if not self._do_tokenize:
|
||||
return text
|
||||
|
||||
text_tokens = self._tokenize(text)
|
||||
return text_tokens
|
||||
|
||||
def _normalize(self, text: str) -> Optional[str]:
|
||||
|
@ -97,6 +107,23 @@ class CharParser:
|
|||
|
||||
return tokens
|
||||
|
||||
def decode(self, str_input):
|
||||
r_map = {}
|
||||
for k, v in self._labels_map.items():
|
||||
r_map[v] = k
|
||||
r_map[len(self._labels_map)] = "<BOS>"
|
||||
r_map[len(self._labels_map) + 1] = "<EOS>"
|
||||
r_map[len(self._labels_map) + 2] = "<P>"
|
||||
|
||||
out = []
|
||||
for i in str_input:
|
||||
# Skip OOV
|
||||
if i not in r_map:
|
||||
continue
|
||||
out.append(r_map[i.item()])
|
||||
|
||||
return "".join(out)
|
||||
|
||||
|
||||
class ENCharParser(CharParser):
|
||||
"""Incorporates english-specific parsing logic."""
|
||||
|
|
|
@ -14,7 +14,17 @@
|
|||
|
||||
from nemo.collections.common.tokenizers.bytelevel_tokenizers import ByteLevelTokenizer
|
||||
from nemo.collections.common.tokenizers.char_tokenizer import CharTokenizer
|
||||
from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import AutoTokenizer
|
||||
from nemo.collections.common.tokenizers.sentencepiece_tokenizer import SentencePieceTokenizer
|
||||
from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec
|
||||
from nemo.collections.common.tokenizers.word_tokenizer import WordTokenizer
|
||||
|
||||
# TODO @blisc: Perhaps refactor instead of import guarding
|
||||
try:
|
||||
from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import AutoTokenizer
|
||||
from nemo.collections.common.tokenizers.sentencepiece_tokenizer import SentencePieceTokenizer
|
||||
except ModuleNotFoundError:
|
||||
from nemo.utils.exceptions import CheckInstall
|
||||
|
||||
# fmt: off
|
||||
class AutoTokenizer(CheckInstall): pass
|
||||
class SentencePieceTokenizer(CheckInstall): pass
|
||||
# fmt: on
|
||||
|
|
|
@ -23,10 +23,22 @@ from numba import jit, prange
|
|||
from numpy import ndarray
|
||||
from pesq import pesq
|
||||
from pystoi import stoi
|
||||
from pytorch_lightning.utilities import rank_zero_only
|
||||
|
||||
from nemo.utils import logging
|
||||
|
||||
try:
|
||||
from pytorch_lightning.utilities import rank_zero_only
|
||||
except ModuleNotFoundError:
|
||||
from functools import wraps
|
||||
|
||||
def rank_zero_only(fn):
|
||||
@wraps(fn)
|
||||
def wrapped_fn(*args, **kwargs):
|
||||
logging.error(
|
||||
f"Function {fn} requires lighting to be installed, but it was not found. Please install lightning first"
|
||||
)
|
||||
exit(1)
|
||||
|
||||
|
||||
class OperationMode(Enum):
|
||||
"""Training or Inference (Evaluation) mode"""
|
||||
|
|
|
@ -12,22 +12,25 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from nemo.collections.tts.models.aligner import AlignerModel
|
||||
from nemo.collections.tts.models.degli import DegliModel
|
||||
from nemo.collections.tts.models.ed_mel2spec import EDMel2SpecModel
|
||||
from nemo.collections.tts.models.fastpitch import FastPitchModel
|
||||
from nemo.collections.tts.models.fastpitch_hifigan_e2e import FastPitchHifiGanE2EModel
|
||||
from nemo.collections.tts.models.fastspeech2 import FastSpeech2Model
|
||||
from nemo.collections.tts.models.fastspeech2_hifigan_e2e import FastSpeech2HifiGanE2EModel
|
||||
from nemo.collections.tts.models.glow_tts import GlowTTSModel
|
||||
from nemo.collections.tts.models.hifigan import HifiGanModel
|
||||
from nemo.collections.tts.models.melgan import MelGanModel
|
||||
from nemo.collections.tts.models.squeezewave import SqueezeWaveModel
|
||||
from nemo.collections.tts.models.tacotron2 import Tacotron2Model
|
||||
from nemo.collections.tts.models.talknet import TalkNetDursModel, TalkNetPitchModel, TalkNetSpectModel
|
||||
from nemo.collections.tts.models.two_stages import GriffinLimModel, MelPsuedoInverseModel, TwoStagesModel
|
||||
from nemo.collections.tts.models.uniglow import UniGlowModel
|
||||
from nemo.collections.tts.models.waveglow import WaveGlowModel
|
||||
try:
|
||||
from nemo.collections.tts.models.aligner import AlignerModel
|
||||
from nemo.collections.tts.models.degli import DegliModel
|
||||
from nemo.collections.tts.models.ed_mel2spec import EDMel2SpecModel
|
||||
from nemo.collections.tts.models.fastpitch import FastPitchModel
|
||||
from nemo.collections.tts.models.fastpitch_hifigan_e2e import FastPitchHifiGanE2EModel
|
||||
from nemo.collections.tts.models.fastspeech2 import FastSpeech2Model
|
||||
from nemo.collections.tts.models.fastspeech2_hifigan_e2e import FastSpeech2HifiGanE2EModel
|
||||
from nemo.collections.tts.models.glow_tts import GlowTTSModel
|
||||
from nemo.collections.tts.models.hifigan import HifiGanModel
|
||||
from nemo.collections.tts.models.melgan import MelGanModel
|
||||
from nemo.collections.tts.models.squeezewave import SqueezeWaveModel
|
||||
from nemo.collections.tts.models.tacotron2 import Tacotron2Model
|
||||
from nemo.collections.tts.models.talknet import TalkNetDursModel, TalkNetPitchModel, TalkNetSpectModel
|
||||
from nemo.collections.tts.models.two_stages import GriffinLimModel, MelPsuedoInverseModel, TwoStagesModel
|
||||
from nemo.collections.tts.models.uniglow import UniGlowModel
|
||||
from nemo.collections.tts.models.waveglow import WaveGlowModel
|
||||
except ModuleNotFoundError:
|
||||
pass
|
||||
|
||||
__all__ = [
|
||||
"GlowTTSModel",
|
||||
|
|
|
@ -44,8 +44,8 @@ import torch.nn.functional as F
|
|||
from torch.nn import AvgPool1d, Conv1d, Conv2d, ConvTranspose1d
|
||||
from torch.nn.utils import remove_weight_norm, spectral_norm, weight_norm
|
||||
|
||||
from nemo.core.classes import NeuralModule
|
||||
from nemo.core.classes.common import typecheck
|
||||
from nemo.core.classes.module import NeuralModule
|
||||
from nemo.core.neural_types.elements import AudioSignal, MelSpectrogramType, VoidType
|
||||
from nemo.core.neural_types.neural_type import NeuralType
|
||||
|
||||
|
|
13
nemo/collections/tts/torch/__init__.py
Normal file
13
nemo/collections/tts/torch/__init__.py
Normal file
|
@ -0,0 +1,13 @@
|
|||
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
521
nemo/collections/tts/torch/data.py
Normal file
521
nemo/collections/tts/torch/data.py
Normal file
|
@ -0,0 +1,521 @@
|
|||
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import json
|
||||
import pickle
|
||||
from pathlib import Path
|
||||
from typing import Dict, Optional
|
||||
|
||||
import librosa
|
||||
import torch
|
||||
|
||||
from nemo.collections.asr.data.vocabs import Base, Phonemes
|
||||
from nemo.collections.asr.parts.preprocessing.features import WaveformFeaturizer
|
||||
from nemo.collections.common.parts.patch_utils import stft_patch
|
||||
from nemo.collections.common.parts.preprocessing.parsers import make_parser
|
||||
from nemo.collections.tts.torch.helpers import beta_binomial_prior_distribution
|
||||
from nemo.core.classes import Dataset
|
||||
from nemo.core.neural_types.elements import *
|
||||
from nemo.core.neural_types.neural_type import NeuralType
|
||||
from nemo.utils import logging
|
||||
|
||||
CONSTANT = 1e-5
|
||||
|
||||
|
||||
class CharMelAudioDataset(Dataset):
|
||||
@property
|
||||
def output_types(self) -> Optional[Dict[str, NeuralType]]:
|
||||
return {
|
||||
'transcripts': NeuralType(('B', 'T'), TokenIndex()),
|
||||
'transcript_length': NeuralType(('B'), LengthsType()),
|
||||
'mels': NeuralType(('B', 'D', 'T'), TokenIndex()),
|
||||
'mel_length': NeuralType(('B'), LengthsType()),
|
||||
'audio': NeuralType(('B', 'T'), AudioSignal()),
|
||||
'audio_length': NeuralType(('B'), LengthsType()),
|
||||
'duration_prior': NeuralType(('B', 'T'), TokenDurationType()),
|
||||
'pitches': NeuralType(('B', 'T'), RegressionValuesType()),
|
||||
'energies': NeuralType(('B', 'T'), RegressionValuesType()),
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
manifest_filepath: str,
|
||||
sample_rate: int,
|
||||
supplementary_folder: Path,
|
||||
max_duration: Optional[float] = None,
|
||||
min_duration: Optional[float] = None,
|
||||
ignore_file: Optional[str] = None,
|
||||
trim: bool = False,
|
||||
n_fft=1024,
|
||||
win_length=None,
|
||||
hop_length=None,
|
||||
window="hann",
|
||||
n_mels=64,
|
||||
lowfreq=0,
|
||||
highfreq=None,
|
||||
pitch_fmin=80,
|
||||
pitch_fmax=640,
|
||||
pitch_avg=0,
|
||||
pitch_std=1,
|
||||
tokenize_text=True,
|
||||
):
|
||||
"""Dataset that loads audio, log mel specs, text tokens, duration / attention priors, pitches, and energies.
|
||||
Log mels, priords, pitches, and energies will be computed on the fly and saved in the supplementary_folder if
|
||||
they did not exist before.
|
||||
|
||||
Args:
|
||||
manifest_filepath (str, Path, List[str, Path]): Path(s) to the .json manifests containing information on the
|
||||
dataset. Each line in the .json file should be valid json. Note: the .json file itself is not valid
|
||||
json. Each line should contain the following:
|
||||
"audio_filepath": <PATH_TO_WAV>
|
||||
"mel_filepath": <PATH_TO_LOG_MEL_PT> (Optional)
|
||||
"duration": <Duration of audio clip in seconds> (Optional)
|
||||
"text": <THE_TRANSCRIPT> (Optional)
|
||||
sample_rate (int): The sample rate of the audio. Or the sample rate that we will resample all files to.
|
||||
supplementary_folder (Path): A folder that contains or will contain extra information such as log_mel if not
|
||||
specified in the manifest .json file. It will also contain priors, pitches, and energies
|
||||
max_duration (Optional[float]): Max duration of audio clips in seconds. All samples exceeding this will be
|
||||
pruned prior to training. Note: Requires "duration" to be set in the manifest file. It does not load
|
||||
audio to compute duration. Defaults to None which does not prune.
|
||||
min_duration (Optional[float]): Min duration of audio clips in seconds. All samples lower than this will be
|
||||
pruned prior to training. Note: Requires "duration" to be set in the manifest file. It does not load
|
||||
audio to compute duration. Defaults to None which does not prune.
|
||||
ignore_file (Optional[str, Path]): The location of a pickle-saved list of audio_ids (the stem of the audio
|
||||
files) that will be pruned prior to training. Defaults to None which does not prune.
|
||||
trim (Optional[bool]): Whether to apply librosa.effects.trim to the audio file. Defaults to False.
|
||||
n_fft (Optional[int]): The number of fft samples. Defaults to 1024
|
||||
win_length (Optional[int]): The length of the stft windows. Defaults to None which uses n_fft.
|
||||
hop_length (Optional[int]): The hope length between fft computations. Defaults to None which uses n_fft//4.
|
||||
window (Optional[str]): One of 'hann', 'hamming', 'blackman','bartlett', 'none'. Which corresponds to the
|
||||
equivalent torch window function.
|
||||
n_mels (Optional[int]): The number of mel filters. Defaults to 64.
|
||||
lowfreq (Optional[int]): The lowfreq input to the mel filter calculation. Defaults to 0.
|
||||
highfreq (Optional[int]): The highfreq input to the mel filter calculation. Defaults to None.
|
||||
pitch_fmin (Optional[int]): The fmin input to librosa.pyin. Defaults to None.
|
||||
pitch_fmax (Optional[int]): The fmax input to librosa.pyin. Defaults to None.
|
||||
pitch_avg (Optional[float]): The mean that we use to normalize the pitch. Defaults to 0.
|
||||
pitch_std (Optional[float]): The std that we use to normalize the pitch. Defaults to 1.
|
||||
tokenize_text (Optional[bool]): Whether to tokenize (turn chars into ints). Defaults to True.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.pitch_fmin = pitch_fmin
|
||||
self.pitch_fmax = pitch_fmax
|
||||
self.pitch_avg = pitch_avg
|
||||
self.pitch_std = pitch_std
|
||||
self.win_length = win_length or n_fft
|
||||
self.sample_rate = sample_rate
|
||||
self.hop_len = hop_length or n_fft // 4
|
||||
|
||||
self.parser = make_parser(name="en", do_tokenize=tokenize_text)
|
||||
self.pad_id = self.parser._blank_id
|
||||
Path(supplementary_folder).mkdir(parents=True, exist_ok=True)
|
||||
self.supplementary_folder = supplementary_folder
|
||||
|
||||
audio_files = []
|
||||
total_duration = 0
|
||||
# Load data from manifests
|
||||
# Note: audio is always required, even for text -> mel_spectrogram models, due to the fact that most models
|
||||
# extract pitch from the audio
|
||||
# Note: mel_filepath is not required and if not present, we then check the supplementary folder. If we fail, we
|
||||
# compute the mel on the fly and save it to the supplementary folder
|
||||
# Note: text is not required. Any models that require on text (spectrogram generators, end-to-end models) will
|
||||
# fail if not set. However vocoders (mel -> audio) will be able to work without text
|
||||
if isinstance(manifest_filepath, str):
|
||||
manifest_filepath = [manifest_filepath]
|
||||
for manifest_file in manifest_filepath:
|
||||
with open(Path(manifest_file).expanduser(), 'r') as f:
|
||||
logging.info(f"Loading dataset from {manifest_file}.")
|
||||
for line in f:
|
||||
item = json.loads(line)
|
||||
# Grab audio, text, mel if they exist
|
||||
file_info = {}
|
||||
file_info["audio_filepath"] = item["audio_filepath"]
|
||||
file_info["mel_filepath"] = item["mel_filepath"] if "mel_filepath" in item else None
|
||||
file_info["duration"] = item["duration"] if "duration" in item else None
|
||||
# Parse text
|
||||
file_info["text_tokens"] = None
|
||||
if "text" in item:
|
||||
text = item["text"]
|
||||
text_tokens = self.parser(text)
|
||||
file_info["text_tokens"] = text_tokens
|
||||
audio_files.append(file_info)
|
||||
if file_info["duration"] is None:
|
||||
logging.info(
|
||||
"Not all audio files have duration information. Duration logging will be disabled."
|
||||
)
|
||||
total_duration = None
|
||||
if total_duration is not None:
|
||||
total_duration += item["duration"]
|
||||
|
||||
logging.info(f"Loaded dataset with {len(audio_files)} files.")
|
||||
if total_duration is not None:
|
||||
logging.info(f"Dataset contains {total_duration/3600:.2f} hours.")
|
||||
|
||||
self.data = []
|
||||
|
||||
if ignore_file:
|
||||
logging.info(f"using {ignore_file} to prune dataset.")
|
||||
with open(Path(ignore_file).expanduser(), "rb") as f:
|
||||
wavs_to_ignore = set(pickle.load(f))
|
||||
|
||||
pruned_duration = 0 if total_duration is not None else None
|
||||
pruned_items = 0
|
||||
for item in audio_files:
|
||||
audio_path = item['audio_filepath']
|
||||
audio_id = Path(audio_path).stem
|
||||
|
||||
# Prune data according to min/max_duration & the ignore file
|
||||
if total_duration is not None:
|
||||
if (min_duration and item["duration"] < min_duration) or (
|
||||
max_duration and item["duration"] > max_duration
|
||||
):
|
||||
pruned_duration += item["duration"]
|
||||
pruned_items += 1
|
||||
continue
|
||||
|
||||
if ignore_file and (audio_id in wavs_to_ignore):
|
||||
pruned_items += 1
|
||||
pruned_duration += item["duration"]
|
||||
wavs_to_ignore.remove(audio_id)
|
||||
continue
|
||||
|
||||
self.data.append(item)
|
||||
|
||||
logging.info(f"Pruned {pruned_items} files. Final dataset contains {len(self.data)} files")
|
||||
if pruned_duration is not None:
|
||||
logging.info(
|
||||
f"Pruned {pruned_duration/3600:.2f} hours. Final dataset contains "
|
||||
f"{(total_duration-pruned_duration)/3600:.2f} hours."
|
||||
)
|
||||
|
||||
self.featurizer = WaveformFeaturizer(sample_rate=sample_rate)
|
||||
self.trim = trim
|
||||
|
||||
filterbanks = torch.tensor(
|
||||
librosa.filters.mel(sample_rate, n_fft, n_mels=n_mels, fmin=lowfreq, fmax=highfreq), dtype=torch.float
|
||||
).unsqueeze(0)
|
||||
self.fb = filterbanks
|
||||
|
||||
torch_windows = {
|
||||
'hann': torch.hann_window,
|
||||
'hamming': torch.hamming_window,
|
||||
'blackman': torch.blackman_window,
|
||||
'bartlett': torch.bartlett_window,
|
||||
'none': None,
|
||||
}
|
||||
window_fn = torch_windows.get(window, None)
|
||||
window_tensor = window_fn(self.win_length, periodic=False) if window_fn else None
|
||||
|
||||
self.stft = lambda x: stft_patch(
|
||||
input=x,
|
||||
n_fft=n_fft,
|
||||
hop_length=self.hop_len,
|
||||
win_length=self.win_length,
|
||||
window=window_tensor.to(torch.float),
|
||||
)
|
||||
|
||||
def __getitem__(self, index):
|
||||
spec = None
|
||||
sample = self.data[index]
|
||||
|
||||
features = self.featurizer.process(sample["audio_filepath"], trim=self.trim)
|
||||
audio, audio_length = features, torch.tensor(features.shape[0]).long()
|
||||
if isinstance(sample["text_tokens"], str):
|
||||
# If tokenize_text is False for Phone dataset
|
||||
text = sample["text_tokens"]
|
||||
text_length = None
|
||||
else:
|
||||
text = torch.tensor(sample["text_tokens"]).long()
|
||||
text_length = torch.tensor(len(sample["text_tokens"])).long()
|
||||
audio_stem = Path(sample["audio_filepath"]).stem
|
||||
|
||||
# Load mel if it exists
|
||||
mel_path = sample["mel_filepath"]
|
||||
if mel_path and Path(mel_path).exists():
|
||||
log_mel = torch.load(mel_path)
|
||||
else:
|
||||
mel_path = Path(self.supplementary_folder) / f"mel_{audio_stem}.pt"
|
||||
if mel_path.exists():
|
||||
log_mel = torch.load(mel_path)
|
||||
else:
|
||||
# disable autocast to get full range of stft values
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
spec = self.stft(audio)
|
||||
|
||||
# guard is needed for sqrt if grads are passed through
|
||||
guard = CONSTANT # TODO: Enable 0 if not self.use_grads else CONSTANT
|
||||
if spec.dtype in [torch.cfloat, torch.cdouble]:
|
||||
spec = torch.view_as_real(spec)
|
||||
spec = torch.sqrt(spec.pow(2).sum(-1) + guard)
|
||||
|
||||
mel = torch.matmul(self.fb.to(spec.dtype), spec)
|
||||
|
||||
log_mel = torch.log(torch.clamp(mel, min=torch.finfo(mel.dtype).tiny))
|
||||
torch.save(log_mel, mel_path)
|
||||
|
||||
log_mel = log_mel.squeeze(0)
|
||||
log_mel_length = torch.tensor(log_mel.shape[1]).long()
|
||||
|
||||
duration_prior = None
|
||||
if text_length is not None:
|
||||
### Make duration attention prior if not exist in the supplementary folder
|
||||
prior_path = Path(self.supplementary_folder) / f"pr_tl{text_length}_al_{log_mel_length}.pt"
|
||||
if prior_path.exists():
|
||||
duration_prior = torch.load(prior_path)
|
||||
else:
|
||||
duration_prior = beta_binomial_prior_distribution(text_length, log_mel_length)
|
||||
duration_prior = torch.from_numpy(duration_prior)
|
||||
torch.save(duration_prior, prior_path)
|
||||
|
||||
# Load pitch file (F0s)
|
||||
pitch_path = (
|
||||
Path(self.supplementary_folder)
|
||||
/ f"{audio_stem}_pitch_pyin_fmin{self.pitch_fmin}_fmax{self.pitch_fmax}_fl{self.win_length}_hs{self.hop_len}.pt"
|
||||
)
|
||||
if pitch_path.exists():
|
||||
pitch = torch.load(pitch_path)
|
||||
else:
|
||||
pitch, _, _ = librosa.pyin(
|
||||
audio.numpy(),
|
||||
fmin=self.pitch_fmin,
|
||||
fmax=self.pitch_fmax,
|
||||
frame_length=self.win_length,
|
||||
sr=self.sample_rate,
|
||||
fill_na=0.0,
|
||||
)
|
||||
pitch = torch.from_numpy(pitch)
|
||||
torch.save(pitch, pitch_path)
|
||||
# Standize pitch
|
||||
pitch -= self.pitch_avg
|
||||
pitch[pitch == -self.pitch_avg] = 0.0 # Zero out values that were perviously zero
|
||||
pitch /= self.pitch_std
|
||||
|
||||
# Load energy file (L2-norm of the amplitude of each STFT frame of an utterance)
|
||||
energy_path = Path(self.supplementary_folder) / f"{audio_stem}_energy_wl{self.win_length}_hs{self.hop_len}.pt"
|
||||
if energy_path.exists():
|
||||
energy = torch.load(energy_path)
|
||||
else:
|
||||
if spec is None:
|
||||
spec = self.stft(audio)
|
||||
energy = torch.linalg.norm(spec.squeeze(0), axis=0)
|
||||
# Save to new file
|
||||
torch.save(energy, energy_path)
|
||||
|
||||
return text, text_length, log_mel, log_mel_length, audio, audio_length, duration_prior, pitch, energy
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
def _collate_fn(self, batch):
|
||||
log_mel_pad = torch.finfo(batch[0][2].dtype).tiny
|
||||
|
||||
_, tokens_lengths, _, log_mel_lengths, _, audio_lengths, duration_priors_list, pitches, energies = zip(*batch)
|
||||
|
||||
max_tokens_len = max(tokens_lengths).item()
|
||||
max_log_mel_len = max(log_mel_lengths)
|
||||
max_audio_len = max(audio_lengths).item()
|
||||
max_pitches_len = max([len(i) for i in pitches])
|
||||
max_energies_len = max([len(i) for i in energies])
|
||||
if max_pitches_len != max_energies_len or max_pitches_len != max_log_mel_len:
|
||||
logging.warning(
|
||||
f"max_pitches_len: {max_pitches_len} != max_energies_len: {max_energies_len} != "
|
||||
f"max_mel_len:{max_log_mel_len}. Your training run will error out!"
|
||||
)
|
||||
|
||||
# Define empty lists to be batched
|
||||
duration_priors = torch.zeros(
|
||||
len(duration_priors_list),
|
||||
max([prior_i.shape[0] for prior_i in duration_priors_list]),
|
||||
max([prior_i.shape[1] for prior_i in duration_priors_list]),
|
||||
)
|
||||
audios, tokens, log_mels, pitches, energies = [], [], [], [], []
|
||||
for i, sample_tuple in enumerate(batch):
|
||||
token, token_len, log_mel, log_mel_len, audio, audio_len, duration_prior, pitch, energy = sample_tuple
|
||||
# Pad text tokens
|
||||
token_len = token_len.item()
|
||||
if token_len < max_tokens_len:
|
||||
pad = (0, max_tokens_len - token_len)
|
||||
token = torch.nn.functional.pad(token, pad, value=self.pad_id)
|
||||
tokens.append(token)
|
||||
# Pad mel
|
||||
log_mel_len = log_mel_len
|
||||
if log_mel_len < max_log_mel_len:
|
||||
pad = (0, max_log_mel_len - log_mel_len)
|
||||
log_mel = torch.nn.functional.pad(log_mel, pad, value=log_mel_pad)
|
||||
log_mels.append(log_mel)
|
||||
# Pad audio
|
||||
audio_len = audio_len.item()
|
||||
if audio_len < max_audio_len:
|
||||
pad = (0, max_audio_len - audio_len)
|
||||
audio = torch.nn.functional.pad(audio, pad)
|
||||
audios.append(audio)
|
||||
# Pad duration_prior
|
||||
duration_priors[i, : duration_prior.shape[0], : duration_prior.shape[1]] = duration_prior
|
||||
# Pad pitch
|
||||
if len(pitch) < max_pitches_len:
|
||||
pad = (0, max_pitches_len - len(pitch))
|
||||
pitch = torch.nn.functional.pad(pitch, pad)
|
||||
pitches.append(pitch)
|
||||
# Pad energy
|
||||
if len(energy) < max_energies_len:
|
||||
pad = (0, max_energies_len - len(energy))
|
||||
energy = torch.nn.functional.pad(energy, pad)
|
||||
energies.append(energy)
|
||||
|
||||
audios = torch.stack(audios)
|
||||
audio_lengths = torch.stack(audio_lengths)
|
||||
tokens = torch.stack(tokens)
|
||||
tokens_lengths = torch.stack(tokens_lengths)
|
||||
log_mels = torch.stack(log_mels)
|
||||
log_mel_lengths = torch.stack(log_mel_lengths)
|
||||
pitches = torch.stack(pitches)
|
||||
energies = torch.stack(energies)
|
||||
|
||||
logging.debug(f"audios: {audios.shape}")
|
||||
logging.debug(f"audio_lengths: {audio_lengths.shape}")
|
||||
logging.debug(f"tokens: {tokens.shape}")
|
||||
logging.debug(f"tokens_lengths: {tokens_lengths.shape}")
|
||||
logging.debug(f"log_mels: {log_mels.shape}")
|
||||
logging.debug(f"log_mel_lengths: {log_mel_lengths.shape}")
|
||||
logging.debug(f"duration_priors: {duration_priors.shape}")
|
||||
logging.debug(f"pitches: {pitches.shape}")
|
||||
logging.debug(f"energies: {energies.shape}")
|
||||
|
||||
return (tokens, tokens_lengths, log_mels, log_mel_lengths, duration_priors, pitches, energies)
|
||||
|
||||
def decode(self, tokens):
|
||||
assert len(tokens.squeeze().shape) in [0, 1]
|
||||
return self.parser.decode(tokens)
|
||||
|
||||
|
||||
class PhoneMelAudioDataset(CharMelAudioDataset):
|
||||
@property
|
||||
def output_types(self) -> Optional[Dict[str, NeuralType]]:
|
||||
return {
|
||||
'transcripts': NeuralType(('B', 'T'), TokenIndex()),
|
||||
'transcript_length': NeuralType(('B'), LengthsType()),
|
||||
'mels': NeuralType(('B', 'D', 'T'), TokenIndex()),
|
||||
'mel_length': NeuralType(('B'), LengthsType()),
|
||||
'audio': NeuralType(('B', 'T'), AudioSignal()),
|
||||
'audio_length': NeuralType(('B'), LengthsType()),
|
||||
'duration_prior': NeuralType(('B', 'T'), TokenDurationType()),
|
||||
'pitches': NeuralType(('B', 'T'), RegressionValuesType()),
|
||||
'energies': NeuralType(('B', 'T'), RegressionValuesType()),
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
punct=True,
|
||||
stresses=False,
|
||||
spaces=True,
|
||||
chars=False,
|
||||
space=' ',
|
||||
silence=None,
|
||||
apostrophe=True,
|
||||
oov=Base.OOV,
|
||||
sep='|',
|
||||
add_blank_at="last_but_one",
|
||||
pad_with_space=False,
|
||||
improved_version_g2p=False,
|
||||
phoneme_dict_path=None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Dataset which extends CharMelAudioDataset to load phones in place of characters. It returns audio, log mel
|
||||
specs, phone tokens, duration / attention priors, pitches, and energies. Log mels, priords, pitches, and
|
||||
energies will be computed on the fly and saved in the supplementary_folder if they did not exist before. These
|
||||
supplementary files can be shared with CharMelAudioDataset.
|
||||
|
||||
Args:
|
||||
punct (bool): Whether to keep punctuation in the input. Defaults to True
|
||||
stresses (bool): Whether to add phone stresses in the input. Defaults to False
|
||||
spaces (bool): Whether to encode space characters. Defaults to True
|
||||
chars (bool): Whether to use add characters to the labels map. NOTE: The current parser class does not
|
||||
actually parse transcripts to characters. Defaults to False
|
||||
space (str): The space character. Defaults to ' '
|
||||
silence (bool): Whether to use add silence tokens. Defaults to False
|
||||
apostrophe (bool): Whether to use keep apostrophes. Defaults to True
|
||||
oov (str): How out of vocabulary tokens are decoded. Defaults to Base.OOV == "<oov>"
|
||||
sep (str): How to seperate phones when tokens are decoded. Defaults to "|"
|
||||
add_blank_at (str): Where to add the blank symbol that is used in CTC. Can be None which does not add a
|
||||
blank token in the vocab, "last" which makes self.vocab.labels[-1] the blank token, or
|
||||
"last_but_one" which makes self.vocab.labels[-2] the blank token
|
||||
pad_with_space (bool): Whether to use pad input with space tokens at start and end. Defaults to False
|
||||
improved_version_g2p (bool): Defaults to False
|
||||
phoneme_dict_path (path): Location of cmudict. Defaults to None which means the code will download it
|
||||
automatically
|
||||
"""
|
||||
if "tokenize_text" in kwargs:
|
||||
tokenize_text = kwargs.pop("tokenize_text")
|
||||
if not tokenize_text:
|
||||
logging.warning(
|
||||
f"{self} requires tokenize_text to be False. Setting it to False and ignoring provided value of "
|
||||
f"{tokenize_text}"
|
||||
)
|
||||
super().__init__(tokenize_text=False, **kwargs)
|
||||
|
||||
self.vocab = Phonemes(
|
||||
punct=punct,
|
||||
stresses=stresses,
|
||||
spaces=spaces,
|
||||
chars=chars,
|
||||
add_blank_at=add_blank_at,
|
||||
pad_with_space=pad_with_space,
|
||||
improved_version_g2p=improved_version_g2p,
|
||||
phoneme_dict_path=phoneme_dict_path,
|
||||
space=space,
|
||||
silence=silence,
|
||||
apostrophe=apostrophe,
|
||||
oov=oov,
|
||||
sep=sep,
|
||||
)
|
||||
self.pad_id = self.vocab.pad
|
||||
|
||||
def __getitem__(self, index):
|
||||
(text, _, log_mel, log_mel_length, audio, audio_length, _, pitch, energy) = super().__getitem__(index)
|
||||
|
||||
phones_tokenized = torch.tensor(self.vocab.encode(text)).long()
|
||||
phones_length = torch.tensor(len(phones_tokenized)).long()
|
||||
|
||||
### Make duration attention prior if not exist in the supplementary folder
|
||||
prior_path = Path(self.supplementary_folder) / f"pr_tl{phones_length}_al_{log_mel_length}.pt"
|
||||
if prior_path.exists():
|
||||
duration_prior = torch.load(prior_path)
|
||||
else:
|
||||
duration_prior = beta_binomial_prior_distribution(phones_length, log_mel_length)
|
||||
duration_prior = torch.from_numpy(duration_prior)
|
||||
torch.save(duration_prior, prior_path)
|
||||
|
||||
return (
|
||||
phones_tokenized,
|
||||
phones_length,
|
||||
log_mel,
|
||||
log_mel_length,
|
||||
audio,
|
||||
audio_length,
|
||||
duration_prior,
|
||||
pitch,
|
||||
energy,
|
||||
)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
def decode(self, tokens):
|
||||
"""
|
||||
Accepts a singule list of tokens, not a batch
|
||||
"""
|
||||
assert len(tokens.squeeze().shape) in [0, 1]
|
||||
return self.vocab.decode(tokens)
|
26
nemo/collections/tts/torch/helpers.py
Normal file
26
nemo/collections/tts/torch/helpers.py
Normal file
|
@ -0,0 +1,26 @@
|
|||
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import numpy as np
|
||||
from scipy.stats import betabinom
|
||||
|
||||
|
||||
def beta_binomial_prior_distribution(phoneme_count, mel_count, scaling_factor=1.0):
|
||||
x = np.arange(0, phoneme_count)
|
||||
mel_text_probs = []
|
||||
for i in range(1, mel_count + 1):
|
||||
a, b = scaling_factor * i, scaling_factor * (mel_count + 1 - i)
|
||||
mel_i_prob = betabinom(phoneme_count, a, b).pmf(x)
|
||||
mel_text_probs.append(mel_i_prob)
|
||||
return np.array(mel_text_probs)
|
106
nemo/collections/tts/torch/readme.md
Normal file
106
nemo/collections/tts/torch/readme.md
Normal file
|
@ -0,0 +1,106 @@
|
|||
# Torch TTS Collection
|
||||
|
||||
This section of code can be used by installing the requirements inside our *requirements.txt* and *requirements_torch_tts.txt*.
|
||||
|
||||
## Install
|
||||
|
||||
This collection can be installed in the following ways:
|
||||
- pip install from github
|
||||
> ```bash
|
||||
> pip install git+https://github.com/nvidia/NeMo.git#egg=nemo_toolkit[torch_tts]
|
||||
> ```
|
||||
- inside a requirements file
|
||||
> `git+https://github.com/nvidia/NeMo.git#egg=nemo_toolkit[torch_tts]`
|
||||
- cloning from github, and then installing
|
||||
> ```bash
|
||||
> git clone https://github.com/nvidia/NeMo.git && cd NeMo && pip install ".[torch_tts]"
|
||||
> ```
|
||||
|
||||
## Usage
|
||||
|
||||
We can check that lightning is not installed by checking pip:
|
||||
```bash
|
||||
pip list | grep lightning
|
||||
```
|
||||
Now even though lightning isn't installed, we can still use parts from the torch_tts collection.
|
||||
|
||||
### TTS Dataset
|
||||
|
||||
Let's import our dataset class and then loop through the batches. Note that in the sample .json files, we only have text
|
||||
and audio. Our dataset will then create the log_mels, priors, pitches, and energies and store them in `supplementary_folder`
|
||||
which in this case is `./debug0`.
|
||||
|
||||
```python
|
||||
import torch
|
||||
|
||||
from nemo.collections.tts.torch.data import CharMelAudioDataset
|
||||
|
||||
dataset = CharMelAudioDataset(
|
||||
manifest_filepath="<PATH_TO_MANIFEST_JSON>", # Path to file that describes the location of audio and text
|
||||
sample_rate=22050,
|
||||
supplementary_folder="./debug0", # An additional folder that will store log_mels, priors, pitches, and energies
|
||||
max_duration=20., # Max duration of samples in seconds
|
||||
min_duration=0.1, # Min duration of samples in seconds
|
||||
ignore_file=None,
|
||||
trim=False, # Whether to use librosa.effects.trim
|
||||
n_fft=1024,
|
||||
win_length=1024,
|
||||
hop_length=256,
|
||||
window="hann",
|
||||
n_mels=64, # Number of mel filters
|
||||
lowfreq=0, # lowfreq for mel filters
|
||||
highfreq=8000, # highfreq for mel filters
|
||||
pitch_fmin=80,
|
||||
pitch_fmax=640,
|
||||
)
|
||||
|
||||
dataloader = torch.utils.data.DataLoader(dataset, 10, collate_fn=dataset._collate_fn)
|
||||
|
||||
for batch in dataloader:
|
||||
tokens, tokens_lengths, log_mels, log_mel_lengths, duration_priors, pitches, energies = batch
|
||||
## Train models, etc.
|
||||
# Tokens represent already tokenized characters which probably will not work with previous tokenziers
|
||||
# You can get the label map from dataset.parser._labels_map. You can tokenize text via dataset.parser("text!")
|
||||
# You can detokenize using dataset.decode()
|
||||
```
|
||||
|
||||
```python
|
||||
import torch
|
||||
|
||||
from nemo.collections.tts.torch.data import PhoneMelAudioDataset
|
||||
|
||||
dataset = PhoneMelAudioDataset(
|
||||
manifest_filepath="<PATH_TO_MANIFEST_JSON>", # Path to file that describes the location of audio and text
|
||||
sample_rate=22050,
|
||||
supplementary_folder="./debug0", # An additional folder that will store log_mels, priors, pitches, and energies
|
||||
)
|
||||
|
||||
dataloader = torch.utils.data.DataLoader(dataset, 10, collate_fn=dataset._collate_fn)
|
||||
|
||||
for batch in dataloader:
|
||||
tokens, tokens_lengths, log_mels, log_mel_lengths, duration_priors, pitches, energies = batch
|
||||
## Train models, etc.
|
||||
# Tokens represent already tokenized characters which probably will not work with previous tokenziers
|
||||
# You can tokenize via dataset.vocab.encode(), and go backwards with dataset.vocab.decode().
|
||||
```
|
||||
|
||||
## NeMo Features
|
||||
|
||||
If you look into the code we see that `TextMelAudioDataset` is a child of `nemo.core.classes.Dataset`. **You do not have to subclass this to add to the torch_tts repo**. It is sufficient to use `torch.utils.data.Dataset`. Using the nemo class adds some additional features not present in torch:
|
||||
|
||||
- *(Optional)* Adding typing information
|
||||
- Looking at `TextMelAudioDataset`, it has a `output_types` that tells us how tensors are organized. For example, the `mels` returned by this dataset has dimensions B x D x T, which is short for saying the first dimension represents batch, the second represents a generic channels / n_mel_filters dimension, and the last represents time
|
||||
- *(Optional)* *(ToDo)* Enables serialization
|
||||
- We can now call `to_config_dict()` to return a dictionary which we can now pass to `from_config_dict()` to create another instanace of the dataset with the same arguments allowing us to easily restate code using these dictionaries. Please note to change any local paths if changing computers.
|
||||
|
||||
## ToDos
|
||||
|
||||
- [ ] Populate *torch_tts*
|
||||
- [x] Create a new datalayer that can be used interchangeably
|
||||
- [x] Add phone support
|
||||
- [ ] Add TTS models
|
||||
- [ ] Split Lightning away from core
|
||||
- [x] v0.1 that import checks a lot of lightning
|
||||
- [ ] Split up code (core, collections, utils) better
|
||||
- [ ] Enable building *text_normlization* without installing lightning
|
||||
- [ ] Look into how `Serialization` works without hydra
|
|
@ -25,5 +25,17 @@ from nemo.core.classes.common import (
|
|||
from nemo.core.classes.dataset import Dataset, IterableDataset
|
||||
from nemo.core.classes.exportable import Exportable, ExportFormat
|
||||
from nemo.core.classes.loss import Loss
|
||||
from nemo.core.classes.modelPT import ModelPT
|
||||
from nemo.core.classes.module import NeuralModule
|
||||
from nemo.utils import exceptions
|
||||
|
||||
# TODO @blisc: Perhaps refactor instead of import guarding
|
||||
try:
|
||||
import pytorch_lightning
|
||||
import hydra
|
||||
import omegaconf
|
||||
from nemo.core.classes.modelPT import ModelPT
|
||||
except ModuleNotFoundError:
|
||||
from nemo.utils.exceptions import CheckInstall
|
||||
|
||||
class ModelPT(CheckInstall):
|
||||
pass
|
||||
|
|
|
@ -24,17 +24,28 @@ from functools import total_ordering
|
|||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
import hydra
|
||||
import wrapt
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
|
||||
import nemo
|
||||
from nemo.core.connectors.save_restore_connector import SaveRestoreConnector
|
||||
from nemo.core.neural_types import NeuralType, NeuralTypeComparisonResult
|
||||
from nemo.utils import logging
|
||||
from nemo.utils.cloud import maybe_download_from_cloud
|
||||
from nemo.utils.model_utils import import_class_by_path, maybe_update_config_version
|
||||
|
||||
# TODO @blisc: Perhaps refactor instead of import guarding
|
||||
_HAS_HYDRA = True
|
||||
try:
|
||||
import hydra
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
from nemo.core.connectors.save_restore_connector import SaveRestoreConnector
|
||||
except ModuleNotFoundError:
|
||||
_HAS_HYDRA = False
|
||||
from nemo.utils.exceptions import CheckInstall
|
||||
|
||||
class SaveRestoreConnector(CheckInstall):
|
||||
pass
|
||||
|
||||
|
||||
__all__ = ['Typing', 'FileIO', 'Model', 'Serialization', 'typecheck']
|
||||
|
||||
|
||||
|
@ -418,22 +429,23 @@ class Typing(ABC):
|
|||
|
||||
class Serialization(ABC):
|
||||
@classmethod
|
||||
def from_config_dict(cls, config: DictConfig):
|
||||
def from_config_dict(cls, config: 'DictConfig'):
|
||||
"""Instantiates object using DictConfig-based configuration"""
|
||||
# Resolve the config dict
|
||||
if isinstance(config, DictConfig):
|
||||
config = OmegaConf.to_container(config, resolve=True)
|
||||
config = OmegaConf.create(config)
|
||||
OmegaConf.set_struct(config, True)
|
||||
if _HAS_HYDRA:
|
||||
if isinstance(config, DictConfig):
|
||||
config = OmegaConf.to_container(config, resolve=True)
|
||||
config = OmegaConf.create(config)
|
||||
OmegaConf.set_struct(config, True)
|
||||
|
||||
config = maybe_update_config_version(config)
|
||||
config = maybe_update_config_version(config)
|
||||
|
||||
# Hydra 0.x API
|
||||
if ('cls' in config or 'target' in config) and 'params' in config:
|
||||
if ('cls' in config or 'target' in config) and 'params' in config and _HAS_HYDRA:
|
||||
# regular hydra-based instantiation
|
||||
instance = hydra.utils.instantiate(config=config)
|
||||
# Hydra 1.x API
|
||||
elif '_target_' in config:
|
||||
elif '_target_' in config and _HAS_HYDRA:
|
||||
# regular hydra-based instantiation
|
||||
instance = hydra.utils.instantiate(config=config)
|
||||
else:
|
||||
|
@ -441,7 +453,7 @@ class Serialization(ABC):
|
|||
imported_cls_tb = None
|
||||
# Attempt class path resolution from config `target` class (if it exists)
|
||||
if 'target' in config:
|
||||
target_cls = config.target
|
||||
target_cls = config["target"] # No guarantee that this is a omegaconf class
|
||||
imported_cls = None
|
||||
try:
|
||||
# try to import the target class
|
||||
|
@ -476,15 +488,16 @@ class Serialization(ABC):
|
|||
instance._cfg = config
|
||||
return instance
|
||||
|
||||
def to_config_dict(self) -> DictConfig:
|
||||
def to_config_dict(self) -> 'DictConfig':
|
||||
"""Returns object's configuration to config dictionary"""
|
||||
if hasattr(self, '_cfg') and self._cfg is not None and isinstance(self._cfg, DictConfig):
|
||||
if hasattr(self, '_cfg') and self._cfg is not None:
|
||||
# Resolve the config dict
|
||||
config = OmegaConf.to_container(self._cfg, resolve=True)
|
||||
config = OmegaConf.create(config)
|
||||
OmegaConf.set_struct(config, True)
|
||||
if _HAS_HYDRA and isinstance(self._cfg, DictConfig):
|
||||
config = OmegaConf.to_container(self._cfg, resolve=True)
|
||||
config = OmegaConf.create(config)
|
||||
OmegaConf.set_struct(config, True)
|
||||
|
||||
config = maybe_update_config_version(config)
|
||||
config = maybe_update_config_version(config)
|
||||
|
||||
self._cfg = config
|
||||
|
||||
|
|
|
@ -156,7 +156,7 @@ class ModelPT(LightningModule, Model):
|
|||
self, config_path: str, src: str, verify_src_exists: bool = True,
|
||||
):
|
||||
""" Register model artifacts with this function. These artifacts (files) will be included inside .nemo file
|
||||
when model.save_to("mymodel.nemo") is called.
|
||||
when model.save_to("mymodel.nemo") is called.
|
||||
|
||||
How it works:
|
||||
1. It always returns existing absolute path which can be used during Model constructor call
|
||||
|
@ -174,7 +174,7 @@ class ModelPT(LightningModule, Model):
|
|||
Args:
|
||||
config_path (str): Artifact key. Usually corresponds to the model config.
|
||||
src (str): Path to artifact.
|
||||
verify_src_exists (bool): If set to False, then the artifact is optional and register_artifact will return None even if
|
||||
verify_src_exists (bool): If set to False, then the artifact is optional and register_artifact will return None even if
|
||||
src is not found. Defaults to True.
|
||||
save_restore_connector (SaveRestoreConnector): Can be overrided to add custom save and restore logic.
|
||||
|
||||
|
|
|
@ -16,7 +16,11 @@
|
|||
from nemo.utils.app_state import AppState
|
||||
from nemo.utils.nemo_logging import Logger as _Logger
|
||||
from nemo.utils.nemo_logging import LogMode as logging_mode
|
||||
from nemo.utils.lightning_logger_patch import add_memory_handlers_to_pl_logger
|
||||
|
||||
logging = _Logger()
|
||||
add_memory_handlers_to_pl_logger()
|
||||
try:
|
||||
from nemo.utils.lightning_logger_patch import add_memory_handlers_to_pl_logger
|
||||
|
||||
add_memory_handlers_to_pl_logger()
|
||||
except ModuleNotFoundError:
|
||||
pass
|
||||
|
|
|
@ -17,13 +17,19 @@ import inspect
|
|||
from dataclasses import is_dataclass
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from omegaconf import DictConfig, OmegaConf, open_dict
|
||||
|
||||
from nemo.core.config.modelPT import NemoConfig
|
||||
from nemo.utils import logging
|
||||
|
||||
# TODO @blisc: Perhaps refactor instead of import guarding
|
||||
_HAS_HYDRA = True
|
||||
try:
|
||||
from omegaconf import DictConfig, OmegaConf, open_dict
|
||||
except ModuleNotFoundError:
|
||||
_HAS_HYDRA = False
|
||||
|
||||
def update_model_config(model_cls: NemoConfig, update_cfg: DictConfig, drop_missing_subconfigs: bool = True):
|
||||
|
||||
def update_model_config(
|
||||
model_cls: 'nemo.core.config.modelPT.NemoConfig', update_cfg: 'DictConfig', drop_missing_subconfigs: bool = True
|
||||
):
|
||||
"""
|
||||
Helper class that updates the default values of a ModelPT config class with the values
|
||||
in a DictConfig that mirrors the structure of the config class.
|
||||
|
@ -52,6 +58,9 @@ def update_model_config(model_cls: NemoConfig, update_cfg: DictConfig, drop_miss
|
|||
A DictConfig with updated values that can be used to instantiate the NeMo Model along with supporting
|
||||
infrastructure.
|
||||
"""
|
||||
if not _HAS_HYDRA:
|
||||
logging.error("This function requires Hydra/Omegaconf and it was not installed.")
|
||||
exit(1)
|
||||
if not (is_dataclass(model_cls) or isinstance(model_cls, DictConfig)):
|
||||
raise ValueError("`model_cfg` must be a dataclass or a structured OmegaConf object")
|
||||
|
||||
|
@ -92,7 +101,7 @@ def update_model_config(model_cls: NemoConfig, update_cfg: DictConfig, drop_miss
|
|||
|
||||
|
||||
def _update_subconfig(
|
||||
model_cfg: DictConfig, update_cfg: DictConfig, subconfig_key: str, drop_missing_subconfigs: bool
|
||||
model_cfg: 'DictConfig', update_cfg: 'DictConfig', subconfig_key: str, drop_missing_subconfigs: bool
|
||||
):
|
||||
"""
|
||||
Updates the NemoConfig DictConfig such that:
|
||||
|
@ -112,6 +121,9 @@ def _update_subconfig(
|
|||
Returns:
|
||||
The updated DictConfig for the NemoConfig
|
||||
"""
|
||||
if not _HAS_HYDRA:
|
||||
logging.error("This function requires Hydra/Omegaconf and it was not installed.")
|
||||
exit(1)
|
||||
with open_dict(model_cfg.model):
|
||||
# If update config has the key, but model cfg doesnt have the key
|
||||
# Add the update cfg subconfig to the model cfg
|
||||
|
@ -127,7 +139,7 @@ def _update_subconfig(
|
|||
return model_cfg
|
||||
|
||||
|
||||
def _add_subconfig_keys(model_cfg: DictConfig, update_cfg: DictConfig, subconfig_key: str):
|
||||
def _add_subconfig_keys(model_cfg: 'DictConfig', update_cfg: 'DictConfig', subconfig_key: str):
|
||||
"""
|
||||
For certain sub-configs, the default values specified by the NemoConfig class is insufficient.
|
||||
In order to support every potential value in the merge between the `update_cfg`, it would require
|
||||
|
@ -155,6 +167,9 @@ def _add_subconfig_keys(model_cfg: DictConfig, update_cfg: DictConfig, subconfig
|
|||
Returns:
|
||||
A ModelPT DictConfig with additional keys added to the sub-config.
|
||||
"""
|
||||
if not _HAS_HYDRA:
|
||||
logging.error("This function requires Hydra/Omegaconf and it was not installed.")
|
||||
exit(1)
|
||||
with open_dict(model_cfg.model):
|
||||
# Create copy of original model sub config
|
||||
if subconfig_key in update_cfg.model:
|
||||
|
|
|
@ -16,4 +16,22 @@
|
|||
class NeMoBaseException(Exception):
|
||||
""" NeMo Base Exception. All exceptions created in NeMo should inherit from this class"""
|
||||
|
||||
pass
|
||||
|
||||
class LightningNotInstalledException(NeMoBaseException):
|
||||
def __init__(self, obj):
|
||||
message = (
|
||||
f" You are trying to use {obj} without installing all of pytorch_lightning, hydra, and "
|
||||
f"omegaconf. Please install those packages before trying to access {obj}."
|
||||
)
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class CheckInstall:
|
||||
def __init__(self, *args, **kwargs):
|
||||
raise LightningNotInstalledException(self)
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
raise LightningNotInstalledException(self)
|
||||
|
||||
def __getattr__(self, *args, **kwargs):
|
||||
raise LightningNotInstalledException(self)
|
||||
|
|
|
@ -19,14 +19,22 @@ from enum import Enum
|
|||
from pathlib import Path
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import pytorch_lightning as pl
|
||||
import wrapt
|
||||
from omegaconf import DictConfig, ListConfig, OmegaConf
|
||||
from omegaconf import errors as omegaconf_errors
|
||||
from packaging import version
|
||||
|
||||
from nemo.utils import logging
|
||||
|
||||
# TODO @blisc: Perhaps refactor instead of import guarding
|
||||
|
||||
_HAS_HYDRA = True
|
||||
|
||||
try:
|
||||
from omegaconf import DictConfig, ListConfig, OmegaConf
|
||||
from omegaconf import errors as omegaconf_errors
|
||||
from packaging import version
|
||||
except ModuleNotFoundError:
|
||||
_HAS_HYDRA = False
|
||||
|
||||
|
||||
_VAL_TEST_FASTPATH_KEY = 'ds_item'
|
||||
|
||||
|
||||
|
@ -49,7 +57,7 @@ class ArtifactItem:
|
|||
hashed_path: Optional[str] = None
|
||||
|
||||
|
||||
def resolve_dataset_name_from_cfg(cfg: DictConfig) -> str:
|
||||
def resolve_dataset_name_from_cfg(cfg: 'DictConfig') -> str:
|
||||
"""
|
||||
Parses items of the provided sub-config to find the first potential key that
|
||||
resolves to an existing file or directory.
|
||||
|
@ -218,6 +226,9 @@ def resolve_validation_dataloaders(model: 'ModelPT'):
|
|||
Args:
|
||||
model: ModelPT subclass, which requires >=1 Validation Dataloaders to be setup.
|
||||
"""
|
||||
if not _HAS_HYDRA:
|
||||
logging.error("This function requires Hydra/Omegaconf and it was not installed.")
|
||||
exit(1)
|
||||
cfg = copy.deepcopy(model._cfg)
|
||||
dataloaders = []
|
||||
|
||||
|
@ -287,6 +298,9 @@ def resolve_test_dataloaders(model: 'ModelPT'):
|
|||
Args:
|
||||
model: ModelPT subclass, which requires >=1 Test Dataloaders to be setup.
|
||||
"""
|
||||
if not _HAS_HYDRA:
|
||||
logging.error("This function requires Hydra/Omegaconf and it was not installed.")
|
||||
exit(1)
|
||||
cfg = copy.deepcopy(model._cfg)
|
||||
dataloaders = []
|
||||
|
||||
|
@ -335,7 +349,7 @@ def resolve_test_dataloaders(model: 'ModelPT'):
|
|||
|
||||
|
||||
@wrapt.decorator
|
||||
def wrap_training_step(wrapped, instance: pl.LightningModule, args, kwargs):
|
||||
def wrap_training_step(wrapped, instance: 'pl.LightningModule', args, kwargs):
|
||||
output_dict = wrapped(*args, **kwargs)
|
||||
|
||||
if isinstance(output_dict, dict) and output_dict is not None and 'log' in output_dict:
|
||||
|
@ -345,7 +359,7 @@ def wrap_training_step(wrapped, instance: pl.LightningModule, args, kwargs):
|
|||
return output_dict
|
||||
|
||||
|
||||
def convert_model_config_to_dict_config(cfg: Union[DictConfig, 'NemoConfig']) -> DictConfig:
|
||||
def convert_model_config_to_dict_config(cfg: Union['DictConfig', 'NemoConfig']) -> 'DictConfig':
|
||||
"""
|
||||
Converts its input into a standard DictConfig.
|
||||
Possible input values are:
|
||||
|
@ -358,6 +372,9 @@ def convert_model_config_to_dict_config(cfg: Union[DictConfig, 'NemoConfig']) ->
|
|||
Returns:
|
||||
The equivalent DictConfig
|
||||
"""
|
||||
if not _HAS_HYDRA:
|
||||
logging.error("This function requires Hydra/Omegaconf and it was not installed.")
|
||||
exit(1)
|
||||
if not isinstance(cfg, (OmegaConf, DictConfig)) and is_dataclass(cfg):
|
||||
cfg = OmegaConf.structured(cfg)
|
||||
|
||||
|
@ -369,8 +386,11 @@ def convert_model_config_to_dict_config(cfg: Union[DictConfig, 'NemoConfig']) ->
|
|||
return config
|
||||
|
||||
|
||||
def _convert_config(cfg: OmegaConf):
|
||||
def _convert_config(cfg: 'OmegaConf'):
|
||||
""" Recursive function convertint the configuration from old hydra format to the new one. """
|
||||
if not _HAS_HYDRA:
|
||||
logging.error("This function requires Hydra/Omegaconf and it was not installed.")
|
||||
exit(1)
|
||||
|
||||
# Get rid of cls -> _target_.
|
||||
if 'cls' in cfg and '_target_' not in cfg:
|
||||
|
@ -391,7 +411,7 @@ def _convert_config(cfg: OmegaConf):
|
|||
logging.warning(f"Skipped conversion for config/subconfig:\n{cfg}\n Reason: {e}.")
|
||||
|
||||
|
||||
def maybe_update_config_version(cfg: DictConfig):
|
||||
def maybe_update_config_version(cfg: 'DictConfig'):
|
||||
"""
|
||||
Recursively convert Hydra 0.x configs to Hydra 1.x configs.
|
||||
|
||||
|
@ -406,6 +426,9 @@ def maybe_update_config_version(cfg: DictConfig):
|
|||
Returns:
|
||||
An updated DictConfig that conforms to Hydra 1.x format.
|
||||
"""
|
||||
if not _HAS_HYDRA:
|
||||
logging.error("This function requires Hydra/Omegaconf and it was not installed.")
|
||||
exit(1)
|
||||
if cfg is not None and not isinstance(cfg, DictConfig):
|
||||
try:
|
||||
temp_cfg = OmegaConf.create(cfg)
|
||||
|
|
|
@ -1,19 +1,13 @@
|
|||
numpy>=1.18.2
|
||||
onnx>=1.7.0
|
||||
pytorch-lightning>=1.4.0
|
||||
python-dateutil
|
||||
torch
|
||||
torchmetrics>=0.4.1rc0
|
||||
wget
|
||||
torch>1.7
|
||||
wrapt
|
||||
ruamel.yaml
|
||||
scikit-learn
|
||||
omegaconf>=2.1.0
|
||||
hydra-core>=1.1.0
|
||||
transformers>=4.8.1
|
||||
sentencepiece<1.0.0
|
||||
webdataset>=0.1.48,<=0.1.62
|
||||
tqdm>=4.41.0
|
||||
numba
|
||||
grpcio
|
||||
grpcio-tools
|
||||
wget
|
||||
frozendict
|
||||
unidecode
|
||||
|
|
|
@ -1,6 +1,5 @@
|
|||
braceexpand
|
||||
editdistance
|
||||
frozendict
|
||||
inflect
|
||||
kaldi-io
|
||||
librosa
|
||||
|
@ -9,7 +8,6 @@ packaging
|
|||
ruamel.yaml
|
||||
soundfile
|
||||
sox
|
||||
unidecode
|
||||
kaldi-python-io
|
||||
kaldiio
|
||||
scipy
|
||||
|
|
9
requirements/requirements_lightning.txt
Normal file
9
requirements/requirements_lightning.txt
Normal file
|
@ -0,0 +1,9 @@
|
|||
pytorch-lightning>=1.3.0
|
||||
torchmetrics>=0.4.1rc0
|
||||
transformers>=4.0.1
|
||||
webdataset>=0.1.48,<=0.1.62
|
||||
opencc
|
||||
pangu
|
||||
jieba
|
||||
omegaconf>=2.1.0
|
||||
hydra-core>=1.1.0
|
|
@ -2,7 +2,6 @@ boto3
|
|||
h5py
|
||||
matplotlib>=3.3.2
|
||||
sentencepiece
|
||||
unidecode
|
||||
youtokentome>=1.0.5
|
||||
numpy
|
||||
rapidfuzz
|
||||
|
|
7
requirements/requirements_torch_tts.txt
Normal file
7
requirements/requirements_torch_tts.txt
Normal file
|
@ -0,0 +1,7 @@
|
|||
matplotlib
|
||||
pypinyin
|
||||
attrdict
|
||||
pystoi
|
||||
pesq
|
||||
g2p_en
|
||||
pandas
|
|
@ -1,7 +1,2 @@
|
|||
matplotlib
|
||||
pypinyin
|
||||
attrdict
|
||||
pystoi
|
||||
pesq
|
||||
g2p_en
|
||||
librosa
|
||||
nltk
|
||||
|
|
|
@ -29,6 +29,7 @@ markers =
|
|||
docs: mark tests related to documentation (deselect with '-m "not docs"')
|
||||
skipduringci: marks tests that are skipped ci as they are addressed by Jenkins jobs but should be run to test user setups
|
||||
pleasefixme: marks tests that are broken and need fixing
|
||||
torch_tts: marks tests that are to be run when "torch_tts" is installed but not all nemo packages
|
||||
|
||||
[isort]
|
||||
known_localfolder = nemo,tests
|
||||
|
|
19
setup.py
19
setup.py
|
@ -79,18 +79,29 @@ install_requires = req_file("requirements.txt")
|
|||
extras_require = {
|
||||
# User packages
|
||||
'test': req_file("requirements_test.txt"),
|
||||
# Collections Packages
|
||||
# NeMo Tools
|
||||
'text_processing': req_file("requirements_text_processing.txt"),
|
||||
# Torch Packages
|
||||
'torch_tts': req_file("requirements_torch_tts.txt"),
|
||||
# Lightning Collections Packages
|
||||
'core': req_file("requirements_lightning.txt"),
|
||||
'asr': req_file("requirements_asr.txt"),
|
||||
'cv': req_file("requirements_cv.txt"),
|
||||
'nlp': req_file("requirements_nlp.txt"),
|
||||
'tts': req_file("requirements_tts.txt"),
|
||||
'text_processing': req_file("requirements_text_processing.txt"),
|
||||
}
|
||||
|
||||
extras_require['all'] = list(chain(extras_require.values()))
|
||||
|
||||
# TTS depends on ASR
|
||||
extras_require['tts'] = list(chain([extras_require['tts'], extras_require['asr']]))
|
||||
# Add lightning requirements as needed
|
||||
extras_require['test'] = list(chain([extras_require['tts'], extras_require['core']]))
|
||||
extras_require['asr'] = list(chain([extras_require['asr'], extras_require['core']]))
|
||||
extras_require['cv'] = list(chain([extras_require['cv'], extras_require['core']]))
|
||||
extras_require['nlp'] = list(chain([extras_require['nlp'], extras_require['core']]))
|
||||
extras_require['tts'] = list(chain([extras_require['tts'], extras_require['core']]))
|
||||
|
||||
# TTS has extra dependencies
|
||||
extras_require['tts'] = list(chain([extras_require['tts'], extras_require['asr'], extras_require['torch_tts']]))
|
||||
|
||||
tests_requirements = extras_require["test"]
|
||||
|
||||
|
|
54
tests/collections/tts/test_torch_tts.py
Normal file
54
tests/collections/tts/test_torch_tts.py
Normal file
|
@ -0,0 +1,54 @@
|
|||
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from nemo.collections.tts.torch.data import CharMelAudioDataset
|
||||
|
||||
|
||||
class TestCharDataset:
|
||||
@pytest.mark.run_only_on('CPU')
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.torch_tts
|
||||
def test_dataset(self, test_data_dir):
|
||||
manifest_path = os.path.join(test_data_dir, 'tts/mini_ljspeech/manifest.json')
|
||||
sup_path = os.path.join(test_data_dir, 'tts/mini_ljspeech/sup')
|
||||
|
||||
dataset = CharMelAudioDataset(
|
||||
manifest_filepath=manifest_path, sample_rate=22050, supplementary_folder=sup_path
|
||||
)
|
||||
|
||||
dataloader = torch.utils.data.DataLoader(dataset, 2, collate_fn=dataset._collate_fn)
|
||||
|
||||
data, _, _, _, _, _, _ = next(iter(dataloader))
|
||||
|
||||
|
||||
class TestPhoneDataset:
|
||||
@pytest.mark.run_only_on('CPU')
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.torch_tts
|
||||
def test_dataset(self, test_data_dir):
|
||||
manifest_path = os.path.join(test_data_dir, 'tts/mini_ljspeech/manifest.json')
|
||||
sup_path = os.path.join(test_data_dir, 'tts/mini_ljspeech/sup')
|
||||
|
||||
dataset = CharMelAudioDataset(
|
||||
manifest_filepath=manifest_path, sample_rate=22050, supplementary_folder=sup_path
|
||||
)
|
||||
|
||||
dataloader = torch.utils.data.DataLoader(dataset, 2, collate_fn=dataset._collate_fn)
|
||||
|
||||
_, _, _, _, _, _, _ = next(iter(dataloader))
|
|
@ -15,7 +15,6 @@
|
|||
import os
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from nemo.collections.tts.data.datalayers import FastSpeech2Dataset
|
||||
|
||||
|
|
Loading…
Reference in a new issue