[FastPitch/PyT] Drop parselmouth dependency

This commit is contained in:
Adrian Lancucki 2021-10-07 11:26:31 -07:00 committed by Krzysztof Kudrynski
parent 61bcab7a0e
commit 054fed043f
7 changed files with 10 additions and 47 deletions

View file

@ -363,16 +363,11 @@ FastPitch 1.1 aligns input symbols to output mel-spectrogram frames automaticall
on any external aligning model. FastPitch training can now be started on raw waveforms
without any pre-processing: pitch values and mel-spectrograms will be calculated on-line.
For every mel-spectrogram frame, its fundamental frequency in Hz is estimated with either
the Probabilistic YIN algorithm or [Praat](http://praat.org).
The former is more accurate but time consuming, and we recommend to pre-calculate
pitch during the data processing step. The latter is suitable for on-line pitch calculation.
Pitch values are then averaged over every character, in order to provide sparse
pitch cues for the model.
For every mel-spectrogram frame, its fundamental frequency in Hz is estimated with
the Probabilistic YIN algorithm.
<p align="center">
<img src="./img/pitch.png" alt="Pitch estimates extracted with Praat" />
<img src="./img/pitch.png" alt="Pitch contour estimate" />
</p>
<p align="center">
<em>Figure 2. Pitch estimates for mel-spectrogram frames of phrase "in being comparatively"

View file

@ -32,7 +32,6 @@ from pathlib import Path
import librosa
import numpy as np
import parselmouth
import torch
import torch.nn.functional as F
from scipy import ndimage
@ -88,35 +87,7 @@ def estimate_pitch(wav, mel_len, method='pyin', normalize_mean=None,
if type(normalize_std) is float or type(normalize_std) is list:
normalize_std = torch.tensor(normalize_std)
if method == 'praat':
snd = parselmouth.Sound(wav)
pitch_mel = snd.to_pitch(time_step=snd.duration / (mel_len + 3)
).selected_array['frequency']
assert np.abs(mel_len - pitch_mel.shape[0]) <= 1.0
pitch_mel = torch.from_numpy(pitch_mel).unsqueeze(0)
if n_formants > 1:
formant = snd.to_formant_burg(
time_step=snd.duration / (mel_len + 3))
formant_n_frames = formant.get_number_of_frames()
assert np.abs(mel_len - formant_n_frames) <= 1.0
formants_mel = np.zeros((formant_n_frames + 1, n_formants - 1))
for i in range(1, formant_n_frames + 1):
formants_mel[i] = np.asarray([
formant.get_value_at_time(
formant_number=f,
time=formant.get_time_from_frame_number(i))
for f in range(1, n_formants)
])
pitch_mel = torch.cat(
[pitch_mel, torch.from_numpy(formants_mel).permute(1, 0)],
dim=0)
elif method == 'pyin':
if method == 'pyin':
snd, sr = librosa.load(wav)
pitch_mel, voiced_flag, voiced_probs = librosa.pyin(
@ -181,7 +152,7 @@ class TTSDataset(torch.utils.data.Dataset):
pitch_online_dir=None,
betabinomial_online_dir=None,
use_betabinomial_interpolator=True,
pitch_online_method='praat',
pitch_online_method='pyin',
**ignored):
# Expect a list of filenames
@ -338,7 +309,7 @@ class TTSDataset(torch.utils.data.Dataset):
if cached_fpath.is_file():
return torch.load(cached_fpath)
# No luck so far - calculate or replace with praat
# No luck so far - calculate
wav = audiopath
if not wav.endswith('.wav'):
wav = re.sub('/mels/', '/wavs/', wav)

View file

@ -73,7 +73,7 @@ def parse_args(parser):
parser.add_argument('--n-mel-channels', type=int, default=80)
# Pitch extraction
parser.add_argument('--f0-method', default='pyin', type=str,
choices=('pyin', 'praat'), help='F0 estimation method')
choices=['pyin'], help='F0 estimation method')
# Performance
parser.add_argument('-b', '--batch-size', default=1, type=int)
parser.add_argument('--n-workers', type=int, default=16)

View file

@ -4,6 +4,5 @@ inflect
librosa==0.8.0
scipy
Unidecode
praat-parselmouth==0.3.3
tensorboardX==2.0
git+git://github.com/NVIDIA/dllogger.git@26a0f8f1958de2c0c460925ff6102a4d2486d6cc#egg=dllogger

View file

@ -3,7 +3,6 @@
set -e
: ${DATA_DIR:=LJSpeech-1.1}
: ${F0_METHOD:="pyin"}
: ${ARGS="--extract-mels"}
python prepare_dataset.py \
@ -12,5 +11,5 @@ python prepare_dataset.py \
--batch-size 1 \
--dataset-path $DATA_DIR \
--extract-pitch \
--f0-method $F0_METHOD \
--f0-method pyin \
$ARGS

View file

@ -6,7 +6,6 @@ set -a
: ${NUM_GPUS_SEQUENCE:="1 4 8"}
: ${EPOCHS:=30}
: ${OUTPUT_DIR:="./output"}
: ${F0_METHOD:=praat}
: ${BATCH_SIZE:=16}
for NUM_GPUS in $NUM_GPUS_SEQUENCE ; do

View file

@ -147,8 +147,8 @@ def parse_args(parser):
'n_speakers > 1 enables speaker embeddings')
cond.add_argument('--load-pitch-from-disk', action='store_true',
help='Use pitch cached on disk with prepare_dataset.py')
cond.add_argument('--pitch-online-method', default='praat',
choices=['praat', 'pyin'],
cond.add_argument('--pitch-online-method', default='pyin',
choices=['pyin'],
help='Calculate pitch on the fly during trainig')
cond.add_argument('--pitch-online-dir', type=str, default=None,
help='A directory for storing pitch calculated on-line')