DeepLearningExamples/CUDA-Optimized/FastSpeech/generate.py
Dabi Ahn fd32b990ac [CUDA-Optimized/FastSpeech]
- support for PyTorch 1.7 and TensorRT 7.2
- limit sample audio file length
2020-11-02 21:17:00 +08:00

171 lines
7.3 KiB
Python

# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# * Neither the name of the NVIDIA CORPORATION nor the
# names of its contributors may be used to endorse or promote products
# derived from this software without specific prior written permission.
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
# DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
# (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
# ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import os
import pathlib
import sys
import time
import fire
import librosa
import torch
from fastspeech.data_load import PadDataLoader
from fastspeech.dataset.text_dataset import TextDataset
from fastspeech.inferencer.fastspeech_inferencer import FastSpeechInferencer
from fastspeech.model.fastspeech import Fastspeech
from fastspeech import hparam as hp, DEFAULT_DEVICE
from fastspeech.utils.logging import tprint
from fastspeech.utils.time import TimeElapsed
from fastspeech.utils.pytorch import to_device_async, to_cpu_numpy
from fastspeech.infer import get_inferencer
from fastspeech.inferencer.waveglow_inferencer import WaveGlowInferencer
MAX_FILESIZE=128
# TODO test with different speeds
def generate(hparam='infer.yaml',
text='test_sentences.txt',
results_path='results',
device=DEFAULT_DEVICE,
**kwargs):
"""The script for generating waveforms from texts with a vocoder.
By default, this script assumes to load parameters in the default config file, fastspeech/hparams/infer.yaml.
Besides the flags, you can also set parameters in the config file via the command-line. For examples,
--checkpoint_path=CHECKPOINT_PATH
Path to checkpoint directory. The latest checkpoint will be loaded.
--waveglow_path=WAVEGLOW_PATH
Path to the WaveGlow checkpoint file.
--waveglow_engine_path=WAVEGLOW_ENGINE_PATH
Path to the WaveGlow engine file. It can be only used with --use_trt=True.
--batch_size=BATCH_SIZE
Batch size to use. Defaults to 1.
Refer to fastspeech/hparams/infer.yaml to see more parameters.
Args:
hparam (str, optional): Path to default config file. Defaults to "infer.yaml".
text (str, optional): a sample text or a text file path to generate its waveform. Defaults to 'test_sentences.txt'.
results_path (str, optional): Path to output waveforms directory. Defaults to 'results'.
device (str, optional): Device to use. Defaults to "cuda" if avaiable, or "cpu".
"""
hp.set_hparam(hparam, kwargs)
if os.path.isfile(text):
f = open(text, 'r', encoding="utf-8")
texts = f.read().splitlines()
else: # single string
texts = [text]
dataset = TextDataset(texts)
data_loader = PadDataLoader(dataset,
batch_size=hp.batch_size,
num_workers=hp.n_workers,
shuffle=False,
drop_last=False)
# text to mel
model = Fastspeech(
max_seq_len=hp.max_seq_len,
d_model=hp.d_model,
phoneme_side_n_layer=hp.phoneme_side_n_layer,
phoneme_side_head=hp.phoneme_side_head,
phoneme_side_conv1d_filter_size=hp.phoneme_side_conv1d_filter_size,
phoneme_side_output_size=hp.phoneme_side_output_size,
mel_side_n_layer=hp.mel_side_n_layer,
mel_side_head=hp.mel_side_head,
mel_side_conv1d_filter_size=hp.mel_side_conv1d_filter_size,
mel_side_output_size=hp.mel_side_output_size,
duration_predictor_filter_size=hp.duration_predictor_filter_size,
duration_predictor_kernel_size=hp.duration_predictor_kernel_size,
fft_conv1d_kernel=hp.fft_conv1d_kernel,
fft_conv1d_padding=hp.fft_conv1d_padding,
dropout=hp.dropout,
n_mels=hp.num_mels,
fused_layernorm=hp.fused_layernorm
)
fs_inferencer = get_inferencer(model, data_loader, device)
# set up WaveGlow
if hp.use_trt:
from fastspeech.trt.waveglow_trt_inferencer import WaveGlowTRTInferencer
wb_inferencer = WaveGlowTRTInferencer(
ckpt_file=hp.waveglow_path, engine_file=hp.waveglow_engine_path, use_fp16=hp.use_fp16)
else:
wb_inferencer = WaveGlowInferencer(
ckpt_file=hp.waveglow_path, device=device, use_fp16=hp.use_fp16)
tprint("Generating {} sentences.. ".format(len(dataset)))
with fs_inferencer, wb_inferencer:
try:
for i in range(len(data_loader)):
tprint("------------- BATCH # {} -------------".format(i))
with TimeElapsed(name="Inferece Time: E2E", format=":.6f"):
## Text-to-Mel ##
with TimeElapsed(name="Inferece Time: FastSpeech", device=device, cuda_sync=True, format=":.6f"), torch.no_grad():
outputs = fs_inferencer.infer()
texts = outputs["text"]
mels = outputs["mel"] # (b, n_mels, t)
mel_masks = outputs['mel_mask'] # (b, t)
# assert(mels.is_cuda)
# remove paddings
mel_lens = mel_masks.sum(axis=1)
max_len = mel_lens.max()
mels = mels[..., :max_len]
mel_masks = mel_masks[..., :max_len]
## Vocoder ##
with TimeElapsed(name="Inferece Time: WaveGlow", device=device, cuda_sync=True, format=":.6f"), torch.no_grad():
wavs = wb_inferencer.infer(mels)
wavs = to_cpu_numpy(wavs)
## Write wavs ##
pathlib.Path(results_path).mkdir(parents=True, exist_ok=True)
for i, (text, wav) in enumerate(zip(texts, wavs)):
tprint("TEXT #{}: \"{}\"".format(i, text))
# remove paddings in case of batch size > 1
wav_len = mel_lens[i] * hp.hop_len
wav = wav[:wav_len]
path = os.path.join(results_path, text[:MAX_FILESIZE] + ".wav")
librosa.output.write_wav(path, wav, hp.sr)
except StopIteration:
tprint("Generation has been done.")
except KeyboardInterrupt:
tprint("Generation has been canceled.")
if __name__ == '__main__':
fire.Fire(generate)