DeepLearningExamples/PyTorch/SpeechSynthesis/FastPitch/prepare_dataset.py
2021-10-07 11:26:31 -07:00

175 lines
7.4 KiB
Python

# *****************************************************************************
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# * Neither the name of the NVIDIA CORPORATION nor the
# names of its contributors may be used to endorse or promote products
# derived from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
# DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
# (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
# ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
# *****************************************************************************
import argparse
import time
from pathlib import Path
import torch
import tqdm
import dllogger as DLLogger
from dllogger import StdOutBackend, JSONStreamBackend, Verbosity
from torch.utils.data import DataLoader
from fastpitch.data_function import TTSCollate, TTSDataset
def parse_args(parser):
"""
Parse commandline arguments.
"""
parser.add_argument('-d', '--dataset-path', type=str,
default='./', help='Path to dataset')
parser.add_argument('--wav-text-filelists', required=True, nargs='+',
type=str, help='Files with audio paths and text')
parser.add_argument('--extract-mels', action='store_true',
help='Calculate spectrograms from .wav files')
parser.add_argument('--extract-pitch', action='store_true',
help='Extract pitch')
parser.add_argument('--save-alignment-priors', action='store_true',
help='Pre-calculate diagonal matrices of alignment of text to audio')
parser.add_argument('--log-file', type=str, default='preproc_log.json',
help='Filename for logging')
parser.add_argument('--n-speakers', type=int, default=1)
# Mel extraction
parser.add_argument('--max-wav-value', default=32768.0, type=float,
help='Maximum audiowave value')
parser.add_argument('--sampling-rate', default=22050, type=int,
help='Sampling rate')
parser.add_argument('--filter-length', default=1024, type=int,
help='Filter length')
parser.add_argument('--hop-length', default=256, type=int,
help='Hop (stride) length')
parser.add_argument('--win-length', default=1024, type=int,
help='Window length')
parser.add_argument('--mel-fmin', default=0.0, type=float,
help='Minimum mel frequency')
parser.add_argument('--mel-fmax', default=8000.0, type=float,
help='Maximum mel frequency')
parser.add_argument('--n-mel-channels', type=int, default=80)
# Pitch extraction
parser.add_argument('--f0-method', default='pyin', type=str,
choices=['pyin'], help='F0 estimation method')
# Performance
parser.add_argument('-b', '--batch-size', default=1, type=int)
parser.add_argument('--n-workers', type=int, default=16)
return parser
def main():
parser = argparse.ArgumentParser(description='FastPitch Data Pre-processing')
parser = parse_args(parser)
args, unk_args = parser.parse_known_args()
if len(unk_args) > 0:
raise ValueError(f'Invalid options {unk_args}')
DLLogger.init(backends=[JSONStreamBackend(Verbosity.DEFAULT, Path(args.dataset_path, args.log_file)),
StdOutBackend(Verbosity.VERBOSE)])
for k, v in vars(args).items():
DLLogger.log(step="PARAMETER", data={k: v})
DLLogger.flush()
if args.extract_mels:
Path(args.dataset_path, 'mels').mkdir(parents=False, exist_ok=True)
if args.extract_pitch:
Path(args.dataset_path, 'pitch').mkdir(parents=False, exist_ok=True)
if args.save_alignment_priors:
Path(args.dataset_path, 'alignment_priors').mkdir(parents=False, exist_ok=True)
for filelist in args.wav_text_filelists:
print(f'Processing {filelist}...')
dataset = TTSDataset(
args.dataset_path,
filelist,
text_cleaners=['english_cleaners_v2'],
n_mel_channels=args.n_mel_channels,
p_arpabet=0.0,
n_speakers=args.n_speakers,
load_mel_from_disk=False,
load_pitch_from_disk=False,
pitch_mean=None,
pitch_std=None,
max_wav_value=args.max_wav_value,
sampling_rate=args.sampling_rate,
filter_length=args.filter_length,
hop_length=args.hop_length,
win_length=args.win_length,
mel_fmin=args.mel_fmin,
mel_fmax=args.mel_fmax,
betabinomial_online_dir=None,
pitch_online_dir=None,
pitch_online_method=args.f0_method)
data_loader = DataLoader(
dataset,
batch_size=args.batch_size,
shuffle=False,
sampler=None,
num_workers=args.n_workers,
collate_fn=TTSCollate(),
pin_memory=False,
drop_last=False)
all_filenames = set()
for i, batch in enumerate(tqdm.tqdm(data_loader)):
tik = time.time()
_, input_lens, mels, mel_lens, _, pitch, _, _, attn_prior, fpaths = batch
# Ensure filenames are unique
for p in fpaths:
fname = Path(p).name
if fname in all_filenames:
raise ValueError(f'Filename is not unique: {fname}')
all_filenames.add(fname)
if args.extract_mels:
for j, mel in enumerate(mels):
fname = Path(fpaths[j]).with_suffix('.pt').name
fpath = Path(args.dataset_path, 'mels', fname)
torch.save(mel[:, :mel_lens[j]], fpath)
if args.extract_pitch:
for j, p in enumerate(pitch):
fname = Path(fpaths[j]).with_suffix('.pt').name
fpath = Path(args.dataset_path, 'pitch', fname)
torch.save(p[:mel_lens[j]], fpath)
if args.save_alignment_priors:
for j, prior in enumerate(attn_prior):
fname = Path(fpaths[j]).with_suffix('.pt').name
fpath = Path(args.dataset_path, 'alignment_priors', fname)
torch.save(prior[:mel_lens[j], :input_lens[j]], fpath)
if __name__ == '__main__':
main()