New TTSDataset, tts tokenizers and g2ps (#2792)
* new vocabs and g2ps for tts Signed-off-by: Oktai Tatanov <oktai.tatanov@gmail.com> * fix style Signed-off-by: Oktai Tatanov <oktai.tatanov@gmail.com> * update tts torch data Signed-off-by: Oktai Tatanov <oktai.tatanov@gmail.com> * update g2p modules, data and add example for tts vocabs Signed-off-by: Oktai Tatanov <oktai.tatanov@gmail.com> * fix style Signed-off-by: Oktai Tatanov <oktai.tatanov@gmail.com> * update tts dataset Signed-off-by: Oktai Tatanov <oktai.tatanov@gmail.com> * add tokens field to tts dataset Signed-off-by: Oktai Tatanov <oktai.tatanov@gmail.com> * update tts dataset Signed-off-by: Oktai Tatanov <oktai.tatanov@gmail.com> * add TTSDataset and docs for all of them Signed-off-by: Oktai Tatanov <oktai.tatanov@gmail.com> * fix paths in yaml Signed-off-by: Oktai Tatanov <oktai.tatanov@gmail.com> * update test for tts dataset Signed-off-by: Oktai Tatanov <oktai.tatanov@gmail.com> * add heteronyms-030921 file to scripts folder Signed-off-by: Oktai Tatanov <oktai.tatanov@gmail.com> * change requirements_torch_tts.txt Signed-off-by: Oktai Tatanov <oktai.tatanov@gmail.com> * add tts_data_types.py Signed-off-by: Oktai Tatanov <oktai.tatanov@gmail.com> * fix style tts_data_types.py Signed-off-by: Oktai Tatanov <oktai.tatanov@gmail.com> * update yaml and comments Signed-off-by: Oktai Tatanov <oktai.tatanov@gmail.com> * update cmu dict and tts ds config Signed-off-by: Oktai Tatanov <oktai.tatanov@gmail.com> * remove unnecessary argument from tokenizers Signed-off-by: Oktai Tatanov <oktai.tatanov@gmail.com> * update test Signed-off-by: Oktai Tatanov <oktai.tatanov@gmail.com>
This commit is contained in:
parent
3f606194f2
commit
3cde074436
|
@ -16,44 +16,44 @@
|
|||
import json
|
||||
import pickle
|
||||
from pathlib import Path
|
||||
from typing import Dict, Optional
|
||||
from typing import Callable, Dict, List, Optional, Union
|
||||
|
||||
import librosa
|
||||
import torch
|
||||
from nemo_text_processing.text_normalization.normalize import Normalizer
|
||||
from tqdm import tqdm
|
||||
|
||||
from nemo.collections.asr.data.vocabs import Base, Phonemes
|
||||
from nemo.collections.asr.parts.preprocessing.features import WaveformFeaturizer
|
||||
from nemo.collections.common.parts.patch_utils import stft_patch
|
||||
from nemo.collections.common.parts.preprocessing.parsers import make_parser
|
||||
from nemo.collections.tts.torch.helpers import beta_binomial_prior_distribution
|
||||
from nemo.collections.tts.torch.helpers import beta_binomial_prior_distribution, general_padding
|
||||
from nemo.collections.tts.torch.tts_data_types import (
|
||||
DATA_STR2DATA_CLASS,
|
||||
MAIN_DATA_TYPES,
|
||||
VALID_SUPPLEMENTARY_DATA_TYPES,
|
||||
DurationPrior,
|
||||
Durations,
|
||||
Energy,
|
||||
LogMel,
|
||||
Pitch,
|
||||
WithLens,
|
||||
)
|
||||
from nemo.collections.tts.torch.tts_tokenizers import BaseTokenizer
|
||||
from nemo.core.classes import Dataset
|
||||
from nemo.core.neural_types.elements import *
|
||||
from nemo.core.neural_types.neural_type import NeuralType
|
||||
from nemo.utils import logging
|
||||
|
||||
CONSTANT = 1e-5
|
||||
|
||||
|
||||
class CharMelAudioDataset(Dataset):
|
||||
@property
|
||||
def output_types(self) -> Optional[Dict[str, NeuralType]]:
|
||||
return {
|
||||
'transcripts': NeuralType(('B', 'T'), TokenIndex()),
|
||||
'transcript_length': NeuralType(('B'), LengthsType()),
|
||||
'mels': NeuralType(('B', 'D', 'T'), TokenIndex()),
|
||||
'mel_length': NeuralType(('B'), LengthsType()),
|
||||
'audio': NeuralType(('B', 'T'), AudioSignal()),
|
||||
'audio_length': NeuralType(('B'), LengthsType()),
|
||||
'duration_prior': NeuralType(('B', 'T'), TokenDurationType()),
|
||||
'pitches': NeuralType(('B', 'T'), RegressionValuesType()),
|
||||
'energies': NeuralType(('B', 'T'), RegressionValuesType()),
|
||||
}
|
||||
|
||||
class TTSDataset(Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
manifest_filepath: str,
|
||||
sample_rate: int,
|
||||
supplementary_folder: Path,
|
||||
text_tokenizer: Union[BaseTokenizer, Callable[[str], List[int]]],
|
||||
tokens: Optional[List[str]] = None,
|
||||
text_normalizer: Optional[Union[Normalizer, Callable[[str], str]]] = None,
|
||||
text_normalizer_call_args: Optional[Dict] = None,
|
||||
text_tokenizer_pad_id: Optional[int] = None,
|
||||
sup_data_types: Optional[List[str]] = None,
|
||||
sup_data_path: Optional[Union[Path, str]] = None,
|
||||
max_duration: Optional[float] = None,
|
||||
min_duration: Optional[float] = None,
|
||||
ignore_file: Optional[str] = None,
|
||||
|
@ -62,19 +62,14 @@ class CharMelAudioDataset(Dataset):
|
|||
win_length=None,
|
||||
hop_length=None,
|
||||
window="hann",
|
||||
n_mels=64,
|
||||
n_mels=80,
|
||||
lowfreq=0,
|
||||
highfreq=None,
|
||||
pitch_fmin=80,
|
||||
pitch_fmax=640,
|
||||
pitch_avg=0,
|
||||
pitch_std=1,
|
||||
tokenize_text=True,
|
||||
**kwargs,
|
||||
):
|
||||
"""Dataset that loads audio, log mel specs, text tokens, duration / attention priors, pitches, and energies.
|
||||
Log mels, priords, pitches, and energies will be computed on the fly and saved in the supplementary_folder if
|
||||
they did not exist before.
|
||||
|
||||
"""Dataset that loads main data types (audio and text) and specified supplementary data types (e.g. log mel, durations, pitch).
|
||||
Most supplementary data types will be computed on the fly and saved in the supplementary_folder if they did not exist before.
|
||||
Arguments for supplementary data should be also specified in this class and they will be used from kwargs (see keyword args section).
|
||||
Args:
|
||||
manifest_filepath (str, Path, List[str, Path]): Path(s) to the .json manifests containing information on the
|
||||
dataset. Each line in the .json file should be valid json. Note: the .json file itself is not valid
|
||||
|
@ -84,8 +79,13 @@ class CharMelAudioDataset(Dataset):
|
|||
"duration": <Duration of audio clip in seconds> (Optional)
|
||||
"text": <THE_TRANSCRIPT> (Optional)
|
||||
sample_rate (int): The sample rate of the audio. Or the sample rate that we will resample all files to.
|
||||
supplementary_folder (Path): A folder that contains or will contain extra information such as log_mel if not
|
||||
specified in the manifest .json file. It will also contain priors, pitches, and energies
|
||||
text_tokenizer (Optional[Union[BaseTokenizer, Callable[[str], List[int]]]]): BaseTokenizer or callable which represents text tokenizer.
|
||||
tokens (Optional[List[str]]): Tokens from text_tokenizer. Should be specified if text_tokenizer is not BaseTokenizer.
|
||||
text_normalizer (Optional[Union[Normalizer, Callable[[str], str]]]): Normalizer or callable which represents text normalizer.
|
||||
text_normalizer_call_args (Optional[Dict]): Additional arguments for text_normalizer function.
|
||||
text_tokenizer_pad_id (Optional[int]): Index of padding. Should be specified if text_tokenizer is not BaseTokenizer.
|
||||
sup_data_types (Optional[List[str]]): List of supplementary data types.
|
||||
sup_data_path (Optional[Union[Path, str]]): A folder that contains or will contain supplementary data (e.g. pitch).
|
||||
max_duration (Optional[float]): Max duration of audio clips in seconds. All samples exceeding this will be
|
||||
pruned prior to training. Note: Requires "duration" to be set in the manifest file. It does not load
|
||||
audio to compute duration. Defaults to None which does not prune.
|
||||
|
@ -100,71 +100,94 @@ class CharMelAudioDataset(Dataset):
|
|||
hop_length (Optional[int]): The hope length between fft computations. Defaults to None which uses n_fft//4.
|
||||
window (Optional[str]): One of 'hann', 'hamming', 'blackman','bartlett', 'none'. Which corresponds to the
|
||||
equivalent torch window function.
|
||||
n_mels (Optional[int]): The number of mel filters. Defaults to 64.
|
||||
n_mels (Optional[int]): The number of mel filters. Defaults to 80.
|
||||
lowfreq (Optional[int]): The lowfreq input to the mel filter calculation. Defaults to 0.
|
||||
highfreq (Optional[int]): The highfreq input to the mel filter calculation. Defaults to None.
|
||||
pitch_fmin (Optional[int]): The fmin input to librosa.pyin. Defaults to None.
|
||||
pitch_fmax (Optional[int]): The fmax input to librosa.pyin. Defaults to None.
|
||||
pitch_avg (Optional[float]): The mean that we use to normalize the pitch. Defaults to 0.
|
||||
pitch_std (Optional[float]): The std that we use to normalize the pitch. Defaults to 1.
|
||||
tokenize_text (Optional[bool]): Whether to tokenize (turn chars into ints). Defaults to True.
|
||||
Keyword Args:
|
||||
durs_file (Optional[str]): String path to pickled durations location.
|
||||
durs_type (Optional[str]): Type of durations. Currently supported only "aligned-based".
|
||||
pitch_fmin (Optional[float]): The fmin input to librosa.pyin. Defaults to librosa.note_to_hz('C2').
|
||||
pitch_fmax (Optional[float]): The fmax input to librosa.pyin. Defaults to librosa.note_to_hz('C7').
|
||||
pitch_avg (Optional[float]): The mean that we use to normalize the pitch.
|
||||
pitch_std (Optional[float]): The std that we use to normalize the pitch.
|
||||
pitch_norm (Optional[bool]): Whether to normalize pitch (via pitch_avg and pitch_std) or not.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.pitch_fmin = pitch_fmin
|
||||
self.pitch_fmax = pitch_fmax
|
||||
self.pitch_avg = pitch_avg
|
||||
self.pitch_std = pitch_std
|
||||
self.win_length = win_length or n_fft
|
||||
self.sample_rate = sample_rate
|
||||
self.hop_len = hop_length or n_fft // 4
|
||||
self.text_normalizer = text_normalizer
|
||||
self.text_normalizer_call = (
|
||||
self.text_normalizer.normalize if isinstance(self.text_normalizer, Normalizer) else self.text_normalizer
|
||||
)
|
||||
self.text_normalizer_call_args = text_normalizer_call_args
|
||||
|
||||
self.parser = make_parser(name="en", do_tokenize=tokenize_text)
|
||||
self.pad_id = self.parser._blank_id
|
||||
Path(supplementary_folder).mkdir(parents=True, exist_ok=True)
|
||||
self.supplementary_folder = supplementary_folder
|
||||
self.text_tokenizer = text_tokenizer
|
||||
|
||||
if isinstance(self.text_tokenizer, BaseTokenizer):
|
||||
self.text_tokenizer_pad_id = text_tokenizer.pad
|
||||
self.tokens = text_tokenizer.tokens
|
||||
else:
|
||||
if text_tokenizer_pad_id is None:
|
||||
raise ValueError(f"text_tokenizer_pad_id must be specified if text_tokenizer is not BaseTokenizer")
|
||||
|
||||
if tokens is None:
|
||||
raise ValueError(f"tokens must be specified if text_tokenizer is not BaseTokenizer")
|
||||
|
||||
self.text_tokenizer_pad_id = text_tokenizer_pad_id
|
||||
self.tokens = tokens
|
||||
|
||||
audio_files = []
|
||||
total_duration = 0
|
||||
# Load data from manifests
|
||||
# Note: audio is always required, even for text -> mel_spectrogram models, due to the fact that most models
|
||||
# extract pitch from the audio
|
||||
# Note: mel_filepath is not required and if not present, we then check the supplementary folder. If we fail, we
|
||||
# compute the mel on the fly and save it to the supplementary folder
|
||||
# Note: text is not required. Any models that require on text (spectrogram generators, end-to-end models) will
|
||||
# fail if not set. However vocoders (mel -> audio) will be able to work without text
|
||||
if isinstance(manifest_filepath, str):
|
||||
manifest_filepath = [manifest_filepath]
|
||||
for manifest_file in manifest_filepath:
|
||||
self.manifest_filepath = manifest_filepath
|
||||
|
||||
if sup_data_path is not None:
|
||||
Path(sup_data_path).mkdir(parents=True, exist_ok=True)
|
||||
self.sup_data_path = sup_data_path
|
||||
|
||||
self.sup_data_types = (
|
||||
[DATA_STR2DATA_CLASS[d_as_str] for d_as_str in sup_data_types] if sup_data_types is not None else []
|
||||
)
|
||||
self.sup_data_types_set = set(self.sup_data_types)
|
||||
|
||||
self.data = []
|
||||
audio_files = []
|
||||
total_duration = 0
|
||||
for manifest_file in self.manifest_filepath:
|
||||
with open(Path(manifest_file).expanduser(), 'r') as f:
|
||||
logging.info(f"Loading dataset from {manifest_file}.")
|
||||
for line in f:
|
||||
for line in tqdm(f):
|
||||
item = json.loads(line)
|
||||
# Grab audio, text, mel if they exist
|
||||
file_info = {}
|
||||
file_info["audio_filepath"] = item["audio_filepath"]
|
||||
file_info["mel_filepath"] = item["mel_filepath"] if "mel_filepath" in item else None
|
||||
file_info["duration"] = item["duration"] if "duration" in item else None
|
||||
# Parse text
|
||||
file_info["text_tokens"] = None
|
||||
|
||||
file_info = {
|
||||
"audio_filepath": item["audio_filepath"],
|
||||
"mel_filepath": item["mel_filepath"] if "mel_filepath" in item else None,
|
||||
"duration": item["duration"] if "duration" in item else None,
|
||||
"text_tokens": None,
|
||||
}
|
||||
|
||||
if "text" in item:
|
||||
text = item["text"]
|
||||
text_tokens = self.parser(text)
|
||||
|
||||
if self.text_normalizer is not None:
|
||||
text = self.text_normalizer_call(text, **self.text_normalizer_call_args)
|
||||
|
||||
text_tokens = self.text_tokenizer(text)
|
||||
file_info["raw_text"] = item["text"]
|
||||
file_info["text_tokens"] = text_tokens
|
||||
|
||||
audio_files.append(file_info)
|
||||
|
||||
if file_info["duration"] is None:
|
||||
logging.info(
|
||||
"Not all audio files have duration information. Duration logging will be disabled."
|
||||
)
|
||||
total_duration = None
|
||||
|
||||
if total_duration is not None:
|
||||
total_duration += item["duration"]
|
||||
|
||||
logging.info(f"Loaded dataset with {len(audio_files)} files.")
|
||||
if total_duration is not None:
|
||||
logging.info(f"Dataset contains {total_duration/3600:.2f} hours.")
|
||||
|
||||
self.data = []
|
||||
logging.info(f"Dataset contains {total_duration / 3600:.2f} hours.")
|
||||
|
||||
if ignore_file:
|
||||
logging.info(f"using {ignore_file} to prune dataset.")
|
||||
|
@ -197,325 +220,300 @@ class CharMelAudioDataset(Dataset):
|
|||
logging.info(f"Pruned {pruned_items} files. Final dataset contains {len(self.data)} files")
|
||||
if pruned_duration is not None:
|
||||
logging.info(
|
||||
f"Pruned {pruned_duration/3600:.2f} hours. Final dataset contains "
|
||||
f"{(total_duration-pruned_duration)/3600:.2f} hours."
|
||||
f"Pruned {pruned_duration / 3600:.2f} hours. Final dataset contains "
|
||||
f"{(total_duration - pruned_duration) / 3600:.2f} hours."
|
||||
)
|
||||
|
||||
self.featurizer = WaveformFeaturizer(sample_rate=sample_rate)
|
||||
self.sample_rate = sample_rate
|
||||
self.featurizer = WaveformFeaturizer(sample_rate=self.sample_rate)
|
||||
self.trim = trim
|
||||
|
||||
filterbanks = torch.tensor(
|
||||
librosa.filters.mel(sample_rate, n_fft, n_mels=n_mels, fmin=lowfreq, fmax=highfreq), dtype=torch.float
|
||||
self.n_fft = n_fft
|
||||
self.n_mels = n_mels
|
||||
self.lowfreq = lowfreq
|
||||
self.highfreq = highfreq
|
||||
self.window = window
|
||||
self.win_length = win_length or self.n_fft
|
||||
self.hop_length = hop_length
|
||||
self.hop_len = self.hop_length or self.n_fft // 4
|
||||
self.fb = torch.tensor(
|
||||
librosa.filters.mel(
|
||||
self.sample_rate, self.n_fft, n_mels=self.n_mels, fmin=self.lowfreq, fmax=self.highfreq
|
||||
),
|
||||
dtype=torch.float,
|
||||
).unsqueeze(0)
|
||||
self.fb = filterbanks
|
||||
|
||||
torch_windows = {
|
||||
window_fn = {
|
||||
'hann': torch.hann_window,
|
||||
'hamming': torch.hamming_window,
|
||||
'blackman': torch.blackman_window,
|
||||
'bartlett': torch.bartlett_window,
|
||||
'none': None,
|
||||
}
|
||||
window_fn = torch_windows.get(window, None)
|
||||
window_tensor = window_fn(self.win_length, periodic=False) if window_fn else None
|
||||
}.get(self.window, None)
|
||||
|
||||
self.stft = lambda x: stft_patch(
|
||||
input=x,
|
||||
n_fft=n_fft,
|
||||
n_fft=self.n_fft,
|
||||
hop_length=self.hop_len,
|
||||
win_length=self.win_length,
|
||||
window=window_tensor.to(torch.float),
|
||||
window=window_fn(self.win_length, periodic=False).to(torch.float) if window_fn else None,
|
||||
)
|
||||
|
||||
for data_type in self.sup_data_types:
|
||||
if data_type not in VALID_SUPPLEMENTARY_DATA_TYPES:
|
||||
raise NotImplementedError(f"Current implementation of TTSDataset doesn't support {data_type} type.")
|
||||
|
||||
getattr(self, f"add_{data_type.name}")(**kwargs)
|
||||
|
||||
def add_log_mel(self, **kwargs):
|
||||
pass
|
||||
|
||||
def add_durations(self, **kwargs):
|
||||
durs_file = kwargs.pop('durs_file')
|
||||
durs_type = kwargs.pop('durs_type')
|
||||
|
||||
audio_stem2durs = torch.load(durs_file)
|
||||
self.durs = []
|
||||
|
||||
for tag in [Path(d["audio_filepath"]).stem for d in self.data]:
|
||||
durs = audio_stem2durs[tag]
|
||||
if durs_type == "aligner-based":
|
||||
self.durs.append(durs)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"{durs_type} duration type is not supported. Only align-based is supported at this moment."
|
||||
)
|
||||
|
||||
def add_duration_prior(self, **kwargs):
|
||||
pass
|
||||
|
||||
def add_pitch(self, **kwargs):
|
||||
self.pitch_fmin = kwargs.pop("pitch_fmin", librosa.note_to_hz('C2'))
|
||||
self.pitch_fmax = kwargs.pop("pitch_fmax", librosa.note_to_hz('C7'))
|
||||
self.pitch_avg = kwargs.pop("pitch_avg", None)
|
||||
self.pitch_std = kwargs.pop("pitch_std", None)
|
||||
self.pitch_norm = kwargs.pop("pitch_norm", False)
|
||||
|
||||
def add_energy(self, **kwargs):
|
||||
pass
|
||||
|
||||
def get_spec(self, audio):
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
spec = self.stft(audio)
|
||||
if spec.dtype in [torch.cfloat, torch.cdouble]:
|
||||
spec = torch.view_as_real(spec)
|
||||
spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-9)
|
||||
return spec
|
||||
|
||||
def get_log_mel(self, audio):
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
spec = self.get_spec(audio)
|
||||
mel = torch.matmul(self.fb.to(spec.dtype), spec)
|
||||
log_mel = torch.log(torch.clamp(mel, min=torch.finfo(mel.dtype).tiny))
|
||||
return log_mel
|
||||
|
||||
def __getitem__(self, index):
|
||||
spec = None
|
||||
sample = self.data[index]
|
||||
audio_stem = Path(sample["audio_filepath"]).stem
|
||||
|
||||
features = self.featurizer.process(sample["audio_filepath"], trim=self.trim)
|
||||
audio, audio_length = features, torch.tensor(features.shape[0]).long()
|
||||
if isinstance(sample["text_tokens"], str):
|
||||
# If tokenize_text is False for Phone dataset
|
||||
text = sample["text_tokens"]
|
||||
text_length = None
|
||||
else:
|
||||
text = torch.tensor(sample["text_tokens"]).long()
|
||||
text_length = torch.tensor(len(sample["text_tokens"])).long()
|
||||
audio_stem = Path(sample["audio_filepath"]).stem
|
||||
|
||||
# Load mel if it exists
|
||||
mel_path = sample["mel_filepath"]
|
||||
if mel_path and Path(mel_path).exists():
|
||||
log_mel = torch.load(mel_path)
|
||||
else:
|
||||
mel_path = Path(self.supplementary_folder) / f"mel_{audio_stem}.pt"
|
||||
if mel_path.exists():
|
||||
text = torch.tensor(sample["text_tokens"]).long()
|
||||
text_length = torch.tensor(len(sample["text_tokens"])).long()
|
||||
|
||||
log_mel, log_mel_length = None, None
|
||||
if LogMel in self.sup_data_types_set:
|
||||
mel_path = sample["mel_filepath"]
|
||||
|
||||
if mel_path is not None and Path(mel_path).exists():
|
||||
log_mel = torch.load(mel_path)
|
||||
else:
|
||||
# disable autocast to get full range of stft values
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
spec = self.stft(audio)
|
||||
mel_path = Path(self.sup_data_path) / f"mel_{audio_stem}.pt"
|
||||
|
||||
# guard is needed for sqrt if grads are passed through
|
||||
guard = CONSTANT # TODO: Enable 0 if not self.use_grads else CONSTANT
|
||||
if spec.dtype in [torch.cfloat, torch.cdouble]:
|
||||
spec = torch.view_as_real(spec)
|
||||
spec = torch.sqrt(spec.pow(2).sum(-1) + guard)
|
||||
|
||||
mel = torch.matmul(self.fb.to(spec.dtype), spec)
|
||||
|
||||
log_mel = torch.log(torch.clamp(mel, min=torch.finfo(mel.dtype).tiny))
|
||||
if mel_path.exists():
|
||||
log_mel = torch.load(mel_path)
|
||||
else:
|
||||
log_mel = self.get_log_mel(audio)
|
||||
torch.save(log_mel, mel_path)
|
||||
|
||||
log_mel = log_mel.squeeze(0)
|
||||
log_mel_length = torch.tensor(log_mel.shape[1]).long()
|
||||
log_mel = log_mel.squeeze(0)
|
||||
log_mel_length = torch.tensor(log_mel.shape[1]).long()
|
||||
|
||||
durations = None
|
||||
if Durations in self.sup_data_types_set:
|
||||
durations = self.durs[index]
|
||||
|
||||
duration_prior = None
|
||||
if text_length is not None:
|
||||
### Make duration attention prior if not exist in the supplementary folder
|
||||
prior_path = Path(self.supplementary_folder) / f"pr_tl{text_length}_al_{log_mel_length}.pt"
|
||||
if DurationPrior in self.sup_data_types_set:
|
||||
prior_path = Path(self.sup_data_path) / f"pr_{audio_stem}.pt"
|
||||
|
||||
if prior_path.exists():
|
||||
duration_prior = torch.load(prior_path)
|
||||
else:
|
||||
log_mel_length = torch.tensor(self.get_log_mel(audio).squeeze(0).shape[1]).long()
|
||||
duration_prior = beta_binomial_prior_distribution(text_length, log_mel_length)
|
||||
duration_prior = torch.from_numpy(duration_prior)
|
||||
torch.save(duration_prior, prior_path)
|
||||
|
||||
# Load pitch file (F0s)
|
||||
pitch_path = (
|
||||
Path(self.supplementary_folder)
|
||||
/ f"{audio_stem}_pitch_pyin_fmin{self.pitch_fmin}_fmax{self.pitch_fmax}_fl{self.win_length}_hs{self.hop_len}.pt"
|
||||
)
|
||||
if pitch_path.exists():
|
||||
pitch = torch.load(pitch_path)
|
||||
else:
|
||||
pitch, _, _ = librosa.pyin(
|
||||
audio.numpy(),
|
||||
fmin=self.pitch_fmin,
|
||||
fmax=self.pitch_fmax,
|
||||
frame_length=self.win_length,
|
||||
sr=self.sample_rate,
|
||||
fill_na=0.0,
|
||||
)
|
||||
pitch = torch.from_numpy(pitch)
|
||||
torch.save(pitch, pitch_path)
|
||||
# Standize pitch
|
||||
pitch -= self.pitch_avg
|
||||
pitch[pitch == -self.pitch_avg] = 0.0 # Zero out values that were perviously zero
|
||||
pitch /= self.pitch_std
|
||||
|
||||
# Load energy file (L2-norm of the amplitude of each STFT frame of an utterance)
|
||||
energy_path = Path(self.supplementary_folder) / f"{audio_stem}_energy_wl{self.win_length}_hs{self.hop_len}.pt"
|
||||
if energy_path.exists():
|
||||
energy = torch.load(energy_path)
|
||||
else:
|
||||
if spec is None:
|
||||
spec = self.stft(audio)
|
||||
energy = torch.linalg.norm(spec.squeeze(0), axis=0)
|
||||
# Save to new file
|
||||
torch.save(energy, energy_path)
|
||||
|
||||
return text, text_length, log_mel, log_mel_length, audio, audio_length, duration_prior, pitch, energy
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
def _collate_fn(self, batch):
|
||||
log_mel_pad = torch.finfo(batch[0][2].dtype).tiny
|
||||
|
||||
_, tokens_lengths, _, log_mel_lengths, _, audio_lengths, duration_priors_list, pitches, energies = zip(*batch)
|
||||
|
||||
max_tokens_len = max(tokens_lengths).item()
|
||||
max_log_mel_len = max(log_mel_lengths)
|
||||
max_audio_len = max(audio_lengths).item()
|
||||
max_pitches_len = max([len(i) for i in pitches])
|
||||
max_energies_len = max([len(i) for i in energies])
|
||||
if max_pitches_len != max_energies_len or max_pitches_len != max_log_mel_len:
|
||||
logging.warning(
|
||||
f"max_pitches_len: {max_pitches_len} != max_energies_len: {max_energies_len} != "
|
||||
f"max_mel_len:{max_log_mel_len}. Your training run will error out!"
|
||||
pitch, pitch_length = None, None
|
||||
if Pitch in self.sup_data_types_set:
|
||||
pitch_name = (
|
||||
f"{audio_stem}_pitch_pyin_"
|
||||
f"fmin{self.pitch_fmin}_fmax{self.pitch_fmax}_"
|
||||
f"fl{self.win_length}_hs{self.hop_len}.pt"
|
||||
)
|
||||
|
||||
# Define empty lists to be batched
|
||||
duration_priors = torch.zeros(
|
||||
len(duration_priors_list),
|
||||
max([prior_i.shape[0] for prior_i in duration_priors_list]),
|
||||
max([prior_i.shape[1] for prior_i in duration_priors_list]),
|
||||
)
|
||||
audios, tokens, log_mels, pitches, energies = [], [], [], [], []
|
||||
for i, sample_tuple in enumerate(batch):
|
||||
token, token_len, log_mel, log_mel_len, audio, audio_len, duration_prior, pitch, energy = sample_tuple
|
||||
# Pad text tokens
|
||||
token_len = token_len.item()
|
||||
if token_len < max_tokens_len:
|
||||
pad = (0, max_tokens_len - token_len)
|
||||
token = torch.nn.functional.pad(token, pad, value=self.pad_id)
|
||||
tokens.append(token)
|
||||
# Pad mel
|
||||
log_mel_len = log_mel_len
|
||||
if log_mel_len < max_log_mel_len:
|
||||
pad = (0, max_log_mel_len - log_mel_len)
|
||||
log_mel = torch.nn.functional.pad(log_mel, pad, value=log_mel_pad)
|
||||
log_mels.append(log_mel)
|
||||
# Pad audio
|
||||
audio_len = audio_len.item()
|
||||
if audio_len < max_audio_len:
|
||||
pad = (0, max_audio_len - audio_len)
|
||||
audio = torch.nn.functional.pad(audio, pad)
|
||||
audios.append(audio)
|
||||
# Pad duration_prior
|
||||
duration_priors[i, : duration_prior.shape[0], : duration_prior.shape[1]] = duration_prior
|
||||
# Pad pitch
|
||||
if len(pitch) < max_pitches_len:
|
||||
pad = (0, max_pitches_len - len(pitch))
|
||||
pitch = torch.nn.functional.pad(pitch, pad)
|
||||
pitches.append(pitch)
|
||||
# Pad energy
|
||||
if len(energy) < max_energies_len:
|
||||
pad = (0, max_energies_len - len(energy))
|
||||
energy = torch.nn.functional.pad(energy, pad)
|
||||
energies.append(energy)
|
||||
|
||||
audios = torch.stack(audios)
|
||||
audio_lengths = torch.stack(audio_lengths)
|
||||
tokens = torch.stack(tokens)
|
||||
tokens_lengths = torch.stack(tokens_lengths)
|
||||
log_mels = torch.stack(log_mels)
|
||||
log_mel_lengths = torch.stack(log_mel_lengths)
|
||||
pitches = torch.stack(pitches)
|
||||
energies = torch.stack(energies)
|
||||
|
||||
logging.debug(f"audios: {audios.shape}")
|
||||
logging.debug(f"audio_lengths: {audio_lengths.shape}")
|
||||
logging.debug(f"tokens: {tokens.shape}")
|
||||
logging.debug(f"tokens_lengths: {tokens_lengths.shape}")
|
||||
logging.debug(f"log_mels: {log_mels.shape}")
|
||||
logging.debug(f"log_mel_lengths: {log_mel_lengths.shape}")
|
||||
logging.debug(f"duration_priors: {duration_priors.shape}")
|
||||
logging.debug(f"pitches: {pitches.shape}")
|
||||
logging.debug(f"energies: {energies.shape}")
|
||||
|
||||
return (tokens, tokens_lengths, log_mels, log_mel_lengths, duration_priors, pitches, energies)
|
||||
|
||||
def decode(self, tokens):
|
||||
assert len(tokens.squeeze().shape) in [0, 1]
|
||||
return self.parser.decode(tokens)
|
||||
|
||||
|
||||
class PhoneMelAudioDataset(CharMelAudioDataset):
|
||||
@property
|
||||
def output_types(self) -> Optional[Dict[str, NeuralType]]:
|
||||
return {
|
||||
'transcripts': NeuralType(('B', 'T'), TokenIndex()),
|
||||
'transcript_length': NeuralType(('B'), LengthsType()),
|
||||
'mels': NeuralType(('B', 'D', 'T'), TokenIndex()),
|
||||
'mel_length': NeuralType(('B'), LengthsType()),
|
||||
'audio': NeuralType(('B', 'T'), AudioSignal()),
|
||||
'audio_length': NeuralType(('B'), LengthsType()),
|
||||
'duration_prior': NeuralType(('B', 'T'), TokenDurationType()),
|
||||
'pitches': NeuralType(('B', 'T'), RegressionValuesType()),
|
||||
'energies': NeuralType(('B', 'T'), RegressionValuesType()),
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
punct=True,
|
||||
stresses=False,
|
||||
spaces=True,
|
||||
chars=False,
|
||||
space=' ',
|
||||
silence=None,
|
||||
apostrophe=True,
|
||||
oov=Base.OOV,
|
||||
sep='|',
|
||||
add_blank_at="last_but_one",
|
||||
pad_with_space=False,
|
||||
improved_version_g2p=False,
|
||||
phoneme_dict_path=None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Dataset which extends CharMelAudioDataset to load phones in place of characters. It returns audio, log mel
|
||||
specs, phone tokens, duration / attention priors, pitches, and energies. Log mels, priords, pitches, and
|
||||
energies will be computed on the fly and saved in the supplementary_folder if they did not exist before. These
|
||||
supplementary files can be shared with CharMelAudioDataset.
|
||||
|
||||
Args:
|
||||
punct (bool): Whether to keep punctuation in the input. Defaults to True
|
||||
stresses (bool): Whether to add phone stresses in the input. Defaults to False
|
||||
spaces (bool): Whether to encode space characters. Defaults to True
|
||||
chars (bool): Whether to use add characters to the labels map. NOTE: The current parser class does not
|
||||
actually parse transcripts to characters. Defaults to False
|
||||
space (str): The space character. Defaults to ' '
|
||||
silence (bool): Whether to use add silence tokens. Defaults to False
|
||||
apostrophe (bool): Whether to use keep apostrophes. Defaults to True
|
||||
oov (str): How out of vocabulary tokens are decoded. Defaults to Base.OOV == "<oov>"
|
||||
sep (str): How to seperate phones when tokens are decoded. Defaults to "|"
|
||||
add_blank_at (str): Where to add the blank symbol that is used in CTC. Can be None which does not add a
|
||||
blank token in the vocab, "last" which makes self.vocab.labels[-1] the blank token, or
|
||||
"last_but_one" which makes self.vocab.labels[-2] the blank token
|
||||
pad_with_space (bool): Whether to use pad input with space tokens at start and end. Defaults to False
|
||||
improved_version_g2p (bool): Defaults to False
|
||||
phoneme_dict_path (path): Location of cmudict. Defaults to None which means the code will download it
|
||||
automatically
|
||||
"""
|
||||
if "tokenize_text" in kwargs:
|
||||
tokenize_text = kwargs.pop("tokenize_text")
|
||||
if not tokenize_text:
|
||||
logging.warning(
|
||||
f"{self} requires tokenize_text to be False. Setting it to False and ignoring provided value of "
|
||||
f"{tokenize_text}"
|
||||
pitch_path = Path(self.sup_data_path) / pitch_name
|
||||
if pitch_path.exists():
|
||||
pitch = torch.load(pitch_path).float()
|
||||
else:
|
||||
pitch, _, _ = librosa.pyin(
|
||||
audio.numpy(),
|
||||
fmin=self.pitch_fmin,
|
||||
fmax=self.pitch_fmax,
|
||||
frame_length=self.win_length,
|
||||
sr=self.sample_rate,
|
||||
fill_na=0.0,
|
||||
)
|
||||
super().__init__(tokenize_text=False, **kwargs)
|
||||
pitch = torch.from_numpy(pitch).float()
|
||||
torch.save(pitch, pitch_path)
|
||||
|
||||
self.vocab = Phonemes(
|
||||
punct=punct,
|
||||
stresses=stresses,
|
||||
spaces=spaces,
|
||||
chars=chars,
|
||||
add_blank_at=add_blank_at,
|
||||
pad_with_space=pad_with_space,
|
||||
improved_version_g2p=improved_version_g2p,
|
||||
phoneme_dict_path=phoneme_dict_path,
|
||||
space=space,
|
||||
silence=silence,
|
||||
apostrophe=apostrophe,
|
||||
oov=oov,
|
||||
sep=sep,
|
||||
)
|
||||
self.pad_id = self.vocab.pad
|
||||
if self.pitch_avg is not None and self.pitch_std is not None and self.pitch_norm:
|
||||
pitch -= self.pitch_avg
|
||||
pitch[pitch == -self.pitch_avg] = 0.0 # Zero out values that were perviously zero
|
||||
pitch /= self.pitch_std
|
||||
|
||||
def __getitem__(self, index):
|
||||
(text, _, log_mel, log_mel_length, audio, audio_length, _, pitch, energy) = super().__getitem__(index)
|
||||
pitch_length = torch.tensor(len(pitch)).long()
|
||||
|
||||
phones_tokenized = torch.tensor(self.vocab.encode(text)).long()
|
||||
phones_length = torch.tensor(len(phones_tokenized)).long()
|
||||
energy, energy_length = None, None
|
||||
if Energy in self.sup_data_types_set:
|
||||
energy_path = Path(self.sup_data_path) / f"{audio_stem}_energy_wl{self.win_length}_hs{self.hop_len}.pt"
|
||||
if energy_path.exists():
|
||||
energy = torch.load(energy_path).float()
|
||||
else:
|
||||
spec = self.get_spec(audio)
|
||||
energy = torch.linalg.norm(spec.squeeze(0), axis=0).float()
|
||||
torch.save(energy, energy_path)
|
||||
|
||||
### Make duration attention prior if not exist in the supplementary folder
|
||||
prior_path = Path(self.supplementary_folder) / f"pr_tl{phones_length}_al_{log_mel_length}.pt"
|
||||
if prior_path.exists():
|
||||
duration_prior = torch.load(prior_path)
|
||||
else:
|
||||
duration_prior = beta_binomial_prior_distribution(phones_length, log_mel_length)
|
||||
duration_prior = torch.from_numpy(duration_prior)
|
||||
torch.save(duration_prior, prior_path)
|
||||
energy_length = torch.tensor(len(energy)).long()
|
||||
|
||||
return (
|
||||
phones_tokenized,
|
||||
phones_length,
|
||||
log_mel,
|
||||
log_mel_length,
|
||||
audio,
|
||||
audio_length,
|
||||
text,
|
||||
text_length,
|
||||
log_mel,
|
||||
log_mel_length,
|
||||
durations,
|
||||
duration_prior,
|
||||
pitch,
|
||||
pitch_length,
|
||||
energy,
|
||||
energy_length,
|
||||
)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
def decode(self, tokens):
|
||||
"""
|
||||
Accepts a singule list of tokens, not a batch
|
||||
"""
|
||||
assert len(tokens.squeeze().shape) in [0, 1]
|
||||
return self.vocab.decode(tokens)
|
||||
def join_data(self, data_dict):
|
||||
result = []
|
||||
for data_type in MAIN_DATA_TYPES + self.sup_data_types:
|
||||
result.append(data_dict[data_type.name])
|
||||
|
||||
if issubclass(data_type, WithLens):
|
||||
result.append(data_dict[f"{data_type.name}_lens"])
|
||||
|
||||
return tuple(result)
|
||||
|
||||
def general_collate_fn(self, batch):
|
||||
(
|
||||
_,
|
||||
audio_lengths,
|
||||
_,
|
||||
tokens_lengths,
|
||||
_,
|
||||
log_mel_lengths,
|
||||
durations_list,
|
||||
duration_priors_list,
|
||||
pitches,
|
||||
pitches_lengths,
|
||||
energies,
|
||||
energies_lengths,
|
||||
) = zip(*batch)
|
||||
|
||||
max_audio_len = max(audio_lengths).item()
|
||||
max_tokens_len = max(tokens_lengths).item()
|
||||
max_log_mel_len = max(log_mel_lengths) if LogMel in self.sup_data_types_set else None
|
||||
max_durations_len = max([len(i) for i in durations_list]) if Durations in self.sup_data_types_set else None
|
||||
max_pitches_len = max(pitches_lengths).item() if Pitch in self.sup_data_types_set else None
|
||||
max_energies_len = max(energies_lengths).item() if Energy in self.sup_data_types_set else None
|
||||
|
||||
if LogMel in self.sup_data_types_set:
|
||||
log_mel_pad = torch.finfo(batch[0][2].dtype).tiny
|
||||
|
||||
duration_priors = (
|
||||
torch.zeros(
|
||||
len(duration_priors_list),
|
||||
max([prior_i.shape[0] for prior_i in duration_priors_list]),
|
||||
max([prior_i.shape[1] for prior_i in duration_priors_list]),
|
||||
)
|
||||
if DurationPrior in self.sup_data_types_set
|
||||
else []
|
||||
)
|
||||
audios, tokens, log_mels, durations_list, pitches, energies = [], [], [], [], [], []
|
||||
|
||||
for i, sample_tuple in enumerate(batch):
|
||||
(
|
||||
audio,
|
||||
audio_len,
|
||||
token,
|
||||
token_len,
|
||||
log_mel,
|
||||
log_mel_len,
|
||||
durations,
|
||||
duration_prior,
|
||||
pitch,
|
||||
pitch_length,
|
||||
energy,
|
||||
energy_length,
|
||||
) = sample_tuple
|
||||
|
||||
audio = general_padding(audio, audio_len.item(), max_audio_len)
|
||||
audios.append(audio)
|
||||
|
||||
token = general_padding(token, token_len.item(), max_tokens_len, pad_value=self.text_tokenizer_pad_id)
|
||||
tokens.append(token)
|
||||
|
||||
if LogMel in self.sup_data_types_set:
|
||||
log_mels.append(general_padding(log_mel, log_mel_len, max_log_mel_len, pad_value=log_mel_pad))
|
||||
if Durations in self.sup_data_types_set:
|
||||
durations_list.append(general_padding(durations, len(durations), max_durations_len))
|
||||
if DurationPrior in self.sup_data_types_set:
|
||||
duration_priors[i, : duration_prior.shape[0], : duration_prior.shape[1]] = duration_prior
|
||||
if Pitch in self.sup_data_types_set:
|
||||
pitches.append(general_padding(pitch, pitch_length.item(), max_pitches_len))
|
||||
if Energy in self.sup_data_types_set:
|
||||
energies.append(general_padding(energy, energy_length.item(), max_energies_len))
|
||||
|
||||
data_dict = {
|
||||
"audio": torch.stack(audios),
|
||||
"audio_lens": torch.stack(audio_lengths),
|
||||
"text": torch.stack(tokens),
|
||||
"text_lens": torch.stack(tokens_lengths),
|
||||
"log_mel": torch.stack(log_mels) if LogMel in self.sup_data_types_set else None,
|
||||
"log_mel_lens": torch.stack(log_mel_lengths) if LogMel in self.sup_data_types_set else None,
|
||||
"durations": torch.stack(durations_list) if Durations in self.sup_data_types_set else None,
|
||||
"duration_prior": duration_priors if DurationPrior in self.sup_data_types_set else None,
|
||||
"pitch": torch.stack(pitches) if Pitch in self.sup_data_types_set else None,
|
||||
"pitch_lens": torch.stack(pitches_lengths) if Pitch in self.sup_data_types_set else None,
|
||||
"energy": torch.stack(energies) if Energy in self.sup_data_types_set else None,
|
||||
"energy_lens": torch.stack(energies_lengths) if Energy in self.sup_data_types_set else None,
|
||||
}
|
||||
|
||||
return data_dict
|
||||
|
||||
def _collate_fn(self, batch):
|
||||
data_dict = self.general_collate_fn(batch)
|
||||
joined_data = self.join_data(data_dict)
|
||||
return joined_data
|
||||
|
|
226
nemo/collections/tts/torch/g2ps.py
Normal file
226
nemo/collections/tts/torch/g2ps.py
Normal file
|
@ -0,0 +1,226 @@
|
|||
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import abc
|
||||
import pathlib
|
||||
import re
|
||||
import time
|
||||
|
||||
import nltk
|
||||
import torch
|
||||
|
||||
from nemo.collections.tts.torch.tts_tokenizers import english_text_preprocessing, english_word_tokenize
|
||||
from nemo.utils import logging
|
||||
from nemo.utils.get_rank import is_global_rank_zero
|
||||
|
||||
|
||||
class BaseG2p(abc.ABC):
|
||||
def __init__(
|
||||
self,
|
||||
phoneme_dict=None,
|
||||
text_preprocessing_func=lambda x: x,
|
||||
word_tokenize_func=lambda x: x,
|
||||
apply_to_oov_word=None,
|
||||
):
|
||||
"""Abstract class for creating an arbitrary module to convert grapheme words to phoneme sequences (or leave unchanged or use apply_to_oov_word).
|
||||
Args:
|
||||
phoneme_dict: Arbitrary representation of dictionary (phoneme -> grapheme) for known words.
|
||||
text_preprocessing_func: Function for preprocessing raw text.
|
||||
word_tokenize_func: Function for tokenizing text to words.
|
||||
apply_to_oov_word: Function that will be applied to out of phoneme_dict word.
|
||||
"""
|
||||
self.phoneme_dict = phoneme_dict
|
||||
self.text_preprocessing_func = text_preprocessing_func
|
||||
self.word_tokenize_func = word_tokenize_func
|
||||
self.apply_to_oov_word = apply_to_oov_word
|
||||
|
||||
@abc.abstractmethod
|
||||
def __call__(self, text: str) -> str:
|
||||
pass
|
||||
|
||||
|
||||
class EnglishG2p(BaseG2p):
|
||||
def __init__(
|
||||
self,
|
||||
phoneme_dict=None,
|
||||
text_preprocessing_func=english_text_preprocessing,
|
||||
word_tokenize_func=english_word_tokenize,
|
||||
apply_to_oov_word=None,
|
||||
ignore_ambiguous_words=True,
|
||||
heteronyms=None,
|
||||
encoding='latin-1',
|
||||
):
|
||||
"""English G2P module. This module converts words from grapheme to phoneme representation using phoneme_dict in CMU dict format.
|
||||
Optionally, it can ignore words which are heteronyms, ambiguous or marked as unchangeable by word_tokenize_func (see code for details).
|
||||
Ignored words are left unchanged or passed through apply_to_oov_word.
|
||||
Args:
|
||||
phoneme_dict (str, Path, Dict): Path to file in CMU dict format or dictionary in CMU dict.
|
||||
text_preprocessing_func: Function for preprocessing raw text to preprocessed text.
|
||||
word_tokenize_func: Function for tokenizing text to words.
|
||||
It has to return List[Tuple[Union[str, List[str]], bool]] where every tuple denotes word representation and flag whether to leave unchanged or not.
|
||||
It is expected that unchangeable word representation will be represented as List[str], other cases are represented as str.
|
||||
It is useful to mark word as unchangeable which is already in phoneme representation.
|
||||
apply_to_oov_word: Function that will be applied to out of phoneme_dict word.
|
||||
ignore_ambiguous_words: Whether to not handle word via phoneme_dict with ambiguous phoneme sequences. Defaults to True.
|
||||
heteronyms (str, Path, List): Path to file with heteronyms (every line is new word) or list of words.
|
||||
encoding: Encoding type.
|
||||
"""
|
||||
phoneme_dict = (
|
||||
self._parse_as_cmu_dict(phoneme_dict, encoding)
|
||||
if isinstance(phoneme_dict, str) or isinstance(phoneme_dict, pathlib.Path) or phoneme_dict is None
|
||||
else phoneme_dict
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
phoneme_dict=phoneme_dict,
|
||||
text_preprocessing_func=text_preprocessing_func,
|
||||
word_tokenize_func=word_tokenize_func,
|
||||
apply_to_oov_word=apply_to_oov_word,
|
||||
)
|
||||
|
||||
self.ignore_ambiguous_words = ignore_ambiguous_words
|
||||
self.heteronyms = (
|
||||
set(self._parse_file_by_lines(heteronyms, encoding))
|
||||
if isinstance(heteronyms, str) or isinstance(heteronyms, pathlib.Path)
|
||||
else heteronyms
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _parse_as_cmu_dict(phoneme_dict_path=None, encoding='latin-1'):
|
||||
if phoneme_dict_path is None:
|
||||
# this part of code downloads file, but it is not rank zero guarded
|
||||
# Try to check if torch distributed is available, if not get global rank zero to download corpora and make
|
||||
# all other ranks sleep for a minute
|
||||
if torch.distributed.is_available() and torch.distributed.is_initialized():
|
||||
group = torch.distributed.group.WORLD
|
||||
if is_global_rank_zero():
|
||||
try:
|
||||
nltk.data.find('corpora/cmudict.zip')
|
||||
except LookupError:
|
||||
nltk.download('cmudict', quiet=True)
|
||||
torch.distributed.barrier(group=group)
|
||||
elif is_global_rank_zero():
|
||||
logging.error(
|
||||
f"Torch distributed needs to be initialized before you initialized EnglishG2p. This class is prone to "
|
||||
"data access race conditions. Now downloading corpora from global rank 0. If other ranks pass this "
|
||||
"before rank 0, errors might result."
|
||||
)
|
||||
try:
|
||||
nltk.data.find('corpora/cmudict.zip')
|
||||
except LookupError:
|
||||
nltk.download('cmudict', quiet=True)
|
||||
else:
|
||||
logging.error(
|
||||
f"Torch distributed needs to be initialized before you initialized EnglishG2p. This class is prone to "
|
||||
"data access race conditions. This process is not rank 0, and now going to sleep for 1 min. If this "
|
||||
"rank wakes from sleep prior to rank 0 finishing downloading, errors might result."
|
||||
)
|
||||
time.sleep(60)
|
||||
|
||||
logging.warning("phoneme_dict_path=None, English g2p_dict will be used from nltk.corpus.cmudict.dict()")
|
||||
|
||||
return nltk.corpus.cmudict.dict()
|
||||
|
||||
_alt_re = re.compile(r'\([0-9]+\)')
|
||||
g2p_dict = {}
|
||||
with open(phoneme_dict_path, encoding=encoding) as file:
|
||||
for line in file:
|
||||
if len(line) and ('A' <= line[0] <= 'Z' or line[0] == "'"):
|
||||
parts = line.split(' ')
|
||||
word = re.sub(_alt_re, '', parts[0])
|
||||
word = word.lower()
|
||||
|
||||
pronunciation = parts[1].strip().split(" ")
|
||||
if word in g2p_dict:
|
||||
g2p_dict[word].append(pronunciation)
|
||||
else:
|
||||
g2p_dict[word] = [pronunciation]
|
||||
return g2p_dict
|
||||
|
||||
@staticmethod
|
||||
def _parse_file_by_lines(p, encoding):
|
||||
with open(p, encoding=encoding) as f:
|
||||
return [l.rstrip() for l in f.readlines()]
|
||||
|
||||
def is_unique_in_phoneme_dict(self, word):
|
||||
return len(self.phoneme_dict[word]) == 1
|
||||
|
||||
def parse_one_word(self, word: str):
|
||||
"""
|
||||
Returns parsed `word` and `status` as bool.
|
||||
`status` will be `False` if word wasn't handled, `True` otherwise.
|
||||
"""
|
||||
|
||||
# punctuation
|
||||
if re.search("[a-zA-Z]", word) is None:
|
||||
return list(word), True
|
||||
|
||||
# heteronym
|
||||
if self.heteronyms is not None and word in self.heteronyms:
|
||||
return word, True
|
||||
|
||||
# `'s` suffix
|
||||
if (
|
||||
len(word) > 2
|
||||
and word.endswith("'s")
|
||||
and (word not in self.phoneme_dict)
|
||||
and (word[:-2] in self.phoneme_dict)
|
||||
and (not self.ignore_ambiguous_words or self.is_unique_in_phoneme_dict(word[:-2]))
|
||||
):
|
||||
return self.phoneme_dict[word[:-2]][0] + ["Z"], True
|
||||
|
||||
# `s` suffix
|
||||
if (
|
||||
len(word) > 1
|
||||
and word.endswith("s")
|
||||
and (word not in self.phoneme_dict)
|
||||
and (word[:-1] in self.phoneme_dict)
|
||||
and (not self.ignore_ambiguous_words or self.is_unique_in_phoneme_dict(word[:-1]))
|
||||
):
|
||||
return self.phoneme_dict[word[:-1]][0] + ["Z"], True
|
||||
|
||||
# phoneme dict
|
||||
if word in self.phoneme_dict and (not self.ignore_ambiguous_words or self.is_unique_in_phoneme_dict(word)):
|
||||
return self.phoneme_dict[word][0], True
|
||||
|
||||
if self.apply_to_oov_word is not None:
|
||||
return self.apply_to_oov_word(word), False
|
||||
else:
|
||||
return word, False
|
||||
|
||||
def __call__(self, text):
|
||||
text = self.text_preprocessing_func(text)
|
||||
words = self.word_tokenize_func(text)
|
||||
|
||||
prons = []
|
||||
for word, without_changes in words:
|
||||
if without_changes:
|
||||
prons.extend(word)
|
||||
continue
|
||||
|
||||
word_by_hyphen = word.split("-")
|
||||
|
||||
pron, is_handled = self.parse_one_word(word)
|
||||
|
||||
if not is_handled and len(word_by_hyphen) > 1:
|
||||
pron = []
|
||||
for sub_word in word_by_hyphen:
|
||||
p, _ = self.parse_one_word(sub_word)
|
||||
pron.extend(p)
|
||||
pron.extend(["-"])
|
||||
pron.pop()
|
||||
|
||||
prons.extend(pron)
|
||||
|
||||
return prons
|
|
@ -13,9 +13,16 @@
|
|||
# limitations under the License.
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from scipy.stats import betabinom
|
||||
|
||||
|
||||
def general_padding(item, item_len, max_len, pad_value=0):
|
||||
if item_len < max_len:
|
||||
item = torch.nn.functional.pad(item, (0, max_len - item_len), value=pad_value)
|
||||
return item
|
||||
|
||||
|
||||
def beta_binomial_prior_distribution(phoneme_count, mel_count, scaling_factor=1.0):
|
||||
x = np.arange(0, phoneme_count)
|
||||
mel_text_probs = []
|
||||
|
|
|
@ -22,83 +22,39 @@ We can check that lightning is not installed by checking pip:
|
|||
```bash
|
||||
pip list | grep lightning
|
||||
```
|
||||
Now even though lightning isn't installed, we can still use parts from the torch_tts collection.
|
||||
Now even though lightning isn't installed, we can still use parts from the `torch_tts` collection.
|
||||
|
||||
### TTS Dataset
|
||||
### TTS Dataset: example
|
||||
|
||||
Let's import our dataset class and then loop through the batches. Note that in the sample .json files, we only have text
|
||||
and audio. Our dataset will then create the log_mels, priors, pitches, and energies and store them in `supplementary_folder`
|
||||
which in this case is `./debug0`.
|
||||
Let's import our dataset class, loop through the batches and do simple task: calculate pitch statistics. Note that in the sample .json files, we only have text
|
||||
and audio. Our dataset will then create supplementary data (e.g. pitch) and store them in `supplementary_folder`. You can find config in `tts_dataset.yaml`.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from hydra.utils import instantiate
|
||||
from omegaconf import OmegaConf
|
||||
from tqdm import tqdm
|
||||
|
||||
from nemo.collections.tts.torch.data import CharMelAudioDataset
|
||||
cfg = OmegaConf.load("nemo/collections/tts/torch/tts_dataset.yaml")
|
||||
|
||||
dataset = CharMelAudioDataset(
|
||||
manifest_filepath="<PATH_TO_MANIFEST_JSON>", # Path to file that describes the location of audio and text
|
||||
sample_rate=22050,
|
||||
supplementary_folder="./debug0", # An additional folder that will store log_mels, priors, pitches, and energies
|
||||
max_duration=20., # Max duration of samples in seconds
|
||||
min_duration=0.1, # Min duration of samples in seconds
|
||||
ignore_file=None,
|
||||
trim=False, # Whether to use librosa.effects.trim
|
||||
n_fft=1024,
|
||||
win_length=1024,
|
||||
hop_length=256,
|
||||
window="hann",
|
||||
n_mels=64, # Number of mel filters
|
||||
lowfreq=0, # lowfreq for mel filters
|
||||
highfreq=8000, # highfreq for mel filters
|
||||
pitch_fmin=80,
|
||||
pitch_fmax=640,
|
||||
)
|
||||
dataset = instantiate(cfg.tts_dataset)
|
||||
dataloader = torch.utils.data.DataLoader(dataset, 1, collate_fn=dataset._collate_fn, num_workers=1)
|
||||
|
||||
dataloader = torch.utils.data.DataLoader(dataset, 10, collate_fn=dataset._collate_fn)
|
||||
pitch_list = []
|
||||
for batch in tqdm(dataloader, total=len(dataloader)):
|
||||
tokens, tokens_lengths, audios, audio_lengths, pitches, pitches_lengths = batch
|
||||
pitch = pitches.squeeze(0)
|
||||
pitch_list.append(pitch[pitch != 0])
|
||||
|
||||
for batch in dataloader:
|
||||
tokens, tokens_lengths, log_mels, log_mel_lengths, duration_priors, pitches, energies = batch
|
||||
## Train models, etc.
|
||||
# Tokens represent already tokenized characters which probably will not work with previous tokenziers
|
||||
# You can get the label map from dataset.parser._labels_map. You can tokenize text via dataset.parser("text!")
|
||||
# You can detokenize using dataset.decode()
|
||||
pitch_tensor = torch.cat(pitch_list)
|
||||
print(f"PITCH_MEAN, PITCH_STD = {pitch_tensor.mean().item()}, {pitch_tensor.std().item()}")
|
||||
```
|
||||
|
||||
```python
|
||||
import torch
|
||||
|
||||
from nemo.collections.tts.torch.data import PhoneMelAudioDataset
|
||||
|
||||
dataset = PhoneMelAudioDataset(
|
||||
manifest_filepath="<PATH_TO_MANIFEST_JSON>", # Path to file that describes the location of audio and text
|
||||
sample_rate=22050,
|
||||
supplementary_folder="./debug0", # An additional folder that will store log_mels, priors, pitches, and energies
|
||||
)
|
||||
|
||||
dataloader = torch.utils.data.DataLoader(dataset, 10, collate_fn=dataset._collate_fn)
|
||||
|
||||
for batch in dataloader:
|
||||
tokens, tokens_lengths, log_mels, log_mel_lengths, duration_priors, pitches, energies = batch
|
||||
## Train models, etc.
|
||||
# Tokens represent already tokenized characters which probably will not work with previous tokenziers
|
||||
# You can tokenize via dataset.vocab.encode(), and go backwards with dataset.vocab.decode().
|
||||
```
|
||||
|
||||
## NeMo Features
|
||||
|
||||
If you look into the code we see that `TextMelAudioDataset` is a child of `nemo.core.classes.Dataset`. **You do not have to subclass this to add to the torch_tts repo**. It is sufficient to use `torch.utils.data.Dataset`. Using the nemo class adds some additional features not present in torch:
|
||||
|
||||
- *(Optional)* Adding typing information
|
||||
- Looking at `TextMelAudioDataset`, it has a `output_types` that tells us how tensors are organized. For example, the `mels` returned by this dataset has dimensions B x D x T, which is short for saying the first dimension represents batch, the second represents a generic channels / n_mel_filters dimension, and the last represents time
|
||||
- *(Optional)* *(ToDo)* Enables serialization
|
||||
- We can now call `to_config_dict()` to return a dictionary which we can now pass to `from_config_dict()` to create another instanace of the dataset with the same arguments allowing us to easily restate code using these dictionaries. Please note to change any local paths if changing computers.
|
||||
|
||||
## ToDos
|
||||
|
||||
- [ ] Populate *torch_tts*
|
||||
- [x] Create a new datalayer that can be used interchangeably
|
||||
- [x] Add phone support
|
||||
- [ ] Add TTS models
|
||||
- [ ] Add TTS models with new dataset
|
||||
- [ ] Split Lightning away from core
|
||||
- [x] v0.1 that import checks a lot of lightning
|
||||
- [ ] Split up code (core, collections, utils) better
|
||||
|
|
54
nemo/collections/tts/torch/tts_data_types.py
Normal file
54
nemo/collections/tts/torch/tts_data_types.py
Normal file
|
@ -0,0 +1,54 @@
|
|||
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
class TTSDataType:
|
||||
name = None
|
||||
|
||||
|
||||
class WithLens(TTSDataType):
|
||||
"""Represent that this TTSDataType returns lengths for data"""
|
||||
|
||||
|
||||
class Audio(WithLens):
|
||||
name = "audio"
|
||||
|
||||
|
||||
class Text(WithLens):
|
||||
name = "text"
|
||||
|
||||
|
||||
class LogMel(WithLens):
|
||||
name = "log_mel"
|
||||
|
||||
|
||||
class Durations(TTSDataType):
|
||||
name = "durations"
|
||||
|
||||
|
||||
class DurationPrior(TTSDataType):
|
||||
name = "duration_prior"
|
||||
|
||||
|
||||
class Pitch(WithLens):
|
||||
name = "pitch"
|
||||
|
||||
|
||||
class Energy(WithLens):
|
||||
name = "energy"
|
||||
|
||||
|
||||
MAIN_DATA_TYPES = [Audio, Text]
|
||||
VALID_SUPPLEMENTARY_DATA_TYPES = [LogMel, Durations, DurationPrior, Pitch, Energy]
|
||||
DATA_STR2DATA_CLASS = {d.name: d for d in MAIN_DATA_TYPES + VALID_SUPPLEMENTARY_DATA_TYPES}
|
45
nemo/collections/tts/torch/tts_dataset.yaml
Normal file
45
nemo/collections/tts/torch/tts_dataset.yaml
Normal file
|
@ -0,0 +1,45 @@
|
|||
tts_dataset:
|
||||
_target_: "nemo.collections.tts.torch.data.TTSDataset"
|
||||
manifest_filepath: "manifest.json"
|
||||
sample_rate: 22050
|
||||
sup_data_path: "test_sup_data"
|
||||
sup_data_types: ["pitch"]
|
||||
n_fft: 1024
|
||||
win_length: 1024
|
||||
hop_length: 256
|
||||
window: "hann"
|
||||
n_mels: 80
|
||||
lowfreq: 0
|
||||
highfreq: 8000
|
||||
max_duration: null
|
||||
min_duration: null
|
||||
ignore_file: null
|
||||
trim: False
|
||||
pitch_fmin: 65.40639132514966
|
||||
pitch_fmax: 2093.004522404789
|
||||
|
||||
text_normalizer:
|
||||
_target_: "nemo_text_processing.text_normalization.normalize.Normalizer"
|
||||
lang: "en"
|
||||
input_case: "cased"
|
||||
|
||||
text_normalizer_call_args:
|
||||
verbose: False
|
||||
punct_pre_process: True
|
||||
punct_post_process: True
|
||||
|
||||
text_tokenizer:
|
||||
_target_: "nemo.collections.tts.torch.tts_tokenizers.EnglishPhonemesTokenizer"
|
||||
punct: True
|
||||
stresses: True
|
||||
chars: True
|
||||
space: ' '
|
||||
silence: null
|
||||
apostrophe: True
|
||||
sep: '|'
|
||||
add_blank_at: null
|
||||
pad_with_space: True
|
||||
g2p:
|
||||
_target_: "nemo.collections.tts.torch.g2ps.EnglishG2p"
|
||||
phoneme_dict: "scripts/tts_dataset_files/cmudict-0.7b-030921"
|
||||
heteronyms: "scripts/tts_dataset_files/heteronyms-030921"
|
310
nemo/collections/tts/torch/tts_tokenizers.py
Normal file
310
nemo/collections/tts/torch/tts_tokenizers.py
Normal file
|
@ -0,0 +1,310 @@
|
|||
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import abc
|
||||
import itertools
|
||||
import re
|
||||
import string
|
||||
import unicodedata
|
||||
from builtins import str as unicode
|
||||
from typing import List
|
||||
|
||||
# Example of parsing by groups via _words_re.
|
||||
# Groups:
|
||||
# 1st group -- valid english words,
|
||||
# 2nd group -- any substring starts from | to | (mustn't be nested), useful when you want to leave sequence unchanged,
|
||||
# 3rd group -- punctuation marks.
|
||||
# Text (first line) and mask of groups for every char (second line).
|
||||
# config file must contain |EY1 EY1|, B, C, D, E, F, and G.
|
||||
# 111111311113111131111111322222222233133133133133133111313
|
||||
_words_re = re.compile("([a-zA-Z]+(?:[a-zA-Z-']*[a-zA-Z]+)*)|(\|[^|]*\|)|([^a-zA-Z|]+)")
|
||||
|
||||
|
||||
def english_text_preprocessing(text):
|
||||
text = unicode(text)
|
||||
text = ''.join(char for char in unicodedata.normalize('NFD', text) if unicodedata.category(char) != 'Mn')
|
||||
return text
|
||||
|
||||
|
||||
def english_word_tokenize(text):
|
||||
"""
|
||||
Convert text (str) to List[Tuple[Union[str, List[str]], bool]] where every tuple denotes word representation and flag whether to leave unchanged or not.
|
||||
Word can be one of: valid english word, any substring starts from | to | (unchangeable word) or punctuation marks.
|
||||
This function expects that unchangeable word is carefully divided by spaces (e.g. HH AH L OW).
|
||||
Unchangeable word will be splitted by space and represented as List[str], other cases are represented as str.
|
||||
"""
|
||||
words = _words_re.findall(text)
|
||||
result = []
|
||||
for word in words:
|
||||
maybe_word, maybe_without_changes, maybe_punct = word
|
||||
|
||||
if maybe_word != '':
|
||||
without_changes = False
|
||||
result.append((maybe_word.lower(), without_changes))
|
||||
elif maybe_punct != '':
|
||||
without_changes = False
|
||||
result.append((re.sub(r'\s(\d)', r'\1', maybe_punct.upper()), without_changes))
|
||||
elif maybe_without_changes != '':
|
||||
without_changes = True
|
||||
result.append((maybe_without_changes[1:-1].split(" "), without_changes))
|
||||
return result
|
||||
|
||||
|
||||
class BaseTokenizer(abc.ABC):
|
||||
PAD, BLANK, OOV = '<pad>', '<blank>', '<oov>'
|
||||
|
||||
def __init__(self, tokens, *, pad=PAD, blank=BLANK, oov=OOV, sep='', add_blank_at=None):
|
||||
"""Abstract class for creating an arbitrary tokenizer to convert string to list of int tokens.
|
||||
Args:
|
||||
tokens: List of tokens.
|
||||
pad: Pad token as string.
|
||||
blank: Blank token as string.
|
||||
oov: OOV token as string.
|
||||
sep: Separation token as string.
|
||||
add_blank_at: Add blank to labels in the specified order ("last") or after tokens (any non None),
|
||||
if None then no blank in labels.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
tokens = list(tokens)
|
||||
self.pad, tokens = len(tokens), tokens + [pad] # Padding
|
||||
|
||||
if add_blank_at is not None:
|
||||
self.blank, tokens = len(tokens), tokens + [blank] # Reserved for blank from asr-model
|
||||
else:
|
||||
# use add_blank_at=None only for ASR where blank is added automatically, disable blank here
|
||||
self.blank = None
|
||||
|
||||
self.oov, tokens = len(tokens), tokens + [oov] # Out Of Vocabulary
|
||||
|
||||
if add_blank_at == "last":
|
||||
tokens[-1], tokens[-2] = tokens[-2], tokens[-1]
|
||||
self.oov, self.blank = self.blank, self.oov
|
||||
|
||||
self.tokens = tokens
|
||||
self.sep = sep
|
||||
|
||||
self._util_ids = {self.pad, self.blank, self.oov}
|
||||
self._token2id = {l: i for i, l in enumerate(tokens)}
|
||||
self._id2token = tokens
|
||||
|
||||
def __call__(self, text: str) -> List[int]:
|
||||
return self.encode(text)
|
||||
|
||||
@abc.abstractmethod
|
||||
def encode(self, text: str) -> List[int]:
|
||||
"""Turns str text into int tokens."""
|
||||
pass
|
||||
|
||||
def decode(self, tokens: List[int]) -> str:
|
||||
"""Turns ints tokens into str text."""
|
||||
return self.sep.join(self._id2token[t] for t in tokens if t not in self._util_ids)
|
||||
|
||||
|
||||
class EnglishCharsTokenizer(BaseTokenizer):
|
||||
# fmt: off
|
||||
PUNCT_LIST = ( # Derived from LJSpeech and "/" additionally
|
||||
',', '.', '!', '?', '-',
|
||||
':', ';', '/', '"', '(',
|
||||
')', '[', ']', '{', '}',
|
||||
)
|
||||
# fmt: on
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
punct=True,
|
||||
apostrophe=True,
|
||||
add_blank_at=None,
|
||||
pad_with_space=False,
|
||||
non_default_punct_list=None,
|
||||
text_preprocessing_func=english_text_preprocessing,
|
||||
word_tokenize_func=english_word_tokenize,
|
||||
):
|
||||
"""English char-based tokenizer.
|
||||
Args:
|
||||
punct: Whether to reserve grapheme for basic punctuation or not.
|
||||
apostrophe: Whether to use apostrophe or not.
|
||||
add_blank_at: Add blank to labels in the specified order ("last") or after tokens (any non None),
|
||||
if None then no blank in labels.
|
||||
pad_with_space: Whether to pad text with spaces at the beginning and at the end or not.
|
||||
non_default_punct_list: List of punctuation marks which will be used instead default.
|
||||
text_preprocessing_func: Function for preprocessing raw text.
|
||||
word_tokenize_func: Function for tokenizing text to words.
|
||||
"""
|
||||
|
||||
tokens = []
|
||||
self.space, tokens = len(tokens), tokens + [' '] # Space
|
||||
tokens.extend(string.ascii_lowercase)
|
||||
if apostrophe:
|
||||
tokens.append("'") # Apostrophe for saving "don't" and "Joe's"
|
||||
|
||||
if punct:
|
||||
if non_default_punct_list is not None:
|
||||
self.PUNCT_LIST = non_default_punct_list
|
||||
tokens.extend(self.PUNCT_LIST)
|
||||
|
||||
super().__init__(tokens, add_blank_at=add_blank_at)
|
||||
|
||||
self.punct = punct
|
||||
self.pad_with_space = pad_with_space
|
||||
|
||||
self.text_preprocessing_func = text_preprocessing_func
|
||||
self.word_tokenize_func = word_tokenize_func
|
||||
|
||||
def encode(self, text):
|
||||
"""See base class."""
|
||||
cs, space, tokens = [], self.tokens[self.space], set(self.tokens)
|
||||
|
||||
words = [
|
||||
word[0] if isinstance(word, tuple) else word
|
||||
for word in self.word_tokenize_func(self.text_preprocessing_func(text))
|
||||
]
|
||||
for c in "".join(words): # noqa
|
||||
# Add space if last one isn't one
|
||||
if c == space and len(cs) > 0 and cs[-1] != space:
|
||||
cs.append(c)
|
||||
|
||||
# Add next char
|
||||
if (c.isalnum() or c == "'") and c in tokens:
|
||||
cs.append(c)
|
||||
|
||||
# Add punct
|
||||
if (c in self.PUNCT_LIST) and self.punct:
|
||||
cs.append(c)
|
||||
|
||||
# Remove trailing spaces
|
||||
while cs[-1] == space:
|
||||
cs.pop()
|
||||
|
||||
if self.pad_with_space:
|
||||
cs = [space] + cs + [space]
|
||||
|
||||
return [self._token2id[p] for p in cs]
|
||||
|
||||
|
||||
class EnglishPhonemesTokenizer(BaseTokenizer):
|
||||
# fmt: off
|
||||
PUNCT_LIST = ( # Derived from LJSpeech and "/" additionally
|
||||
',', '.', '!', '?', '-',
|
||||
':', ';', '/', '"', '(',
|
||||
')', '[', ']', '{', '}',
|
||||
)
|
||||
VOWELS = (
|
||||
'AA', 'AE', 'AH', 'AO', 'AW',
|
||||
'AY', 'EH', 'ER', 'EY', 'IH',
|
||||
'IY', 'OW', 'OY', 'UH', 'UW',
|
||||
)
|
||||
CONSONANTS = (
|
||||
'B', 'CH', 'D', 'DH', 'F', 'G',
|
||||
'HH', 'JH', 'K', 'L', 'M', 'N',
|
||||
'NG', 'P', 'R', 'S', 'SH', 'T',
|
||||
'TH', 'V', 'W', 'Y', 'Z', 'ZH',
|
||||
)
|
||||
# fmt: on
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
g2p,
|
||||
punct=True,
|
||||
non_default_punct_list=None,
|
||||
stresses=False,
|
||||
chars=False,
|
||||
*,
|
||||
space=' ',
|
||||
silence=None,
|
||||
apostrophe=True,
|
||||
oov=BaseTokenizer.OOV,
|
||||
sep='|', # To be able to distinguish between 2/3 letters codes.
|
||||
add_blank_at=None,
|
||||
pad_with_space=False,
|
||||
):
|
||||
"""English phoneme-based tokenizer.
|
||||
Args:
|
||||
g2p: Grapheme to phoneme module.
|
||||
punct: Whether to reserve grapheme for basic punctuation or not.
|
||||
non_default_punct_list: List of punctuation marks which will be used instead default.
|
||||
stresses: Whether to use phonemes codes with stresses (0-2) or not.
|
||||
chars: Whether to additionally use chars together with phonemes. It is useful if g2p module can return chars too.
|
||||
space: Space token as string.
|
||||
silence: Silence token as string (will be disabled if it is None).
|
||||
apostrophe: Whether to use apostrophe or not.
|
||||
oov: OOV token as string.
|
||||
sep: Separation token as string.
|
||||
add_blank_at: Add blank to labels in the specified order ("last") or after tokens (any non None),
|
||||
if None then no blank in labels.
|
||||
pad_with_space: Whether to pad text with spaces at the beginning and at the end or not.
|
||||
"""
|
||||
|
||||
tokens = []
|
||||
self.space, tokens = len(tokens), tokens + [space] # Space
|
||||
|
||||
if silence is not None:
|
||||
self.silence, tokens = len(tokens), tokens + [silence] # Silence
|
||||
|
||||
tokens.extend(self.CONSONANTS)
|
||||
vowels = list(self.VOWELS)
|
||||
|
||||
if stresses:
|
||||
vowels = [f'{p}{s}' for p, s in itertools.product(vowels, (0, 1, 2))]
|
||||
tokens.extend(vowels)
|
||||
|
||||
if chars:
|
||||
tokens.extend(string.ascii_lowercase)
|
||||
|
||||
if apostrophe:
|
||||
tokens.append("'") # Apostrophe
|
||||
|
||||
if punct:
|
||||
if non_default_punct_list is not None:
|
||||
self.PUNCT_LIST = non_default_punct_list
|
||||
tokens.extend(self.PUNCT_LIST)
|
||||
|
||||
super().__init__(tokens, oov=oov, sep=sep, add_blank_at=add_blank_at)
|
||||
|
||||
self.chars = chars
|
||||
self.punct = punct
|
||||
self.stresses = stresses
|
||||
self.pad_with_space = pad_with_space
|
||||
|
||||
self.g2p = g2p
|
||||
|
||||
def encode(self, text):
|
||||
"""See base class."""
|
||||
ps, space, tokens = [], self.tokens[self.space], set(self.tokens)
|
||||
|
||||
for p in self.g2p(text): # noqa
|
||||
# Remove stress
|
||||
if p.isalnum() and len(p) == 3 and not self.stresses:
|
||||
p = p[:2]
|
||||
|
||||
# Add space if last one isn't one
|
||||
if p == space and len(ps) > 0 and ps[-1] != space:
|
||||
ps.append(p)
|
||||
|
||||
# Add next phoneme
|
||||
if (p.isalnum() or p == "'") and p in tokens:
|
||||
ps.append(p)
|
||||
|
||||
# Add punct
|
||||
if (p in self.PUNCT_LIST) and self.punct:
|
||||
ps.append(p)
|
||||
|
||||
# Remove trailing spaces
|
||||
while ps[-1] == space:
|
||||
ps.pop()
|
||||
|
||||
if self.pad_with_space:
|
||||
ps = [space] + ps + [space]
|
||||
|
||||
return [self._token2id[p] for p in ps]
|
|
@ -3,5 +3,5 @@ pypinyin
|
|||
attrdict
|
||||
pystoi
|
||||
pesq
|
||||
g2p_en
|
||||
pandas
|
||||
inflect
|
||||
|
|
134092
scripts/tts_dataset_files/cmudict-0.7b-030921
Normal file
134092
scripts/tts_dataset_files/cmudict-0.7b-030921
Normal file
File diff suppressed because it is too large
Load diff
413
scripts/tts_dataset_files/heteronyms-030921
Normal file
413
scripts/tts_dataset_files/heteronyms-030921
Normal file
|
@ -0,0 +1,413 @@
|
|||
abject
|
||||
abrogate
|
||||
absent
|
||||
abstract
|
||||
abuse
|
||||
ache
|
||||
acre
|
||||
acuminate
|
||||
addict
|
||||
address
|
||||
adduct
|
||||
adele
|
||||
advocate
|
||||
affect
|
||||
affiliate
|
||||
agape
|
||||
aged
|
||||
agglomerate
|
||||
aggregate
|
||||
agonic
|
||||
agora
|
||||
allied
|
||||
ally
|
||||
alternate
|
||||
alum
|
||||
am
|
||||
analyses
|
||||
andrea
|
||||
animate
|
||||
apply
|
||||
appropriate
|
||||
approximate
|
||||
ares
|
||||
arithmetic
|
||||
arsenic
|
||||
articulate
|
||||
associate
|
||||
attribute
|
||||
august
|
||||
axes
|
||||
ay
|
||||
aye
|
||||
bases
|
||||
bass
|
||||
bathed
|
||||
bested
|
||||
bifurcate
|
||||
blessed
|
||||
blotto
|
||||
bow
|
||||
bowed
|
||||
bowman
|
||||
brassy
|
||||
buffet
|
||||
bustier
|
||||
carbonate
|
||||
celtic
|
||||
choral
|
||||
chumash
|
||||
close
|
||||
closer
|
||||
coax
|
||||
coincidence
|
||||
color coordinate
|
||||
colour coordinate
|
||||
comber
|
||||
combine
|
||||
combs
|
||||
committee
|
||||
commune
|
||||
compact
|
||||
complex
|
||||
compound
|
||||
compress
|
||||
concert
|
||||
conduct
|
||||
confine
|
||||
confines
|
||||
conflict
|
||||
conglomerate
|
||||
conscript
|
||||
conserve
|
||||
consist
|
||||
console
|
||||
consort
|
||||
construct
|
||||
consult
|
||||
consummate
|
||||
content
|
||||
contest
|
||||
contract
|
||||
contracts
|
||||
contrast
|
||||
converse
|
||||
convert
|
||||
convict
|
||||
coop
|
||||
coordinate
|
||||
covey
|
||||
crooked
|
||||
curate
|
||||
cussed
|
||||
decollate
|
||||
decrease
|
||||
defect
|
||||
defense
|
||||
delegate
|
||||
deliberate
|
||||
denier
|
||||
desert
|
||||
detail
|
||||
deviate
|
||||
diagnoses
|
||||
diffuse
|
||||
digest
|
||||
discard
|
||||
discharge
|
||||
discount
|
||||
do
|
||||
document
|
||||
does
|
||||
dogged
|
||||
domesticate
|
||||
dominican
|
||||
dove
|
||||
dr
|
||||
drawer
|
||||
duplicate
|
||||
egress
|
||||
ejaculate
|
||||
eject
|
||||
elaborate
|
||||
ellipses
|
||||
email
|
||||
emu
|
||||
entrace
|
||||
entrance
|
||||
escort
|
||||
estimate
|
||||
eta
|
||||
etna
|
||||
evening
|
||||
excise
|
||||
excuse
|
||||
exploit
|
||||
export
|
||||
extract
|
||||
fine
|
||||
flower
|
||||
forbear
|
||||
four-legged
|
||||
frequent
|
||||
furrier
|
||||
gallant
|
||||
gel
|
||||
geminate
|
||||
gillie
|
||||
glower
|
||||
gotham
|
||||
graduate
|
||||
haggis
|
||||
heavy
|
||||
hinder
|
||||
house
|
||||
housewife
|
||||
impact
|
||||
imped
|
||||
implant
|
||||
implement
|
||||
import
|
||||
impress
|
||||
incense
|
||||
incline
|
||||
increase
|
||||
infix
|
||||
insert
|
||||
instar
|
||||
insult
|
||||
integral
|
||||
intercept
|
||||
interchange
|
||||
interflow
|
||||
interleaf
|
||||
intermediate
|
||||
intern
|
||||
interspace
|
||||
intimate
|
||||
intrigue
|
||||
invalid
|
||||
invert
|
||||
invite
|
||||
irony
|
||||
jagged
|
||||
jesses
|
||||
julies
|
||||
kite
|
||||
laminate
|
||||
laos
|
||||
lather
|
||||
lead
|
||||
learned
|
||||
leasing
|
||||
lech
|
||||
legitimate
|
||||
lied
|
||||
lima
|
||||
lipread
|
||||
live
|
||||
lower
|
||||
lunged
|
||||
maas
|
||||
magdalen
|
||||
manes
|
||||
mare
|
||||
marked
|
||||
merchandise
|
||||
merlion
|
||||
minute
|
||||
misconduct
|
||||
misled
|
||||
misprint
|
||||
mobile
|
||||
moderate
|
||||
mong
|
||||
moped
|
||||
moth
|
||||
mouth
|
||||
mow
|
||||
mpg
|
||||
multiply
|
||||
mush
|
||||
nana
|
||||
nice
|
||||
nice
|
||||
number
|
||||
numerate
|
||||
nun
|
||||
object
|
||||
opiate
|
||||
ornament
|
||||
outbox
|
||||
outcry
|
||||
outpour
|
||||
outreach
|
||||
outride
|
||||
outright
|
||||
outside
|
||||
outwork
|
||||
overall
|
||||
overbid
|
||||
overcall
|
||||
overcast
|
||||
overfall
|
||||
overflow
|
||||
overhaul
|
||||
overhead
|
||||
overlap
|
||||
overlay
|
||||
overuse
|
||||
overweight
|
||||
overwork
|
||||
pace
|
||||
palled
|
||||
palling
|
||||
para
|
||||
pasty
|
||||
pate
|
||||
pauline
|
||||
pedal
|
||||
peer
|
||||
perfect
|
||||
periodic
|
||||
permit
|
||||
pervert
|
||||
pinta
|
||||
placer
|
||||
platy
|
||||
polish
|
||||
polish
|
||||
poll
|
||||
pontificate
|
||||
postulate
|
||||
pram
|
||||
prayer
|
||||
precipitate
|
||||
predate
|
||||
predicate
|
||||
prefix
|
||||
preposition
|
||||
present
|
||||
pretest
|
||||
primer
|
||||
proceeds
|
||||
produce
|
||||
progress
|
||||
project
|
||||
proportionate
|
||||
prospect
|
||||
protest
|
||||
pussy
|
||||
putter
|
||||
putting
|
||||
quite
|
||||
ragged
|
||||
raven
|
||||
re
|
||||
read
|
||||
reading
|
||||
reading
|
||||
real
|
||||
rebel
|
||||
recall
|
||||
recap
|
||||
recitative
|
||||
recollect
|
||||
record
|
||||
recreate
|
||||
recreation
|
||||
redress
|
||||
refill
|
||||
refund
|
||||
refuse
|
||||
reject
|
||||
relay
|
||||
remake
|
||||
repaint
|
||||
reprint
|
||||
reread
|
||||
rerun
|
||||
resent
|
||||
reside
|
||||
resign
|
||||
respray
|
||||
resume
|
||||
retard
|
||||
retest
|
||||
retread
|
||||
rewrite
|
||||
root
|
||||
routed
|
||||
routing
|
||||
row
|
||||
rugged
|
||||
rummy
|
||||
sais
|
||||
sake
|
||||
sambuca
|
||||
saucier
|
||||
second
|
||||
secrete
|
||||
secreted
|
||||
secreting
|
||||
segment
|
||||
separate
|
||||
sewer
|
||||
shirk
|
||||
shower
|
||||
sin
|
||||
skied
|
||||
slaver
|
||||
slough
|
||||
sow
|
||||
spoof
|
||||
squid
|
||||
stingy
|
||||
subject
|
||||
subordinate
|
||||
subvert
|
||||
supply
|
||||
supposed
|
||||
survey
|
||||
suspect
|
||||
syringes
|
||||
tabulate
|
||||
tales
|
||||
tarrier
|
||||
tarry
|
||||
taxes
|
||||
taxis
|
||||
tear
|
||||
theron
|
||||
thou
|
||||
three-legged
|
||||
tier
|
||||
tinged
|
||||
torment
|
||||
transfer
|
||||
transform
|
||||
transplant
|
||||
transport
|
||||
transpose
|
||||
tush
|
||||
two-legged
|
||||
unionised
|
||||
unionized
|
||||
update
|
||||
uplift
|
||||
upset
|
||||
use
|
||||
used
|
||||
vale
|
||||
violist
|
||||
viva
|
||||
ware
|
||||
whinged
|
||||
whoop
|
||||
wicked
|
||||
wind
|
||||
windy
|
||||
wino
|
||||
won
|
||||
worsted
|
||||
wound
|
|
@ -17,10 +17,12 @@ import os
|
|||
import pytest
|
||||
import torch
|
||||
|
||||
from nemo.collections.tts.torch.data import CharMelAudioDataset
|
||||
from nemo.collections.tts.torch.data import TTSDataset
|
||||
from nemo.collections.tts.torch.g2ps import EnglishG2p
|
||||
from nemo.collections.tts.torch.tts_tokenizers import EnglishPhonemesTokenizer
|
||||
|
||||
|
||||
class TestCharDataset:
|
||||
class TestTTSDataset:
|
||||
@pytest.mark.run_only_on('CPU')
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.torch_tts
|
||||
|
@ -28,27 +30,21 @@ class TestCharDataset:
|
|||
manifest_path = os.path.join(test_data_dir, 'tts/mini_ljspeech/manifest.json')
|
||||
sup_path = os.path.join(test_data_dir, 'tts/mini_ljspeech/sup')
|
||||
|
||||
dataset = CharMelAudioDataset(
|
||||
manifest_filepath=manifest_path, sample_rate=22050, supplementary_folder=sup_path
|
||||
dataset = TTSDataset(
|
||||
manifest_filepath=manifest_path,
|
||||
sample_rate=22050,
|
||||
sup_data_types=["pitch"],
|
||||
sup_data_path=sup_path,
|
||||
text_tokenizer=EnglishPhonemesTokenizer(
|
||||
punct=True,
|
||||
stresses=True,
|
||||
chars=True,
|
||||
space=' ',
|
||||
apostrophe=True,
|
||||
pad_with_space=True,
|
||||
g2p=EnglishG2p(),
|
||||
),
|
||||
)
|
||||
|
||||
dataloader = torch.utils.data.DataLoader(dataset, 2, collate_fn=dataset._collate_fn)
|
||||
|
||||
data, _, _, _, _, _, _ = next(iter(dataloader))
|
||||
|
||||
|
||||
class TestPhoneDataset:
|
||||
@pytest.mark.run_only_on('CPU')
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.torch_tts
|
||||
def test_dataset(self, test_data_dir):
|
||||
manifest_path = os.path.join(test_data_dir, 'tts/mini_ljspeech/manifest.json')
|
||||
sup_path = os.path.join(test_data_dir, 'tts/mini_ljspeech/sup')
|
||||
|
||||
dataset = CharMelAudioDataset(
|
||||
manifest_filepath=manifest_path, sample_rate=22050, supplementary_folder=sup_path
|
||||
)
|
||||
|
||||
dataloader = torch.utils.data.DataLoader(dataset, 2, collate_fn=dataset._collate_fn)
|
||||
|
||||
_, _, _, _, _, _, _ = next(iter(dataloader))
|
||||
data, _, _, _, _, _ = next(iter(dataloader))
|
||||
|
|
Loading…
Reference in a new issue