[Tacotron2] Added denoiser and inference stats, fixed typos

This commit is contained in:
Przemek Strzelczyk 2019-09-10 16:22:53 +02:00
parent da8acb1288
commit 02b49acead
9 changed files with 540 additions and 61 deletions

View file

@ -1,4 +1,4 @@
FROM nvcr.io/nvidia/pytorch:19.07-py3
FROM nvcr.io/nvidia/pytorch:19.08-py3
ADD . /workspace/tacotron2
WORKDIR /workspace/tacotron2

View file

@ -1,4 +1,4 @@
# Tacotron 2 And WaveGlow v1.6 For PyTorch
# Tacotron 2 And WaveGlow v1.7 For PyTorch
This repository provides a script and recipe to train Tacotron 2 and WaveGlow
v1.6 models to achieve state of the art accuracy, and is tested and maintained by NVIDIA.
@ -38,7 +38,8 @@ v1.6 models to achieve state of the art accuracy, and is tested and maintained b
* [NVIDIA DGX-1 (8x V100 16G)](#nvidia-dgx-1-8x-v100-16g)
* [Expected training time](#expected-training-time)
* [Inference performance results](#inference-performance-results)
* [NVIDIA DGX-1 (8x V100 16G)](#nvidia-dgx-1-8x-v100-16g)
* [NVIDIA V100 16G](#nvidia-v100-16g)
* [NVIDIA T4](#nvidia-t4)
* [Release notes](#release-notes)
* [Changelog](#changelog)
* [Known issues](#known-issues)
@ -99,7 +100,7 @@ into spherical Gaussian distribution through a series of flows. One step of a
flow consists of an invertible convolution, followed by a modified WaveNet
architecture that serves as an affine coupling layer. During inference, the
network is inverted and audio samples are generated from the Gaussian
distribution.
distribution. Our implementation uses 512 residual channels in the coupling layer.
![](./img/waveglow_arch.png "WaveGlow architecture")
@ -130,16 +131,16 @@ The following features are supported by this model.
|[AMP](https://nvidia.github.io/apex/amp.html) | Yes | Yes |
|[Apex DistributedDataParallel](https://nvidia.github.io/apex/parallel.html) | Yes | Yes |
#### Features
#### Features
AMP - a tool that enables Tensor Core-accelerated training. For more information,
AMP - a tool that enables Tensor Core-accelerated training. For more information,
refer to [Enabling mixed precision](#enabling-mixed-precision).
Apex DistributedDataParallel - a module wrapper that enables easy multiprocess
distributed data parallel training, similar to `torch.nn.parallel.DistributedDataParallel`.
`DistributedDataParallel` is optimized for use with NCCL. It achieves high
performance by overlapping communication with computation during `backward()`
and bucketing smaller gradient transfers to reduce the total number of transfers
Apex DistributedDataParallel - a module wrapper that enables easy multiprocess
distributed data parallel training, similar to `torch.nn.parallel.DistributedDataParallel`.
`DistributedDataParallel` is optimized for use with NCCL. It achieves high
performance by overlapping communication with computation during `backward()`
and bucketing smaller gradient transfers to reduce the total number of transfers
required.
## Mixed precision training
@ -267,16 +268,9 @@ this script, issue:
bash scripts/prepare_dataset.sh
```
To preprocess the datasets for Tacotron 2 training, use the
`./scripts/prepare_mels.sh` script:
```bash
bash scripts/prepare_mels.sh
```
Data is downloaded to the `./LJSpeech-1.1` directory (on the host). The
`./LJSpeech-1.1` directory is mounted to the `/workspace/tacotron2/LJSpeech-1.1`
location in the NGC container. The preprocessed mel-spectrograms are stored in the
`./LJSpeech-1.1/mels` directory.
`./LJSpeech-1.1` directory is mounted to the `/workspace/tacotron2/LJSpeech-1.1`
location in the NGC container.
3. Build the Tacotron 2 and WaveGlow PyTorch NGC container.
```bash
@ -290,8 +284,14 @@ After you build the container image, you can start an interactive CLI session wi
bash scripts/docker/interactive.sh
```
The `interactive.sh` script requires that the location on the dataset is specified.
For example, `LJSpeech-1.1`.
The `interactive.sh` script requires that the location on the dataset is specified.
For example, `LJSpeech-1.1`. To preprocess the datasets for Tacotron 2 training, use
the `./scripts/prepare_mels.sh` script:
```bash
bash scripts/prepare_mels.sh
```
The preprocessed mel-spectrograms are stored in the `./LJSpeech-1.1/mels` directory.
5. Start training.
To start Tacotron 2 training, run:
@ -313,8 +313,8 @@ Ensure your loss values are comparable to those listed in the table in the
samples in the `./audio` folder. For details about generating audio, see the
[Inference process](#inference-process) section below.
The training scripts automatically run the validation after each training
epoch. The results from the validation are printed to the standard output
The training scripts automatically run the validation after each training
epoch. The results from the validation are printed to the standard output
(`stdout`) and saved to the log files.
7. Start inference.
@ -327,10 +327,10 @@ and `--waveglow` arguments.
```bash
python inference.py --tacotron2 <Tacotron2_checkpoint> --waveglow <WaveGlow_checkpoint> -o output/ -i phrases/phrase.txt --amp-run
```
The speech is generated from lines of text in the file that is passed with
`-i` argument. The number of lines determines inference batch size. To run
inference in mixed precision, use the `--amp-run` flag. The output audio will
The speech is generated from lines of text in the file that is passed with
`-i` argument. The number of lines determines inference batch size. To run
inference in mixed precision, use the `--amp-run` flag. The output audio will
be stored in the path specified by the `-o` argument.
## Advanced
@ -390,11 +390,12 @@ WaveGlow models.
#### WaveGlow parameters
* `--segment-length` - segment length of input audio processed by the neural network (8000)
* `--wn-channels` - number of residual channels in the coupling layer networks (512)
### Command-line options
To see the full list of available options and their descriptions, use the `-h`
To see the full list of available options and their descriptions, use the `-h`
or `--help` command line option, for example:
```bash
python train.py --help
@ -470,8 +471,12 @@ To run inference, issue:
```bash
python inference.py --tacotron2 <Tacotron2_checkpoint> --waveglow <WaveGlow_checkpoint> -o output/ --include-warmup -i phrases/phrase.txt --amp-run
```
Here, `Tacotron2_checkpoint` and `WaveGlow_checkpoint` are pre-trained
checkpoints for the respective models, and `phrases/phrase.txt` contains input phrases. The number of text lines determines the inference batch size. Audio will be saved in the output folder.
Here, `Tacotron2_checkpoint` and `WaveGlow_checkpoint` are pre-trained
checkpoints for the respective models, and `phrases/phrase.txt` contains input
phrases. The number of text lines determines the inference batch size. Audio
will be saved in the output folder. The audio files [audio_fp16](./audio/audio_fp16.wav)
and [audio_fp32](./audio/audio_fp32.wav) were generated using checkpoints from
mixed precision and FP32 training, respectively.
You can find all the available options by calling `python inference.py --help`.
@ -548,9 +553,9 @@ To benchmark the inference performance on a batch size=1, run:
```
The output log files will contain performance numbers for Tacotron 2 model
(number of output mel-spectrograms per second, reported as `tacotron2_items_per_sec`)
and for WaveGlow (number of output samples per second, reported as `waveglow_items_per_sec`).
The `inference.py` script will run a few warmup iterations before running the benchmark.
(number of output mel-spectrograms per second, reported as `tacotron2_items_per_sec`)
and for WaveGlow (number of output samples per second, reported as `waveglow_items_per_sec`).
The `inference.py` script will run a few warmup iterations before running the benchmark.
### Results
@ -635,31 +640,36 @@ The following table shows the expected training time for convergence for WaveGlo
#### Inference performance results
##### NVIDIA DGX-1 (8x V100 16G)
The following tables show inference statistics for the Tacotron2 and WaveGlow
text-to-speech system, gathered from 1000 inference runs, on 1 V100 and 1 T4,
respectively. Latency is measured from the start of Tacotron 2 inference to
the end of WaveGlow inference. The tables include average latency, latency standard
deviation, and latency confidence intervals. Throughput is measured
as the number of generated audio samples per second. RTF is the real-time factor
which tells how many seconds of speech are generated in 1 second of compute.
Our results were obtained by running the `./inference.py` inference script in
the PyTorch-19.06-py3 NGC container on NVIDIA DGX-1 with 8x V100 16G GPUs.
Performance numbers (in output mel-spectrograms per second for Tacotron 2 and
output samples per second for WaveGlow) were averaged over 16 runs.
##### NVIDIA V100 16G
The following table shows the inference performance results for Tacotron 2 model.
Results are measured in the number of output mel-spectrograms per second.
|Batch size|Input length|Precision|Avg latency (s)|Latency std (s)|Latency confidence interval 50% (s)|Latency confidence interval 100% (s)|Throughput (samples/sec)|Speed-up with mixed precision|Avg mels generated (81 mels=1 sec of speech)|Avg audio length (s)|Avg RTF|
|---:|---:|---:|---:|---:|---:|---:|---:|---:|---:|---:|---:|
|1| 128| FP16| 1.73| 0.07| 1.72| 2.11| 89,162| 1.09| 601| 6.98| 4.04|
|4| 128| FP16| 4.21| 0.17| 4.19| 4.84| 145,800| 1.16| 600| 6.97| 1.65|
|1| 128| FP32| 1.85| 0.06| 1.84| 2.19| 81,868| 1.00| 590| 6.85| 3.71|
|4| 128| FP32| 4.80| 0.15| 4.79| 5.43| 125,930| 1.00| 590| 6.85| 1.43|
|Number of GPUs|Number of mels used with mixed precision|Number of mels used with FP32|Speed-up with mixed precision|
|---:|---:|---:|---:|
|**1**|625|613|1.02|
##### NVIDIA T4
The following table shows the inference performance results for WaveGlow model.
Results are measured in the number of output samples per second<sup>1</sup>.
|Number of GPUs|Number of samples used with mixed precision|Number of samples used with FP32|Speed-up with mixed precision|
|---:|---:|---:|---:|
|**1**|180474|162282|1.11|
<sup>1</sup>With sampling rate equal to 22050, one second of audio is generated from 22050 samples.
To achieve these same results, follow the steps in the [Quick Start Guide](#quick-start-guide).
|Batch size|Input length|Precision|Avg latency (s)|Latency std (s)|Latency confidence interval 50% (s)|Latency confidence interval 100% (s)|Throughput (samples/sec)|Speed-up with mixed precision|Avg mels generated (81 mels=1 sec of speech)|Avg audio length (s)|Avg RTF|
|---:|---:|---:|---:|---:|---:|---:|---:|---:|---:|---:|---:|
|1| 128| FP16| 3.16| 0.13| 3.16| 3.81| 48,792| 1.23| 603| 7.00| 2.21|
|4| 128| FP16| 11.45| 0.49| 11.39| 14.38| 53,771| 1.22| 601| 6.98| 0.61|
|1| 128| FP32| 3.82| 0.11| 3.81| 4.24| 39,603| 1.00| 591| 6.86| 1.80|
|4| 128| FP32| 13.80| 0.45| 13.74| 16.09| 43,915| 1.00| 592| 6.87| 0.50|
Our results were obtained by running the `./run_latency_tests.sh` script in
the PyTorch-19.06-py3 NGC container. Please note that to reproduce the results,
you need to provide pretrained checkpoints for Tacotron 2 and WaveGlow. Please
edit the script to provide your checkpoint filenames.
## Release notes
@ -674,7 +684,7 @@ June 2019
* Fixed dropouts on LSTMCells
July 2019
* Changed measurement units for Tacotron 2 training and inference performance
* Changed measurement units for Tacotron 2 training and inference performance
benchmarks from input tokes per second to output mel-spectrograms per second
* Introduced batched inference
* Included warmup in the inference script
@ -683,6 +693,10 @@ August 2019
* Fixed inference results
* Fixed initialization of Batch Normalization
September 2019
* Introduced inference statistics
### Known issues
There are no known issues in this release.

View file

@ -124,6 +124,7 @@ class STFT(torch.nn.Module):
np.where(window_sum > tiny(window_sum))[0])
window_sum = torch.autograd.Variable(
torch.from_numpy(window_sum), requires_grad=False)
window_sum = window_sum.cuda() if magnitude.is_cuda else window_sum
inverse_transform[:, :, approx_nonzero_indices] /= window_sum[approx_nonzero_indices]
# scale by hop ratio

View file

@ -41,6 +41,8 @@ from dllogger.autologging import log_hardware, log_args
from apex import amp
from waveglow.denoiser import Denoiser
def parse_args(parser):
"""
Parse commandline arguments.
@ -53,7 +55,8 @@ def parse_args(parser):
help='full path to the Tacotron2 model checkpoint file')
parser.add_argument('--waveglow', type=str,
help='full path to the WaveGlow model checkpoint file')
parser.add_argument('-s', '--sigma-infer', default=0.6, type=float)
parser.add_argument('-s', '--sigma-infer', default=0.9, type=float)
parser.add_argument('-d', '--denoising-strength', default=0.01, type=float)
parser.add_argument('-sr', '--sampling-rate', default=22050, type=int,
help='Sampling rate')
parser.add_argument('--amp-run', action='store_true',
@ -212,6 +215,7 @@ def main():
args.amp_run)
waveglow = load_and_setup_model('WaveGlow', parser, args.waveglow,
args.amp_run)
denoiser = Denoiser(waveglow).cuda()
texts = []
try:
@ -242,6 +246,7 @@ def main():
with torch.no_grad(), MeasureTime(measurements, "waveglow_time"):
audios = waveglow.infer(mel, sigma=args.sigma_infer)
audios = audios.float()
audios = denoiser(audios, strength=args.denoising_strength).squeeze(1)
tacotron2_infer_perf = mel.size(0)*mel.size(2)/measurements['tacotron2_time']
waveglow_infer_perf = audios.size(0)*audios.size(1)/measurements['waveglow_time']
@ -254,9 +259,10 @@ def main():
measurements['waveglow_time']))
for i, audio in enumerate(audios):
audio = audio[:mel_lengths[i]*args.stft_hop_length]
audio = audio/torch.max(torch.abs(audio))
audio_path = args.output + "audio_"+str(i)+".wav"
write(audio_path, args.sampling_rate,
audio.data.cpu().numpy()[:mel_lengths[i]*args.stft_hop_length])
write(audio_path, args.sampling_rate, audio.cpu().numpy())
LOGGER.iteration_stop()
LOGGER.finish()

View file

@ -0,0 +1,5 @@
bash test_infer.sh -bs 1 -il 128 -p amp --num-iters 1003 --tacotron2 checkpoint_Tacotron2_amp --waveglow checkpoint_WaveGlow_amp
bash test_infer.sh -bs 4 -il 128 -p amp --num-iters 1003 --tacotron2 checkpoint_Tacotron2_amp --waveglow checkpoint_WaveGlow_amp
bash test_infer.sh -bs 1 -il 128 -p fp32 --num-iters 1003 --tacotron2 checkpoint_Tacotron2_fp32 --waveglow checkpoint_WaveGlow_fp32
bash test_infer.sh -bs 4 -il 128 -p fp32 --num-iters 1003 --tacotron2 checkpoint_Tacotron2_fp32 --waveglow checkpoint_WaveGlow_fp32

View file

@ -491,9 +491,6 @@ class Decoder(nn.Module):
decoder_input = self.prenet(decoder_input, inference=True)
mel_output, gate_output, alignment = self.decode(decoder_input)
mel_outputs += [mel_output.squeeze(1)]
gate_outputs += [gate_output]
alignments += [alignment]
dec = torch.le(torch.sigmoid(gate_output.data),
self.gate_threshold).to(torch.int32).squeeze(1)
@ -502,6 +499,11 @@ class Decoder(nn.Module):
if self.early_stopping and torch.sum(not_finished) == 0:
break
mel_outputs += [mel_output.squeeze(1)]
gate_outputs += [gate_output]
alignments += [alignment]
if len(mel_outputs) == self.max_decoder_steps:
print("Warning! Reached max decoder steps")
break

View file

@ -0,0 +1,316 @@
# *****************************************************************************
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# * Neither the name of the NVIDIA CORPORATION nor the
# names of its contributors may be used to endorse or promote products
# derived from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
# DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
# (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
# ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
# *****************************************************************************
from tacotron2.text import text_to_sequence
import models
import torch
import argparse
import numpy as np
from scipy.io.wavfile import write
import sys
import time
from dllogger.logger import LOGGER
import dllogger.logger as dllg
from dllogger.autologging import log_hardware, log_args
from apex import amp
def parse_args(parser):
"""
Parse commandline arguments.
"""
parser.add_argument('--tacotron2', type=str,
help='full path to the Tacotron2 model checkpoint file')
parser.add_argument('--waveglow', type=str,
help='full path to the WaveGlow model checkpoint file')
parser.add_argument('-s', '--sigma-infer', default=0.6, type=float)
parser.add_argument('-sr', '--sampling-rate', default=22050, type=int,
help='Sampling rate')
parser.add_argument('--amp-run', action='store_true',
help='inference with AMP')
parser.add_argument('--log-file', type=str, default='nvlog.json',
help='Filename for logging')
parser.add_argument('--stft-hop-length', type=int, default=256,
help='STFT hop length for estimating audio length from mel size')
parser.add_argument('--num-iters', type=int, default=10,
help='Number of iterations')
parser.add_argument('-il', '--input-length', type=int, default=64,
help='Input length')
parser.add_argument('-bs', '--batch-size', type=int, default=1,
help='Batch size')
return parser
def checkpoint_from_distributed(state_dict):
"""
Checks whether checkpoint was generated by DistributedDataParallel. DDP
wraps model in additional "module.", it needs to be unwrapped for single
GPU inference.
:param state_dict: model's state dict
"""
ret = False
for key, _ in state_dict.items():
if key.find('module.') != -1:
ret = True
break
return ret
def unwrap_distributed(state_dict):
"""
Unwraps model from DistributedDataParallel.
DDP wraps model in additional "module.", it needs to be removed for single
GPU inference.
:param state_dict: model's state dict
"""
new_state_dict = {}
for key, value in state_dict.items():
new_key = key.replace('module.', '')
new_state_dict[new_key] = value
return new_state_dict
def load_and_setup_model(model_name, parser, checkpoint, amp_run, to_cuda=True):
model_parser = models.parse_model_args(model_name, parser, add_help=False)
model_args, _ = model_parser.parse_known_args()
model_config = models.get_model_config(model_name, model_args)
model = models.get_model(model_name, model_config, to_cuda=to_cuda)
if checkpoint is not None:
if to_cuda:
state_dict = torch.load(checkpoint)['state_dict']
else:
state_dict = torch.load(checkpoint,map_location='cpu')['state_dict']
if checkpoint_from_distributed(state_dict):
state_dict = unwrap_distributed(state_dict)
model.load_state_dict(state_dict)
if model_name == "WaveGlow":
model = model.remove_weightnorm(model)
model.eval()
if amp_run:
model, _ = amp.initialize(model, [], opt_level="O3")
return model
# taken from tacotron2/data_function.py:TextMelCollate.__call__
def pad_sequences(batch):
# Right zero-pad all one-hot text sequences to max input length
input_lengths, ids_sorted_decreasing = torch.sort(
torch.LongTensor([len(x) for x in batch]),
dim=0, descending=True)
max_input_len = input_lengths[0]
text_padded = torch.LongTensor(len(batch), max_input_len)
text_padded.zero_()
for i in range(len(ids_sorted_decreasing)):
text = batch[ids_sorted_decreasing[i]]
text_padded[i, :text.size(0)] = text
return text_padded, input_lengths
def prepare_input_sequence(texts):
d = []
for i,text in enumerate(texts):
d.append(torch.IntTensor(
text_to_sequence(text, ['english_cleaners'])[:]))
text_padded, input_lengths = pad_sequences(d)
if torch.cuda.is_available():
text_padded = torch.autograd.Variable(text_padded).cuda().long()
input_lengths = torch.autograd.Variable(input_lengths).cuda().long()
else:
text_padded = torch.autograd.Variable(text_padded).long()
input_lengths = torch.autograd.Variable(input_lengths).long()
return text_padded, input_lengths
class MeasureTime():
def __init__(self, measurements, key):
self.measurements = measurements
self.key = key
def __enter__(self):
torch.cuda.synchronize()
self.t0 = time.perf_counter()
def __exit__(self, exc_type, exc_value, exc_traceback):
torch.cuda.synchronize()
self.measurements[self.key] = time.perf_counter() - self.t0
def main():
"""
Launches text to speech (inference).
Inference is executed on a single GPU.
"""
parser = argparse.ArgumentParser(
description='PyTorch Tacotron 2 Inference')
parser = parse_args(parser)
args, unknown_args = parser.parse_known_args()
LOGGER.set_model_name("Tacotron2_PyT")
LOGGER.set_backends([
dllg.JsonBackend(log_file=args.log_file,
logging_scope=dllg.TRAIN_ITER_SCOPE, iteration_interval=1)
])
LOGGER.register_metric("pre_processing", metric_scope=dllg.TRAIN_ITER_SCOPE)
LOGGER.register_metric("tacotron2_items_per_sec", metric_scope=dllg.TRAIN_ITER_SCOPE)
LOGGER.register_metric("tacotron2_latency", metric_scope=dllg.TRAIN_ITER_SCOPE)
LOGGER.register_metric("waveglow_items_per_sec", metric_scope=dllg.TRAIN_ITER_SCOPE)
LOGGER.register_metric("waveglow_latency", metric_scope=dllg.TRAIN_ITER_SCOPE)
LOGGER.register_metric("latency", metric_scope=dllg.TRAIN_ITER_SCOPE)
LOGGER.register_metric("type_conversion", metric_scope=dllg.TRAIN_ITER_SCOPE)
LOGGER.register_metric("storage", metric_scope=dllg.TRAIN_ITER_SCOPE)
LOGGER.register_metric("data_transfer", metric_scope=dllg.TRAIN_ITER_SCOPE)
LOGGER.register_metric("num_mels_per_audio", metric_scope=dllg.TRAIN_ITER_SCOPE)
LOGGER.register_metric("throughput", metric_scope=dllg.TRAIN_ITER_SCOPE)
measurements_all = {"pre_processing": [],
"tacotron2_latency": [],
"waveglow_latency": [],
"latency": [],
"type_conversion": [],
"data_transfer": [],
"storage": [],
"tacotron2_items_per_sec": [],
"waveglow_items_per_sec": [],
"num_mels_per_audio": [],
"throughput": []}
log_hardware()
log_args(args)
print("args:", args, unknown_args)
tacotron2 = load_and_setup_model('Tacotron2', parser, args.tacotron2, args.amp_run)
waveglow = load_and_setup_model('WaveGlow', parser, args.waveglow, args.amp_run)
texts = ["The forms of printed letters should be beautiful, and that their arrangement on the page should be reasonable and a help to the shapeliness of the letters themselves. The forms of printed letters should be beautiful, and that their arrangement on the page should be reasonable and a help to the shapeliness of the letters themselves."]
texts = [texts[0][:args.input_length]]
texts = texts*args.batch_size
warmup_iters = 3
for iter in range(args.num_iters):
if iter >= warmup_iters:
LOGGER.iteration_start()
measurements = {}
with MeasureTime(measurements, "pre_processing"):
sequences_padded, input_lengths = prepare_input_sequence(texts)
with torch.no_grad():
with MeasureTime(measurements, "latency"):
with MeasureTime(measurements, "tacotron2_latency"):
_, mel, _, _, mel_lengths = tacotron2.infer(sequences_padded, input_lengths)
with MeasureTime(measurements, "waveglow_latency"):
audios = waveglow.infer(mel, sigma=args.sigma_infer)
num_mels = mel.size(0)*mel.size(2)
num_samples = audios.size(0)*audios.size(1)
with MeasureTime(measurements, "type_conversion"):
audios = audios.float()
with MeasureTime(measurements, "data_transfer"):
audios = audios.cpu()
with MeasureTime(measurements, "storage"):
audios = audios.numpy()
for i, audio in enumerate(audios):
audio_path = "audio_"+str(i)+".wav"
write(audio_path, args.sampling_rate,
audio[:mel_lengths[i]*args.stft_hop_length])
measurements['tacotron2_items_per_sec'] = num_mels/measurements['tacotron2_latency']
measurements['waveglow_items_per_sec'] = num_samples/measurements['waveglow_latency']
measurements['num_mels_per_audio'] = mel.size(2)
measurements['throughput'] = num_samples/measurements['latency']
if iter >= warmup_iters:
for k,v in measurements.items():
measurements_all[k].append(v)
LOGGER.log(key=k, value=v)
LOGGER.iteration_stop()
LOGGER.finish()
print(np.mean(measurements_all['latency'][1:]),
np.mean(measurements_all['throughput'][1:]),
np.mean(measurements_all['pre_processing'][1:]),
np.mean(measurements_all['type_conversion'][1:])+
np.mean(measurements_all['storage'][1:])+
np.mean(measurements_all['data_transfer'][1:]),
np.mean(measurements_all['num_mels_per_audio'][1:]))
throughput = measurements_all['throughput']
preprocessing = measurements_all['pre_processing']
type_conversion = measurements_all['type_conversion']
storage = measurements_all['storage']
data_transfer = measurements_all['data_transfer']
postprocessing = [sum(p) for p in zip(type_conversion,storage,data_transfer)]
latency = measurements_all['latency']
num_mels_per_audio = measurements_all['num_mels_per_audio']
latency.sort()
cf_50 = max(latency[:int(len(latency)*0.50)])
cf_90 = max(latency[:int(len(latency)*0.90)])
cf_95 = max(latency[:int(len(latency)*0.95)])
cf_99 = max(latency[:int(len(latency)*0.99)])
cf_100 = max(latency[:int(len(latency)*1.0)])
print("Throughput average (samples/sec) = {:.4f}".format(np.mean(throughput)))
print("Preprocessing average (seconds) = {:.4f}".format(np.mean(preprocessing)))
print("Postprocessing average (seconds) = {:.4f}".format(np.mean(postprocessing)))
print("Number of mels per audio average = {}".format(np.mean(num_mels_per_audio)))
print("Latency average (seconds) = {:.4f}".format(np.mean(latency)))
print("Latency std (seconds) = {:.4f}".format(np.std(latency)))
print("Latency cl 50 (seconds) = {:.4f}".format(cf_50))
print("Latency cl 90 (seconds) = {:.4f}".format(cf_90))
print("Latency cl 95 (seconds) = {:.4f}".format(cf_95))
print("Latency cl 99 (seconds) = {:.4f}".format(cf_99))
print("Latency cl 100 (seconds) = {:.4f}".format(cf_100))
if __name__ == '__main__':
main()

View file

@ -0,0 +1,68 @@
#!/bin/bash
BATCH_SIZE=1
INPUT_LENGTH=128
PRECISION="fp32"
NUM_ITERS=1003 # extra 3 iterations for warmup
TACOTRON2_CKPT="checkpoint_Tacotron2_1500_fp32"
WAVEGLOW_CKPT="checkpoint_WaveGlow_1000_fp32"
while [ -n "$1" ]
do
case "$1" in
-bs|--batch-size)
BATCH_SIZE="$2"
shift
;;
-il|--input-length)
INPUT_LENGTH="$2"
shift
;;
-p|--prec)
PRECISION="$2"
shift
;;
--num-iters)
NUM_ITERS="$2"
shift
;;
--tacotron2)
TACOTRON2_CKPT="$2"
shift
;;
--waveglow)
WAVEGLOW_CKPT="$2"
shift
;;
*)
echo "Option $1 not recognized"
esac
shift
done
LOG_SUFFIX=bs${BATCH_SIZE}_il${INPUT_LENGTH}_${PRECISION}
NVLOG_FILE=nvlog_${LOG_SUFFIX}.json
TMP_LOGFILE=tmp_log_${LOG_SUFFIX}.log
LOGFILE=log_${LOG_SUFFIX}.log
set -x
python test_infer.py \
--tacotron2 $TACOTRON2_CKPT \
--waveglow $WAVEGLOW_CKPT \
--batch-size $BATCH_SIZE \
--input-length $INPUT_LENGTH $AMP_RUN $CPU_RUN \
--log-file $NVLOG_FILE \
--num-iters $NUM_ITERS \
|& tee $TMP_LOGFILE
set +x
PERF=$(cat $TMP_LOGFILE | grep -F 'Throughput average (samples/sec)' | awk -F'= ' '{print $2}')
NUM_MELS=$(cat $TMP_LOGFILE | grep -F 'Number of mels per audio average' | awk -F'= ' '{print $2}')
LATENCY=$(cat $TMP_LOGFILE | grep -F 'Latency average (seconds)' | awk -F'= ' '{print $2}')
LATENCYSTD=$(cat $TMP_LOGFILE | grep -F 'Latency std (seconds)' | awk -F'= ' '{print $2}')
LATENCY50=$(cat $TMP_LOGFILE | grep -F 'Latency cl 50 (seconds)' | awk -F'= ' '{print $2}')
LATENCY100=$(cat $TMP_LOGFILE | grep -F 'Latency cl 100 (seconds)' | awk -F'= ' '{print $2}')
echo "$BATCH_SIZE,$INPUT_LENGTH,$PRECISION,$NUM_ITERS,$LATENCY,$LATENCYSTD,$LATENCY50,$LATENCY100,$PERF,$NUM_MELS" >> $LOGFILE

View file

@ -0,0 +1,67 @@
# *****************************************************************************
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# * Neither the name of the NVIDIA CORPORATION nor the
# names of its contributors may be used to endorse or promote products
# derived from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
# DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
# (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
# ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
# *****************************************************************************
import sys
sys.path.append('tacotron2')
import torch
from common.layers import STFT
class Denoiser(torch.nn.Module):
""" Removes model bias from audio produced with waveglow """
def __init__(self, waveglow, filter_length=1024, n_overlap=4,
win_length=1024, mode='zeros'):
super(Denoiser, self).__init__()
self.stft = STFT(filter_length=filter_length,
hop_length=int(filter_length/n_overlap),
win_length=win_length).cuda()
if mode == 'zeros':
mel_input = torch.zeros(
(1, 80, 88),
dtype=waveglow.upsample.weight.dtype,
device=waveglow.upsample.weight.device)
elif mode == 'normal':
mel_input = torch.randn(
(1, 80, 88),
dtype=waveglow.upsample.weight.dtype,
device=waveglow.upsample.weight.device)
else:
raise Exception("Mode {} if not supported".format(mode))
with torch.no_grad():
bias_audio = waveglow.infer(mel_input, sigma=0.0).float()
bias_spec, _ = self.stft.transform(bias_audio)
self.register_buffer('bias_spec', bias_spec[:, :, 0][:, :, None])
def forward(self, audio, strength=0.1):
audio_spec, audio_angles = self.stft.transform(audio.cuda().float())
audio_spec_denoised = audio_spec - self.bias_spec * strength
audio_spec_denoised = torch.clamp(audio_spec_denoised, 0.0)
audio_denoised = self.stft.inverse(audio_spec_denoised, audio_angles)
return audio_denoised