added TRTIS demo to Tacotron2 (#281)

* added TRTIS demo

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md
This commit is contained in:
GrzegorzKarchNV 2019-11-06 19:46:35 +01:00 committed by nvpstr
parent 4760c03a04
commit 4e00153ab5
21 changed files with 1465 additions and 208 deletions

View file

@ -1,4 +1,4 @@
FROM nvcr.io/nvidia/pytorch:19.08-py3
FROM nvcr.io/nvidia/pytorch:19.10-py3
ADD . /workspace/tacotron2
WORKDIR /workspace/tacotron2

View file

@ -0,0 +1,41 @@
# Copyright (c) 2019 NVIDIA CORPORATION. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
FROM nvcr.io/nvidia/tensorrtserver:19.10-py3-clientsdk AS trt
FROM continuumio/miniconda3
RUN apt-get update && apt-get install -y pbzip2 pv bzip2 cabextract mc iputils-ping wget
WORKDIR /workspace/speech_ai_demo__TTS/
# Copy the perf_client over
COPY --from=trt /workspace/install/ /workspace/install/
ENV LD_LIBRARY_PATH /workspace/install/lib:${LD_LIBRARY_PATH}
# set up env variables
ENV PATH="$PATH:/opt/conda/bin"
RUN cd /workspace/speech_ai_demo__TTS/
# jupyter lab extensions
RUN conda install -c conda-forge jupyterlab=1.0 ipywidgets=7.5 nodejs python-sounddevice librosa unidecode inflect
RUN jupyter labextension install @jupyter-widgets/jupyterlab-manager
RUN pip install /workspace/install/python/tensorrtserver*.whl
# Copy the python wheel and install with pip
COPY --from=trt /workspace/install/python/tensorrtserver*.whl /tmp/
RUN pip install /tmp/tensorrtserver*.whl && rm /tmp/tensorrtserver*.whl
RUN cd /workspace/speech_ai_demo__TTS/
COPY ./notebooks/trtis/ .
RUN mkdir /workspace/speech_ai_demo__TTS/tacotron2/
COPY ./tacotron2/text /workspace/speech_ai_demo__TTS/tacotron2/text
RUN chmod a+x /workspace/speech_ai_demo__TTS/run_this.sh

View file

@ -1,4 +1,4 @@
# Tacotron 2 And WaveGlow v1.7 For PyTorch
# Tacotron 2 And WaveGlow v1.10 For PyTorch
This repository provides a script and recipe to train Tacotron 2 and WaveGlow
v1.6 models to achieve state of the art accuracy, and is tested and maintained by NVIDIA.
@ -33,13 +33,13 @@ v1.6 models to achieve state of the art accuracy, and is tested and maintained b
* [Inference performance benchmark](#inference-performance-benchmark)
* [Results](#results)
* [Training accuracy results](#training-accuracy-results)
* [NVIDIA DGX-1 (8x V100 16G)](#nvidia-dgx-1-8x-v100-16g)
* [Training accuracy: NVIDIA DGX-1 (8x V100 16G)](#training-accuracy-nvidia-dgx-1-8x-v100-16g)
* [Training performance results](#training-performance-results)
* [NVIDIA DGX-1 (8x V100 16G)](#nvidia-dgx-1-8x-v100-16g)
* [Training performance: NVIDIA DGX-1 (8x V100 16G)](#training-performance-nvidia-dgx-1-8x-v100-16g)
* [Expected training time](#expected-training-time)
* [Inference performance results](#inference-performance-results)
* [NVIDIA V100 16G](#nvidia-v100-16g)
* [NVIDIA T4](#nvidia-t4)
* [Inference performance: NVIDIA V100 16G](#inference-performance-nvidia-v100-16g)
* [Inference performance: NVIDIA T4](#inference-performance-nvidia-t4)
* [Release notes](#release-notes)
* [Changelog](#changelog)
* [Known issues](#known-issues)
@ -471,7 +471,7 @@ To run inference, issue:
```bash
python inference.py --tacotron2 <Tacotron2_checkpoint> --waveglow <WaveGlow_checkpoint> -o output/ --include-warmup -i phrases/phrase.txt --amp-run
```
Here, `Tacotron2_checkpoint` and `WaveGlow_checkpoint` are pre-trained
Here, `Tacotron2_checkpoint` and `WaveGlow_checkpoint` are pre-trained
checkpoints for the respective models, and `phrases/phrase.txt` contains input
phrases. The number of text lines determines the inference batch size. Audio
will be saved in the output folder. The audio files [audio_fp16](./audio/audio_fp16.wav)
@ -564,7 +564,7 @@ and accuracy in training and inference.
#### Training accuracy results
##### NVIDIA DGX-1 (8x V100 16G)
##### Training accuracy: NVIDIA DGX-1 (8x V100 16G)
Our results were obtained by running the `./platform/train_{tacotron2,waveglow}_{AMP,FP32}_DGX1_16GB_8GPU.sh` training script in the PyTorch-19.06-py3
NGC container on NVIDIA DGX-1 with 8x V100 16G GPUs.
@ -594,7 +594,7 @@ WaveGlow FP32 loss - batch size 4 (mean and std over 16 runs)
#### Training performance results
##### NVIDIA DGX-1 (8x V100 16G)
##### Training performance: NVIDIA DGX-1 (8x V100 16G)
Our results were obtained by running the `./platform/train_{tacotron2,waveglow}_{AMP,FP32}_DGX1_16GB_8GPU.sh`
training script in the PyTorch-19.06-py3 NGC container on NVIDIA DGX-1 with
@ -648,26 +648,27 @@ deviation, and latency confidence intervals. Throughput is measured
as the number of generated audio samples per second. RTF is the real-time factor
which tells how many seconds of speech are generated in 1 second of compute.
##### NVIDIA V100 16G
##### Inference performance: NVIDIA DGX-1 (1x V100 16G)
|Batch size|Input length|Precision|Avg latency (s)|Latency std (s)|Latency confidence interval 50% (s)|Latency confidence interval 100% (s)|Throughput (samples/sec)|Speed-up with mixed precision|Avg mels generated (81 mels=1 sec of speech)|Avg audio length (s)|Avg RTF|
|---:|---:|---:|---:|---:|---:|---:|---:|---:|---:|---:|---:|
|1| 128| FP16| 1.73| 0.07| 1.72| 2.11| 89,162| 1.09| 601| 6.98| 4.04|
|4| 128| FP16| 4.21| 0.17| 4.19| 4.84| 145,800| 1.16| 600| 6.97| 1.65|
|1| 128| FP32| 1.85| 0.06| 1.84| 2.19| 81,868| 1.00| 590| 6.85| 3.71|
|4| 128| FP32| 4.80| 0.15| 4.79| 5.43| 125,930| 1.00| 590| 6.85| 1.43|
|Batch size|Input length|Precision|Avg latency (s)|Latency std (s)|Latency confidence interval 90% (s)|Latency confidence interval 95% (s)|Latency confidence interval 99% (s)|Throughput (samples/sec)|Speed-up with mixed precision|Avg mels generated (81 mels=1 sec of speech)|Avg audio length (s)|Avg RTF|
|---:|---:|---:|---:|---:|---:|---:|---:|---:|---:|---:|---:|---:|
|1| 128| FP16| 1.27| 0.06| 1.34| 1.38| 1.41| 121,190| 1.37| 603| 7.00| 5.51|
|4| 128| FP16| 2.32| 0.09| 2.42| 2.45| 2.59| 277,711| 2.03| 628| 7.23| 3.12|
|1| 128| FP32| 1.70| 0.05| 1.77| 1.79| 1.84| 88,650| 1.00| 590| 6.85| 4.03|
|4| 128| FP32| 4.56| 0.12| 4.72| 4.77| 4.87| 136,518| 1.00| 608| 7.06| 1.55|
##### NVIDIA T4
##### Inference performance: NVIDIA T4
|Batch size|Input length|Precision|Avg latency (s)|Latency std (s)|Latency confidence interval 90% (s)|Latency confidence interval 95% (s)|Latency confidence interval 99% (s)|Throughput (samples/sec)|Speed-up with mixed precision|Avg mels generated (81 mels=1 sec of speech)|Avg audio length (s)|Avg RTF|
|---:|---:|---:|---:|---:|---:|---:|---:|---:|---:|---:|---:|---:|
|1| 128| FP16| 3.13| 0.13| 3.28| 3.36| 3.46| 49,276| 1.26| 602| 6.99| 2.24|
|4| 128| FP16| 11.98| 0.42| 12.44| 12.70| 13.29| 53,676| 1.23| 628| 7.29| 0.61|
|1| 128| FP32| 3.88| 0.12| 4.04| 4.09| 4.19| 38,964| 1.00| 591| 6.86| 1.77|
|4| 128| FP32| 14.34| 0.42| 14.89| 15.08| 15.55| 43,489| 1.00| 609| 7.07| 0.49|
|Batch size|Input length|Precision|Avg latency (s)|Latency std (s)|Latency confidence interval 50% (s)|Latency confidence interval 100% (s)|Throughput (samples/sec)|Speed-up with mixed precision|Avg mels generated (81 mels=1 sec of speech)|Avg audio length (s)|Avg RTF|
|---:|---:|---:|---:|---:|---:|---:|---:|---:|---:|---:|---:|
|1| 128| FP16| 3.16| 0.13| 3.16| 3.81| 48,792| 1.23| 603| 7.00| 2.21|
|4| 128| FP16| 11.45| 0.49| 11.39| 14.38| 53,771| 1.22| 601| 6.98| 0.61|
|1| 128| FP32| 3.82| 0.11| 3.81| 4.24| 39,603| 1.00| 591| 6.86| 1.80|
|4| 128| FP32| 13.80| 0.45| 13.74| 16.09| 43,915| 1.00| 592| 6.87| 0.50|
Our results were obtained by running the `./run_latency_tests.sh` script in
the PyTorch-19.06-py3 NGC container. Please note that to reproduce the results,
the PyTorch-19.09-py3 NGC container. Please note that to reproduce the results,
you need to provide pretrained checkpoints for Tacotron 2 and WaveGlow. Please
edit the script to provide your checkpoint filenames.
@ -696,7 +697,13 @@ August 2019
September 2019
* Introduced inference statistics
October 2019
* Tacotron 2 inference with torch.jit.script
November 2019
* Implemented training resume from checkpoint
* Added notebook for running Tacotron 2 and WaveGlow in TRTIS.
### Known issues
There are no known issues in this release.

View file

@ -33,8 +33,9 @@ import os
def get_mask_from_lengths(lengths):
max_len = torch.max(lengths).item()
ids = torch.arange(0, max_len, out=torch.cuda.LongTensor(max_len))
ids = torch.arange(0, max_len, device=lengths.device, dtype=lengths.dtype)
mask = (ids < lengths.unsqueeze(1)).byte()
mask = torch.le(mask, 0)
return mask

View file

@ -0,0 +1,67 @@
# *****************************************************************************
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# * Neither the name of the NVIDIA CORPORATION nor the
# names of its contributors may be used to endorse or promote products
# derived from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
# DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
# (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
# ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
# *****************************************************************************
import torch
import argparse
from inference import checkpoint_from_distributed, unwrap_distributed, load_and_setup_model
from dllogger.autologging import log_hardware, log_args
def parse_args(parser):
"""
Parse commandline arguments.
"""
parser.add_argument('--tacotron2', type=str, required=True,
help='full path to the Tacotron2 model checkpoint file')
parser.add_argument('-o', '--output', type=str, default="trtis_repo/tacotron/1/model.pt",
help='filename for the Tacotron 2 TorchScript model')
parser.add_argument('--amp-run', action='store_true',
help='inference with AMP')
return parser
def main():
parser = argparse.ArgumentParser(
description='PyTorch Tacotron 2 Inference')
parser = parse_args(parser)
args = parser.parse_args()
log_args(args)
tacotron2 = load_and_setup_model('Tacotron2', parser, args.tacotron2,
args.amp_run, rename=True)
jitted_tacotron2 = torch.jit.script(tacotron2)
torch.jit.save(jitted_tacotron2, args.output)
if __name__ == '__main__':
main()

View file

@ -0,0 +1,117 @@
# *****************************************************************************
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# * Neither the name of the NVIDIA CORPORATION nor the
# names of its contributors may be used to endorse or promote products
# derived from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
# DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
# (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
# ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
# *****************************************************************************
import os
import argparse
from dllogger.autologging import log_hardware, log_args
def parse_args(parser):
"""
Parse commandline arguments.
"""
parser.add_argument("--trtis_model_name",
type=str,
default='tacotron2',
help="exports to appropriate directory for TRTIS")
parser.add_argument("--trtis_model_version",
type=int,
default=1,
help="exports to appropriate directory for TRTIS")
parser.add_argument("--trtis_max_batch_size",
type=int,
default=8,
help="Specifies the 'max_batch_size' in the TRTIS model config.\
See the TRTIS documentation for more info.")
parser.add_argument('--amp-run', action='store_true',
help='inference with AMP')
return parser
def main():
parser = argparse.ArgumentParser(
description='PyTorch Tacotron 2 TRTIS config exporter')
parser = parse_args(parser)
args = parser.parse_args()
log_args(args)
# prepare repository
model_folder = os.path.join('./trtis_repo', args.trtis_model_name)
version_folder = os.path.join(model_folder, str(args.trtis_model_version))
if not os.path.exists(version_folder):
os.makedirs(version_folder)
# build the config for TRTIS
config_filename = os.path.join(model_folder, "config.pbtxt")
config_template = r"""
name: "{model_name}"
platform: "pytorch_libtorch"
max_batch_size: {max_batch_size}
input [
{{
name: "sequence__0"
data_type: TYPE_INT64
dims: [-1]
}},
{{
name: "input_lengths__1"
data_type: TYPE_INT64
dims: [1]
reshape: {{ shape: [ ] }}
}}
]
output [
{{
name: "mel_outputs_postnet__0"
data_type: {fp_type}
dims: [80,-1]
}},
{{
name: "mel_lengths__1"
data_type: TYPE_INT32
dims: [1]
reshape: {{ shape: [ ] }}
}}
]
"""
config_values = {
"model_name": args.trtis_model_name,
"max_batch_size": args.trtis_max_batch_size,
"fp_type": "TYPE_FP16" if args.amp_run else "TYPE_FP32"
}
with open(model_folder + "/config.pbtxt", "w") as file:
final_config_str = config_template.format_map(config_values)
file.write(final_config_str)
if __name__ == '__main__':
main()

View file

@ -0,0 +1,158 @@
# *****************************************************************************
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# * Neither the name of the NVIDIA CORPORATION nor the
# names of its contributors may be used to endorse or promote products
# derived from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
# DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
# (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
# ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
# *****************************************************************************
import torch
import argparse
from inference import checkpoint_from_distributed, unwrap_distributed, load_and_setup_model
from dllogger.autologging import log_args
def parse_args(parser):
"""
Parse commandline arguments.
"""
parser.add_argument('--waveglow', type=str, required=True,
help='full path to the WaveGlow model checkpoint file')
parser.add_argument('-o', '--output', type=str, default="waveglow.onnx",
help='filename for the exported WaveGlow TRT engine')
parser.add_argument('--amp-run', action='store_true',
help='inference with AMP')
parser.add_argument('-s', '--sigma-infer', default=0.6, type=float)
return parser
def convert_convinv_1d_to_2d(convinv):
"""
Takes an invertible 1x1 1-d convolution and returns a 2-d convolution that does
the inverse
"""
conv2d = torch.nn.Conv2d(convinv.W_inverse.size(1),
convinv.W_inverse.size(0),
1, bias=False)
conv2d.weight.data[:,:,:,0] = convinv.W_inverse.data
return conv2d
def convert_conv_1d_to_2d(conv1d):
conv2d = torch.nn.Conv2d(conv1d.weight.size(1),
conv1d.weight.size(0),
(conv1d.weight.size(2), 1),
stride=(conv1d.stride[0], 1),
dilation=(conv1d.dilation[0], 1),
padding=(conv1d.padding[0], 0))
conv2d.weight.data[:,:,:,0] = conv1d.weight.data
conv2d.bias.data = conv1d.bias.data
return conv2d
def convert_WN_1d_to_2d_(WN):
"""
Modifies the WaveNet like affine coupling layer in-place to use 2-d convolutions
"""
WN.start = convert_conv_1d_to_2d(WN.start)
WN.end = convert_conv_1d_to_2d(WN.end)
for i in range(len(WN.in_layers)):
WN.in_layers[i] = convert_conv_1d_to_2d(WN.in_layers[i])
for i in range(len(WN.res_skip_layers)):
WN.res_skip_layers[i] = convert_conv_1d_to_2d(WN.res_skip_layers[i])
for i in range(len(WN.res_skip_layers)):
WN.cond_layers[i] = convert_conv_1d_to_2d(WN.cond_layers[i])
def convert_1d_to_2d_(glow):
"""
Caffe2 and TensorRT don't seem to support 1-d convolutions or properly
convert ONNX exports with 1d convolutions to 2d convolutions yet, so we
do the conversion to 2-d convolutions before ONNX export
"""
# Convert upsample to 2d
upsample = torch.nn.ConvTranspose2d(glow.upsample.weight.size(0),
glow.upsample.weight.size(1),
(glow.upsample.weight.size(2), 1),
stride=(glow.upsample.stride[0], 1))
upsample.weight.data[:,:,:,0] = glow.upsample.weight.data
upsample.bias.data = glow.upsample.bias.data
glow.upsample = upsample.cuda()
# Convert WN to 2d
for WN in glow.WN:
convert_WN_1d_to_2d_(WN)
# Convert invertible conv to 2d
for i in range(len(glow.convinv)):
glow.convinv[i] = convert_convinv_1d_to_2d(glow.convinv[i])
glow.cuda()
def export_onnx(parser, args):
waveglow = load_and_setup_model('WaveGlow', parser, args.waveglow, args.amp_run)
# 80 mel channels, 620 mel spectrograms ~ 7 seconds of speech
mel = torch.randn(1, 80, 620).cuda()
stride = 256 # value from waveglow upsample
kernel_size = 1024 # value from waveglow upsample
n_group = 8
z_size2 = (mel.size(2)-1)*stride+(kernel_size-1)+1
# corresponds to cutoff in infer_onnx
z_size2 = z_size2 - (kernel_size-stride)
z_size2 = z_size2//n_group
z = torch.randn(1, n_group, z_size2, 1).cuda()
if args.amp_run:
mel = mel.half()
z = z.half()
with torch.no_grad():
# run inference to force calculation of inverses
waveglow.infer(mel, sigma=args.sigma_infer)
# export to ONNX
convert_1d_to_2d_(waveglow)
waveglow.forward = waveglow.infer_onnx
if args.amp_run:
waveglow.half()
mel = mel.unsqueeze(3)
torch.onnx.export(waveglow, (mel, z), args.output)
def main():
parser = argparse.ArgumentParser(
description='PyTorch Tacotron 2 Inference')
parser = parse_args(parser)
args, _ = parser.parse_known_args()
log_args(args)
export_onnx(parser, args)
if __name__ == '__main__':
main()

View file

@ -0,0 +1,99 @@
# *****************************************************************************
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# * Neither the name of the NVIDIA CORPORATION nor the
# names of its contributors may be used to endorse or promote products
# derived from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
# DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
# (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
# ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
# *****************************************************************************
import os
import argparse
from dllogger.autologging import log_hardware, log_args
def parse_args(parser):
"""
Parse commandline arguments.
"""
parser.add_argument("--trtis_model_name",
type=str,
default='waveglow',
help="exports to appropriate directory for TRTIS")
parser.add_argument("--trtis_model_version",
type=int,
default=1,
help="exports to appropriate directory for TRTIS")
parser.add_argument('--amp-run', action='store_true',
help='inference with AMP')
return parser
def main():
parser = argparse.ArgumentParser(
description='PyTorch WaveGlow TRTIS config exporter')
parser = parse_args(parser)
args = parser.parse_args()
log_args(args)
# prepare repository
model_folder = os.path.join('./trtis_repo', args.trtis_model_name)
version_folder = os.path.join(model_folder, str(args.trtis_model_version))
if not os.path.exists(version_folder):
os.makedirs(version_folder)
# build the config for TRTIS
config_filename = os.path.join(model_folder, "config.pbtxt")
config_template = r"""
name: "{model_name}"
platform: "tensorrt_plan"
input {{
name: "0"
data_type: {fp_type}
dims: [1, 80, 620, 1]
}}
input {{
name: "1"
data_type: {fp_type}
dims: [1, 8, 19840, 1]
}}
output {{
name: "1991"
data_type: {fp_type}
dims: [1, 158720]
}}
"""
config_values = {
"model_name": args.trtis_model_name,
"fp_type": "TYPE_FP16" if args.amp_run else "TYPE_FP32"
}
with open(model_folder + "/config.pbtxt", "w") as file:
final_config_str = config_template.format_map(config_values)
file.write(final_config_str)
if __name__ == '__main__':
main()

View file

@ -111,12 +111,12 @@ def unwrap_distributed(state_dict):
return new_state_dict
def load_and_setup_model(model_name, parser, checkpoint, amp_run):
def load_and_setup_model(model_name, parser, checkpoint, amp_run, rename=False):
model_parser = models.parse_model_args(model_name, parser, add_help=False)
model_args, _ = model_parser.parse_known_args()
model_config = models.get_model_config(model_name, model_args)
model = models.get_model(model_name, model_config, to_cuda=True)
model = models.get_model(model_name, model_config, to_cuda=True, rename=rename)
if checkpoint is not None:
state_dict = torch.load(checkpoint)['state_dict']
@ -131,7 +131,7 @@ def load_and_setup_model(model_name, parser, checkpoint, amp_run):
model.eval()
if amp_run:
model, _ = amp.initialize(model, [], opt_level="O3")
model.half()
return model
@ -217,6 +217,10 @@ def main():
args.amp_run)
denoiser = Denoiser(waveglow).cuda()
tacotron2.forward = tacotron2.infer
type(tacotron2).forward = type(tacotron2).infer
jitted_tacotron2 = torch.jit.script(tacotron2)
texts = []
try:
f = open(args.input, 'r')
@ -231,7 +235,7 @@ def main():
input_lengths = torch.IntTensor([sequence.size(1)]).cuda().long()
for i in range(3):
with torch.no_grad():
_, mel, _, _, mel_lengths = tacotron2.infer(sequence, input_lengths)
mel, mel_lengths = jitted_tacotron2(sequence, input_lengths)
_ = waveglow.infer(mel)
LOGGER.iteration_start()
@ -241,7 +245,7 @@ def main():
sequences_padded, input_lengths = prepare_input_sequence(texts)
with torch.no_grad(), MeasureTime(measurements, "tacotron2_time"):
_, mel, _, _, mel_lengths = tacotron2.infer(sequences_padded, input_lengths)
mel, mel_lengths = jitted_tacotron2(sequences_padded, input_lengths)
with torch.no_grad(), MeasureTime(measurements, "waveglow_time"):
audios = waveglow.infer(mel, sigma=args.sigma_infer)

View file

@ -101,7 +101,7 @@ def main():
dtype=torch.long).cuda()
input_lengths = torch.IntTensor([text_padded.size(1)]*args.batch_size).cuda().long()
with torch.no_grad(), MeasureTime(measurements, "inference_time"):
_, mels, _, _, _ = model.infer(text_padded, input_lengths)
mels, _ = model.infer(text_padded, input_lengths)
num_items = mels.size(0)*mels.size(2)
if args.model_name == 'WaveGlow':

View file

@ -63,11 +63,17 @@ def init_bn(module):
def get_model(model_name, model_config, to_cuda,
uniform_initialize_bn_weight=False):
uniform_initialize_bn_weight=False, rename=False):
""" Code chooses a model based on name"""
model = None
if model_name == 'Tacotron2':
model = Tacotron2(**model_config)
if rename:
class Tacotron2_extra(Tacotron2):
def forward(self, inputs, input_lengths):
return self.infer(inputs, input_lengths)
model = Tacotron2_extra(**model_config)
else:
model = Tacotron2(**model_config)
elif model_name == 'WaveGlow':
model = WaveGlow(**model_config)
else:

View file

@ -2,52 +2,49 @@
A jupyter notobook based on Quick Start Guide of: https://github.com/NVIDIA/DeepLearningExamples/tree/master/PyTorch/SpeechSynthesis/Tacotron2
## Requirements
Ensure you have the following components:
NVIDIA Docker (https://github.com/NVIDIA/nvidia-docker)
PyTorch 19.06-py3+ NGC container or newer (https://ngc.nvidia.com/catalog/containers/nvidia:pytorch)
NVIDIA Volta (https://www.nvidia.com/en-us/data-center/volta-gpu-architecture/) or Turing (https://www.nvidia.com/en-us/geforce/turing/) based GPU
NVIDIA Docker (https://github.com/NVIDIA/nvidia-docker) PyTorch 19.06-py3+ NGC container or newer (https://ngc.nvidia.com/catalog/containers/nvidia:pytorch) NVIDIA Volta (https://www.nvidia.com/en-us/data-center/volta-gpu-architecture/) or Turing (https://www.nvidia.com/en-us/geforce/turing/) based GPU
Before running the Jupyter notebook, please make sure you already git clone the code from the Github:
```bash
git clone https://github.com/NVIDIA/DeepLearningExamples.git
```bash
git clone https://github.com/NVIDIA/DeepLearningExamples.git
cd DeepLearningExamples/PyTorch/SpeechSynthesis/Tacotron2
cd DeepLearningExamples/PyTorch/SpeechSynthesis/Tacotron2
```
Copy the Tacotron2.ipynb file into the folder 'Tacotron2':
```bash
cp notebooks/Tacotron2.ipynb .
Copy the Tacotron2.ipynb file into the folder 'Tacotron2'
```bash
cp notebooks/Tacotron2.ipynb .
```
### Running the quick start guide as a Jupyter notebook
To run the notebook on you local machine:
To run the notebook on you local machine:
```bash
jupyter notebook Tacotron2.ipynb
```
To run the notebook on another machine remotely:
To run the notebook remotely:
```bash
jupyter notebook --ip=0.0.0.0 --allow-root
```
And navigate a web browser to the IP address or hostname of the host machine
at port `8888`:
And navigate a web browser to the IP address or hostname of the host machine at port `8888`:
```
http://[host machine]:8888
```
Use the token listed in the output from running the `jupyter` command to log
in, for example:
Use the token listed in the output from running the `jupyter` command to log in, for example:
```
http://[host machine]:8888/?token=aae96ae9387cd28151868fee318c3b3581a2d794f3b25c6b
```
```

View file

@ -0,0 +1,25 @@
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# * Neither the name of NVIDIA CORPORATION nor the names of its
# contributors may be used to endorse or promote products derived
# from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

View file

@ -0,0 +1,153 @@
# Tacotron 2 and WaveGlow inference on TRTIS
## Setup
### Clone the repository.
```bash
git clone https://github.com/NVIDIA/DeepLearningExamples.git
cd DeepLearningExamples/PyTorch/SpeechSynthesis/Tacotron2
```
### Obtain models to be loaded in TRTIS.
We have prepared Tacotron 2 and WaveGlow models that are ready to be loaded in TRTIS,
so you don't need to train and export the models. Please follow the instructions
below to learn how to train, export --- or simply download the pretrained models.
### Obtain Tacotron 2 and WaveGlow checkpoints.
You can either download the pretrained checkpoints or train the models yourself.
#### Download pretrained checkpoints.
If you want to use a pretrained checkpoints, download them from [NGC](https://ngc.nvidia.com/catalog/models):
- [Tacotron2 checkpoint](https://ngc.nvidia.com/models/nvidia:tacotron2pyt_fp16)
- [WaveGlow checkpoint](https://ngc.nvidia.com/models/nvidia:waveglow256pyt_fp16)
#### Train Tacotron 2 and WaveGlow models.
In order to train the models, follow the QuickStart section in the `Tacotron2/README.md`
file by executing points 1-5. You have to train WaveGlow in a different way than described there. Use
the following command instead of the one given in QuickStart at point 5:
```bash
python -m multiproc train.py -m WaveGlow -o output/ --amp-run -lr 1e-4 --epochs 2001 --wn-channels 256 -bs 12 --segment-length 16000 --weight-decay 0 --grad-clip-thresh 65504.0 --cudnn-benchmark --cudnn-enabled --log-file output/nvlog.json
```
This will train the WaveGlow model with a smaller number of residual connections
in the coupling layer networks and larger segment length. Training should take
about 100 hours on DGX-1 (8x V100 16G).
### Setup Tacotron 2 TorchScript.
First, you need to create a folder structure for the model to be loaded in TRTIS server.
Follow the Tacotron 2 Quick Start Guide (points 1-4) to start the container.
Inside the container, type:
```bash
cd /workspace/tacotron2/
python export_tacotron2_ts_config.py --amp-run
```
This will export the folder structure of the TRTIS repository and the config file of Tacotron 2.
By default, it will be found in the `trtis_repo/tacotron` folder.
Now there are two ways to proceed.
#### Download the Tacotron 2 TorchScript model.
Download the Tacotron 2 TorchScript model from:
- [Tacotron2 TorchScript](https://ngc.nvidia.com/models/nvidia:tacotron2pyt_jit_fp16)
Move the downloaded model to `trtis_repo/tacotron2/1/model.pt`
#### Export the Tacotron 2 model using TorchScript.
To export the Tacotron 2 model using TorchScript, type:
```bash
python export_tacotron2_ts.py --tacotron2 <tacotron2_checkpoint> -o trtis_repo/tacotron2/1/model.pt --amp-run
```
This will save the model as ``trtis_repo/tacotron/1/model.pt``.
### Setup WaveGlow TRT engine.
For WaveGlow, we also need to create the folder structure that will be used by the TRTIS server.
Inside the container, type:
```bash
cd /workspace/tacotron2/
python export_waveglow_trt_config.py --amp-run
```
This will export the folder structure of the TRTIS repository and the config file of Waveglow.
By default, it will be found in the `trtis_repo/waveglow` folder.
There are two ways to proceed.
#### Download the WaveGlow TRT engine.
Download the WaveGlow TRT engine from:
- [WaveGlow TRT engine](https://ngc.nvidia.com/models/nvidia:waveglow256pyt_trt_fp16)
Move the downloaded model to `trtis_repo/waveglow/1/model.plan`
#### Export the WaveGlow model to TRT.
Before exporting the model, you need to install onnx-tensorrt by typing:
```bash
cd /workspace && git clone https://github.com/onnx/onnx-tensorrt.git
cd /workspace/onnx-tensorrt/ && git submodule update --init --recursive
cd /workspace/onnx-tensorrt && mkdir -p build
cd /workspace/onnx-tensorrt/build && cmake .. -DCMAKE_CXX_FLAGS=-isystem\ /usr/local/cuda/include && make -j12 && make install
```
In order to export the model into the ONNX intermediate representation, type:
```bash
python export_waveglow_trt.py --waveglow <waveglow_checkpoint> --wn-channels 256 --amp-run
```
This will save the model as `waveglow.onnx` (you can change its name with the flag `--output <filename>`).
With the model exported to ONNX, type the following to obtain a TRT engine and save it as `trtis_repo/waveglow/1/model.plan`:
```bash
onnx2trt <exported_waveglow_onnx> -o trtis_repo/waveglow/1/model.plan -b 1 -w 8589934592
```
### Setup the TRTIS server.
Download the TRTIS container by typing:
```bash
docker pull nvcr.io/nvidia/tensorrtserver:19.10-py3
docker tag nvcr.io/nvidia/tensorrtserver:19.10-py3 tensorrtserver:19.10
```
### Setup the TRTIS notebook client.
Now go to the root directory of the Tacotron 2 repo, and type:
```bash
docker build -f Dockerfile_trtis_client --network=host -t speech_ai__tts_only:demo .
```
### Run the TRTIS server.
To run the server, type in the root directory of the Tacotron 2 repo:
```bash
NV_GPU=1 nvidia-docker run -ti --ipc=host --network=host --rm -p8000:8000 -p8001:8001 -v $PWD/trtis_repo/:/models tensorrtserver:19.10 trtserver --model-store=/models --log-verbose 1
```
The flag `NV_GPU` selects the GPU the server is going to see. If we want it to see all the available GPUs, then run the above command without this flag.
By default, the model repository will be in `trtis_repo/`.
### Run the TRTIS notebook client.
Leave the server running. In another terminal, type:
```bash
docker run -it --rm --network=host --device /dev/snd:/dev/snd --device /dev/usb:/dev/usb speech_ai__tts_only:demo bash ./run_this.sh
```
Open the URL in a browser, open `notebook.ipynb`, click play, and enjoy.

View file

@ -0,0 +1,467 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 25,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "bfd62b9362ec4dbb825e786821358a5e",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"VBox(layout=Layout(height='1in'))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/markdown": [
"**tacotron2 input**"
],
"text/plain": [
"<IPython.core.display.Markdown object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "fead313d50594d0688a399cff8b6eb86",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Textarea(value='type here', layout=Layout(height='80px', width='550px'), placeholder='')"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/markdown": [
"**tacotron2 preprocessing**"
],
"text/plain": [
"<IPython.core.display.Markdown object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "e610d8bd77ad44a7839d6ede103a71ef",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Output(layout=Layout(height='1in', object_fit='fill', object_position='{center} {center}', width='10in'))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/markdown": [
"**tacotron2 output / waveglow input**"
],
"text/plain": [
"<IPython.core.display.Markdown object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "240214dbb35c4bfd97b84df7cb5cd8cf",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Output(layout=Layout(height='2.1in', object_fit='fill', object_position='{center} {center}', width='10in'))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/markdown": [
"**waveglow output**"
],
"text/plain": [
"<IPython.core.display.Markdown object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "5bc0e4519ea74a89b34c55fa8a040b02",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Output(layout=Layout(height='2in', object_fit='fill', object_position='{center} {center}', width='10in'))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/markdown": [
"**play**"
],
"text/plain": [
"<IPython.core.display.Markdown object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "30976ec770c34d34880af8d7f3d6ff08",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Output(layout=Layout(height='1in', object_fit='fill', object_position='{center} {center}', width='10in'))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "bfd62b9362ec4dbb825e786821358a5e",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"VBox(layout=Layout(height='1in'))"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"import os\n",
"import time\n",
"import numpy as np\n",
"import collections\n",
"%matplotlib inline\n",
"import matplotlib.pyplot as plt\n",
"from matplotlib import cm as cm\n",
"from IPython.display import Audio, display, clear_output, Markdown, Image\n",
"import librosa\n",
"import librosa.display\n",
"import ipywidgets as widgets\n",
"# \n",
"from tacotron2.text import text_to_sequence as text_to_sequence_internal\n",
"from tacotron2.text.symbols import symbols\n",
"# \n",
"from tensorrtserver.api import *\n",
"\n",
"\n",
"defaults = {\n",
" # settings\n",
" 'sigma_infer': 0.6, # don't touch this\n",
" 'sampling_rate': 22050, # don't touch this\n",
" 'stft_hop_length': 256, # don't touch this\n",
" 'url': 'localhost:8000', # don't touch this\n",
" 'protocol': 0, # 0: http, 1: grpc \n",
" 'autoplay': True, # autoplay\n",
" 'character_limit_min': 4, # don't touch this\n",
" 'character_limit_max': 124 # don't touch this\n",
"}\n",
"\n",
"\n",
"# create args object\n",
"class Struct:\n",
" def __init__(self, **entries):\n",
" self.__dict__.update(entries)\n",
"\n",
"\n",
"args = Struct(**defaults)\n",
"\n",
"\n",
"# create the inference context for the models\n",
"infer_ctx_tacotron2 = InferContext(args.url, args.protocol, 'tacotron2', -1)\n",
"infer_ctx_waveglow = InferContext(args.url, args.protocol, 'waveglow', -1)\n",
"\n",
"\n",
"def display_heatmap(sequence, title='preprocessed text'):\n",
" ''' displays sequence as a heatmap '''\n",
" clear_output(wait=True)\n",
" sequence = sequence[None, :]\n",
" plt.figure(figsize=(10, 2.5))\n",
" plt.title(title)\n",
" plt.tick_params(\n",
" axis='both',\n",
" which='both',\n",
" bottom=False,\n",
" top=False,\n",
" left=False,\n",
" right=False,\n",
" labelbottom=False,\n",
" labelleft=False)\n",
" plt.imshow(sequence, cmap='BrBG_r', interpolation='nearest')\n",
" plt.show()\n",
"\n",
"\n",
"def display_sound(signal, title, color):\n",
" ''' displays signal '''\n",
" clear_output(wait=True)\n",
" plt.figure(figsize=(10, 2.5))\n",
" plt.title(title)\n",
" plt.tick_params(\n",
" axis='both',\n",
" which='both',\n",
" bottom=True,\n",
" top=False,\n",
" left=False,\n",
" right=False,\n",
" labelbottom=True,\n",
" labelleft=False)\n",
" librosa.display.waveplot(signal, color=color)\n",
" plt.show()\n",
"\n",
"\n",
"def display_spectrogram(mel, title):\n",
" ''' displays mel spectrogram '''\n",
" clear_output(wait=True)\n",
" fig = plt.figure(figsize=(10, 2.5))\n",
" ax = fig.add_subplot(111)\n",
" plt.title(title)\n",
" plt.tick_params(\n",
" axis='both',\n",
" which='both',\n",
" bottom=True,\n",
" top=False,\n",
" left=False,\n",
" right=False,\n",
" labelbottom=True,\n",
" labelleft=False)\n",
" plt.xlabel('Time')\n",
" cmap = cm.get_cmap('jet', 30)\n",
" cax = ax.imshow(mel.astype(np.float32), interpolation=\"nearest\", cmap=cmap)\n",
" ax.grid(True)\n",
" plt.show()\n",
"\n",
"\n",
"def text_to_sequence(text):\n",
" ''' preprocessor of tacotron2\n",
" ::text:: the input str\n",
" ::returns:: sequence, the preprocessed text\n",
" '''\n",
" sequence = text_to_sequence_internal(text, ['english_cleaners'])\n",
" sequence = np.array(sequence, dtype=np.int64)\n",
" return sequence\n",
"\n",
"\n",
"def sequence_to_mel(sequence):\n",
" ''' calls tacotron2\n",
" ::sequence:: int64 numpy array, contains the preprocessed text\n",
" ::returns:: (mel, mel_lengths) pair\n",
" mel is the mel-spectrogram, np.array\n",
" mel_lengths contains the length of the unpadded mel, np.array\n",
" '''\n",
" input_lengths = [len(sequence)]\n",
" input_lengths = np.array(input_lengths, dtype=np.int64)\n",
" # prepare input/output\n",
" input_dict = {}\n",
" input_dict['sequence__0'] = (sequence,)\n",
" input_dict['input_lengths__1'] = (input_lengths,)\n",
" output_dict = {}\n",
" output_dict['mel_outputs_postnet__0'] = InferContext.ResultFormat.RAW\n",
" output_dict['mel_lengths__1'] = InferContext.ResultFormat.RAW\n",
" batch_size = 1\n",
" # call tacotron2\n",
" result = infer_ctx_tacotron2.run(input_dict, output_dict, batch_size)\n",
" # get results\n",
" mel = result['mel_outputs_postnet__0'][0] # take only the first instance in the output batch\n",
" mel_lengths = result['mel_lengths__1'][0] # take only the first instance in the output batch\n",
" return mel, mel_lengths\n",
"\n",
"\n",
"def force_to_shape(mel, length):\n",
" ''' preprocessor of waveglow\n",
" :: mel :: numpy array \n",
" :: length :: int \n",
" :: return :: m padded (or trimmed) to length in dimension 1\n",
" '''\n",
" diff = length - mel.shape[1]\n",
" if 0 < diff:\n",
" # pad it\n",
" min_value = mel.min()\n",
" shape = ((0,0),(0,diff))\n",
" ret = np.pad(mel, shape, mode='constant', constant_values=min_value)\n",
" else:\n",
" # trim it\n",
" ret = mel[:,:length]\n",
" ret = ret[:,:,None]\n",
" return ret\n",
"\n",
"\n",
"def mel_to_signal(mel, mel_lengths):\n",
" ''' calls waveglow\n",
" ::mel:: mel spectrogram\n",
" ::mel_lengths:: original length of mel spectrogram\n",
" ::returns:: waveform\n",
" '''\n",
" # padding/trimming mel to dimension 620\n",
" mel = force_to_shape(mel, 620)\n",
" # prepare input/output\n",
" mel = mel[None,:,:]\n",
" input_dict = {}\n",
" input_dict['0'] = (mel,)\n",
" shape = (8,19840,1)\n",
" shape = (1,*shape)\n",
" input_dict['1'] = np.random.normal(0.0, 1.0, shape).astype(mel.dtype)\n",
" input_dict['1'] = (input_dict['1'],)\n",
" output_dict = {}\n",
" output_dict['1991'] = InferContext.ResultFormat.RAW\n",
" batch_size = 1\n",
" # call waveglow\n",
" result = infer_ctx_waveglow.run(input_dict, output_dict, batch_size)\n",
" # get the results\n",
" signal = result['1991'][0] # take only the first instance in the output batch\n",
" signal = signal[0] # remove this line, when waveglow supports dynamic batch sizes\n",
" # postprocessing of waveglow: trimming signal to its actual size\n",
" trimmed_length = mel_lengths[0] * args.stft_hop_length\n",
" signal = signal[:trimmed_length] # trim\n",
" signal = signal.astype(np.float32)\n",
" return signal\n",
"\n",
"\n",
"# widgets\n",
"def get_output_widget(width, height):\n",
" ''' creates an output widget with default values and returns it '''\n",
" layout = widgets.Layout(width=width,\n",
" height=height,\n",
" object_fit='fill',\n",
" object_position = '{center} {center}')\n",
" ret = widgets.Output(layout=layout)\n",
" return ret\n",
"\n",
"\n",
"text_area = widgets.Textarea(\n",
" value='type here',\n",
" placeholder='',\n",
" description='',\n",
" disabled=False,\n",
" continuous_update=True,\n",
" layout=widgets.Layout(width='550px', height='80px')\n",
")\n",
"\n",
"\n",
"plot_text_area_preprocessed = get_output_widget(width='10in',height='1in')\n",
"plot_spectrogram = get_output_widget(width='10in',height='2.1in')\n",
"plot_signal = get_output_widget(width='10in',height='2.1in')\n",
"plot_play = get_output_widget(width='10in',height='1in')\n",
"\n",
"\n",
"def text_area_change(change):\n",
" ''' this gets called each time text_area.value changes '''\n",
" text = change['new']\n",
" text = text.strip(' ')\n",
" length = len(text)\n",
" if length < args.character_limit_min: # too short text\n",
" return\n",
" if length > args.character_limit_max: # too long text\n",
" text_area.value = text[:args.character_limit_max]\n",
" return\n",
" # preprocess tacotron2\n",
" sequence = text_to_sequence(text)\n",
" with plot_text_area_preprocessed:\n",
" display_heatmap(sequence)\n",
" # run tacotron2\n",
" mel, mel_lengths = sequence_to_mel(sequence)\n",
" with plot_spectrogram:\n",
" display_spectrogram(mel, change['new'])\n",
" # run waveglow\n",
" signal = mel_to_signal(mel, mel_lengths)\n",
" with plot_signal:\n",
" display_sound(signal, change['new'], 'green')\n",
" with plot_play:\n",
" clear_output(wait=True)\n",
" display(Audio(signal, rate=args.sampling_rate, autoplay=args.autoplay))\n",
" # related issue: https://github.com/ipython/ipython/issues/11316\n",
"\n",
"\n",
"# setup callback\n",
"text_area.observe(text_area_change, names='value')\n",
"\n",
"# decorative widgets\n",
"empty = widgets.VBox([], layout=widgets.Layout(height='1in'))\n",
"markdown_4 = Markdown('**tacotron2 input**')\n",
"markdown_5 = Markdown('**tacotron2 preprocessing**')\n",
"markdown_6 = Markdown('**tacotron2 output / waveglow input**')\n",
"markdown_7 = Markdown('**waveglow output**')\n",
"markdown_8 = Markdown('**play**')\n",
"\n",
"# display widgets\n",
"display(\n",
" empty, \n",
" markdown_4, text_area, \n",
"# markdown_5, plot_text_area_preprocessed, \n",
" markdown_6, plot_spectrogram, \n",
" markdown_7, plot_signal, \n",
" markdown_8, plot_play, \n",
" empty\n",
")\n",
"\n",
"# default text\n",
"text_area.value = \"I think grown-ups just act like they know what they're doing. \""
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.3"
}
},
"nbformat": 4,
"nbformat_minor": 4
}

View file

@ -0,0 +1 @@
jupyter lab --allow-root --ip=0.0.0.0 --no-browser notebook.ipynb

View file

@ -27,7 +27,6 @@
from math import sqrt
import torch
from torch.autograd import Variable
from torch import nn
from torch.nn import functional as F
import sys
@ -108,8 +107,7 @@ class Attention(nn.Module):
alignment = self.get_alignment_energies(
attention_hidden_state, processed_memory, attention_weights_cat)
if mask is not None:
alignment.data.masked_fill_(mask, self.score_mask_value)
alignment.masked_fill_(mask, self.score_mask_value)
attention_weights = F.softmax(alignment, dim=1)
attention_context = torch.bmm(attention_weights.unsqueeze(1), memory)
@ -126,17 +124,9 @@ class Prenet(nn.Module):
[LinearNorm(in_size, out_size, bias=False)
for (in_size, out_size) in zip(in_sizes, sizes)])
def forward(self, x, inference=False):
if inference:
for linear in self.layers:
x = F.relu(linear(x))
x0 = x[0].unsqueeze(0)
mask = Variable(torch.bernoulli(x0.data.new(x0.data.size()).fill_(0.5)))
mask = mask.expand(x.size(0), x.size(1))
x = x*mask*2
else:
for linear in self.layers:
x = F.dropout(F.relu(linear(x)), p=0.5, training=True)
def forward(self, x):
for linear in self.layers:
x = F.dropout(F.relu(linear(x)), p=0.5, training=True)
return x
@ -178,11 +168,16 @@ class Postnet(nn.Module):
dilation=1, w_init_gain='linear'),
nn.BatchNorm1d(n_mel_channels))
)
self.n_convs = len(self.convolutions)
def forward(self, x):
for i in range(len(self.convolutions) - 1):
x = F.dropout(torch.tanh(self.convolutions[i](x)), 0.5, self.training)
x = F.dropout(self.convolutions[-1](x), 0.5, self.training)
i = 0
for conv in self.convolutions:
if i < self.n_convs - 1:
x = F.dropout(torch.tanh(conv(x)), 0.5, training=self.training)
else:
x = F.dropout(conv(x), 0.5, training=self.training)
i += 1
return x
@ -212,6 +207,7 @@ class Encoder(nn.Module):
int(encoder_embedding_dim / 2), 1,
batch_first=True, bidirectional=True)
@torch.jit.ignore
def forward(self, x, input_lengths):
for conv in self.convolutions:
x = F.dropout(F.relu(conv(x)), 0.5, self.training)
@ -231,6 +227,25 @@ class Encoder(nn.Module):
return outputs
@torch.jit.export
def infer(self, x, input_lengths):
device = x.device
for conv in self.convolutions:
x = F.dropout(F.relu(conv(x.to(device))), 0.5, self.training)
x = x.transpose(1, 2)
input_lengths = input_lengths.cpu()
x = nn.utils.rnn.pack_padded_sequence(
x, input_lengths, batch_first=True)
outputs, _ = self.lstm(x)
outputs, _ = nn.utils.rnn.pad_packed_sequence(
outputs, batch_first=True)
return outputs
class Decoder(nn.Module):
def __init__(self, n_mel_channels, n_frames_per_step,
@ -290,11 +305,14 @@ class Decoder(nn.Module):
decoder_input: all zeros frames
"""
B = memory.size(0)
decoder_input = Variable(memory.data.new(
B, self.n_mel_channels * self.n_frames_per_step).zero_())
dtype = memory.dtype
device = memory.device
decoder_input = torch.zeros(
B, self.n_mel_channels*self.n_frames_per_step,
dtype=dtype, device=device)
return decoder_input
def initialize_decoder_states(self, memory, mask):
def initialize_decoder_states(self, memory):
""" Initializes attention rnn states, decoder rnn states, attention
weights, attention cumulative weights, attention context, stores memory
and stores processed memory
@ -305,27 +323,31 @@ class Decoder(nn.Module):
"""
B = memory.size(0)
MAX_TIME = memory.size(1)
dtype = memory.dtype
device = memory.device
self.attention_hidden = Variable(memory.data.new(
B, self.attention_rnn_dim).zero_())
self.attention_cell = Variable(memory.data.new(
B, self.attention_rnn_dim).zero_())
attention_hidden = torch.zeros(
B, self.attention_rnn_dim, dtype=dtype, device=device)
attention_cell = torch.zeros(
B, self.attention_rnn_dim, dtype=dtype, device=device)
self.decoder_hidden = Variable(memory.data.new(
B, self.decoder_rnn_dim).zero_())
self.decoder_cell = Variable(memory.data.new(
B, self.decoder_rnn_dim).zero_())
decoder_hidden = torch.zeros(
B, self.decoder_rnn_dim, dtype=dtype, device=device)
decoder_cell = torch.zeros(
B, self.decoder_rnn_dim, dtype=dtype, device=device)
self.attention_weights = Variable(memory.data.new(
B, MAX_TIME).zero_())
self.attention_weights_cum = Variable(memory.data.new(
B, MAX_TIME).zero_())
self.attention_context = Variable(memory.data.new(
B, self.encoder_embedding_dim).zero_())
attention_weights = torch.zeros(
B, MAX_TIME, dtype=dtype, device=device)
attention_weights_cum = torch.zeros(
B, MAX_TIME, dtype=dtype, device=device)
attention_context = torch.zeros(
B, self.encoder_embedding_dim, dtype=dtype, device=device)
self.memory = memory
self.processed_memory = self.attention_layer.memory_layer(memory)
self.mask = mask
processed_memory = self.attention_layer.memory_layer(memory)
return (attention_hidden, attention_cell, decoder_hidden,
decoder_cell, attention_weights, attention_weights_cum,
attention_context, processed_memory)
def parse_decoder_inputs(self, decoder_inputs):
""" Prepares decoder inputs, i.e. mel outputs
@ -362,21 +384,23 @@ class Decoder(nn.Module):
alignments:
"""
# (T_out, B) -> (B, T_out)
alignments = torch.stack(alignments).transpose(0, 1)
alignments = alignments.transpose(0, 1).contiguous()
# (T_out, B) -> (B, T_out)
gate_outputs = torch.stack(gate_outputs).transpose(0, 1)
gate_outputs = gate_outputs.contiguous()
gate_outputs = gate_outputs.transpose(0, 1).contiguous()
# (T_out, B, n_mel_channels) -> (B, T_out, n_mel_channels)
mel_outputs = torch.stack(mel_outputs).transpose(0, 1).contiguous()
mel_outputs = mel_outputs.transpose(0, 1).contiguous()
# decouple frames per step
mel_outputs = mel_outputs.view(
mel_outputs.size(0), -1, self.n_mel_channels)
shape = (mel_outputs.shape[0], -1, self.n_mel_channels)
mel_outputs = mel_outputs.view(*shape)
# (B, T_out, n_mel_channels) -> (B, n_mel_channels, T_out)
mel_outputs = mel_outputs.transpose(1, 2)
return mel_outputs, gate_outputs, alignments
def decode(self, decoder_input):
def decode(self, decoder_input, attention_hidden, attention_cell,
decoder_hidden, decoder_cell, attention_weights,
attention_weights_cum, attention_context, memory,
processed_memory, mask):
""" Decoder step using stored states, attention and memory
PARAMS
------
@ -388,37 +412,41 @@ class Decoder(nn.Module):
gate_output: gate output energies
attention_weights:
"""
cell_input = torch.cat((decoder_input, self.attention_context), -1)
cell_input = torch.cat((decoder_input, attention_context), -1)
self.attention_hidden, self.attention_cell = self.attention_rnn(
cell_input, (self.attention_hidden, self.attention_cell))
self.attention_hidden = F.dropout(
self.attention_hidden, self.p_attention_dropout, self.training)
attention_hidden, attention_cell = self.attention_rnn(
cell_input, (attention_hidden, attention_cell))
attention_hidden = F.dropout(
attention_hidden, self.p_attention_dropout, self.training)
attention_weights_cat = torch.cat(
(self.attention_weights.unsqueeze(1),
self.attention_weights_cum.unsqueeze(1)), dim=1)
self.attention_context, self.attention_weights = self.attention_layer(
self.attention_hidden, self.memory, self.processed_memory,
attention_weights_cat, self.mask)
(attention_weights.unsqueeze(1),
attention_weights_cum.unsqueeze(1)), dim=1)
attention_context, attention_weights = self.attention_layer(
attention_hidden, memory, processed_memory,
attention_weights_cat, mask)
self.attention_weights_cum += self.attention_weights
attention_weights_cum += attention_weights
decoder_input = torch.cat(
(self.attention_hidden, self.attention_context), -1)
(attention_hidden, attention_context), -1)
self.decoder_hidden, self.decoder_cell = self.decoder_rnn(
decoder_input, (self.decoder_hidden, self.decoder_cell))
self.decoder_hidden = F.dropout(
self.decoder_hidden, self.p_decoder_dropout, self.training)
decoder_hidden, decoder_cell = self.decoder_rnn(
decoder_input, (decoder_hidden, decoder_cell))
decoder_hidden = F.dropout(
decoder_hidden, self.p_decoder_dropout, self.training)
decoder_hidden_attention_context = torch.cat(
(self.decoder_hidden, self.attention_context), dim=1)
(decoder_hidden, attention_context), dim=1)
decoder_output = self.linear_projection(
decoder_hidden_attention_context)
gate_prediction = self.gate_layer(decoder_hidden_attention_context)
return decoder_output, gate_prediction, self.attention_weights
return (decoder_output, gate_prediction, attention_hidden,
attention_cell, decoder_hidden, decoder_cell, attention_weights,
attention_weights_cum, attention_context)
@torch.jit.ignore
def forward(self, memory, decoder_inputs, memory_lengths):
""" Decoder forward pass for training
PARAMS
@ -439,25 +467,51 @@ class Decoder(nn.Module):
decoder_inputs = torch.cat((decoder_input, decoder_inputs), dim=0)
decoder_inputs = self.prenet(decoder_inputs)
self.initialize_decoder_states(
memory, mask=~get_mask_from_lengths(memory_lengths))
mask = get_mask_from_lengths(memory_lengths)
(attention_hidden,
attention_cell,
decoder_hidden,
decoder_cell,
attention_weights,
attention_weights_cum,
attention_context,
processed_memory) = self.initialize_decoder_states(memory)
mel_outputs, gate_outputs, alignments = [], [], []
while len(mel_outputs) < decoder_inputs.size(0) - 1:
decoder_input = decoder_inputs[len(mel_outputs)]
mel_output, gate_output, attention_weights = self.decode(
decoder_input)
(mel_output,
gate_output,
attention_hidden,
attention_cell,
decoder_hidden,
decoder_cell,
attention_weights,
attention_weights_cum,
attention_context) = self.decode(decoder_input,
attention_hidden,
attention_cell,
decoder_hidden,
decoder_cell,
attention_weights,
attention_weights_cum,
attention_context,
memory,
processed_memory,
mask)
mel_outputs += [mel_output.squeeze(1)]
gate_outputs += [gate_output.squeeze()]
alignments += [attention_weights]
mel_outputs, gate_outputs, alignments = self.parse_decoder_outputs(
mel_outputs, gate_outputs, alignments)
torch.stack(mel_outputs),
torch.stack(gate_outputs),
torch.stack(alignments))
return mel_outputs, gate_outputs, alignments
@torch.jit.export
def infer(self, memory, memory_lengths):
""" Decoder inference
PARAMS
@ -472,26 +526,56 @@ class Decoder(nn.Module):
"""
decoder_input = self.get_go_frame(memory)
if memory.size(0) > 1:
mask =~ get_mask_from_lengths(memory_lengths)
else:
mask = None
mask = get_mask_from_lengths(memory_lengths)
(attention_hidden,
attention_cell,
decoder_hidden,
decoder_cell,
attention_weights,
attention_weights_cum,
attention_context,
processed_memory) = self.initialize_decoder_states(memory)
self.initialize_decoder_states(memory, mask=mask)
mel_lengths = torch.zeros([memory.size(0)], dtype=torch.int32).cuda()
not_finished = torch.ones([memory.size(0)], dtype=torch.int32).cuda()
mel_lengths = torch.zeros([memory.size(0)], dtype=torch.int32)
not_finished = torch.ones([memory.size(0)], dtype=torch.int32)
if torch.cuda.is_available():
mel_lengths = mel_lengths.cuda()
not_finished = not_finished.cuda()
mel_outputs, gate_outputs, alignments = [], [], []
mel_outputs, gate_outputs, alignments = (
torch.zeros(1), torch.zeros(1), torch.zeros(1))
first_iter = True
while True:
decoder_input = self.prenet(decoder_input, inference=True)
mel_output, gate_output, alignment = self.decode(decoder_input)
decoder_input = self.prenet(decoder_input)
(mel_output,
gate_output,
attention_hidden,
attention_cell,
decoder_hidden,
decoder_cell,
attention_weights,
attention_weights_cum,
attention_context) = self.decode(decoder_input,
attention_hidden,
attention_cell,
decoder_hidden,
decoder_cell,
attention_weights,
attention_weights_cum,
attention_context,
memory,
processed_memory,
mask)
dec = torch.le(torch.sigmoid(gate_output.data),
if first_iter:
mel_outputs = mel_output.unsqueeze(0)
gate_outputs = gate_output
alignments = attention_weights
first_iter = False
else:
mel_outputs = torch.cat(
(mel_outputs, mel_output.unsqueeze(0)), dim=0)
gate_outputs = torch.cat((gate_outputs, gate_output), dim=0)
alignments = torch.cat((alignments, attention_weights), dim=0)
dec = torch.le(torch.sigmoid(gate_output),
self.gate_threshold).to(torch.int32).squeeze(1)
not_finished = not_finished*dec
@ -499,11 +583,6 @@ class Decoder(nn.Module):
if self.early_stopping and torch.sum(not_finished) == 0:
break
mel_outputs += [mel_output.squeeze(1)]
gate_outputs += [gate_output]
alignments += [alignment]
if len(mel_outputs) == self.max_decoder_steps:
print("Warning! Reached max decoder steps")
break
@ -530,8 +609,7 @@ class Tacotron2(nn.Module):
self.mask_padding = mask_padding
self.n_mel_channels = n_mel_channels
self.n_frames_per_step = n_frames_per_step
self.embedding = nn.Embedding(
n_symbols, symbols_embedding_dim)
self.embedding = nn.Embedding(n_symbols, symbols_embedding_dim)
std = sqrt(2.0 / (n_symbols + symbols_embedding_dim))
val = sqrt(3.0) * std # uniform bounds for std
self.embedding.weight.data.uniform_(-val, val)
@ -565,15 +643,16 @@ class Tacotron2(nn.Module):
(text_padded, input_lengths, mel_padded, max_len, output_lengths),
(mel_padded, gate_padded))
def parse_output(self, outputs, output_lengths=None):
def parse_output(self, outputs, output_lengths):
# type: (List[Tensor], Tensor) -> List[Tensor]
if self.mask_padding and output_lengths is not None:
mask = ~get_mask_from_lengths(output_lengths)
mask = get_mask_from_lengths(output_lengths)
mask = mask.expand(self.n_mel_channels, mask.size(0), mask.size(1))
mask = mask.permute(1, 0, 2)
outputs[0].data.masked_fill_(mask, 0.0)
outputs[1].data.masked_fill_(mask, 0.0)
outputs[2].data.masked_fill_(mask[:, 0, :], 1e3) # gate energies
outputs[0].masked_fill_(mask, 0.0)
outputs[1].masked_fill_(mask, 0.0)
outputs[2].masked_fill_(mask[:, 0, :], 1e3) # gate energies
return outputs
@ -595,17 +674,15 @@ class Tacotron2(nn.Module):
[mel_outputs, mel_outputs_postnet, gate_outputs, alignments],
output_lengths)
def infer(self, inputs, input_lengths):
embedded_inputs = self.embedding(inputs).transpose(1, 2)
encoder_outputs = self.encoder(embedded_inputs, input_lengths)
encoder_outputs = self.encoder.infer(embedded_inputs, input_lengths)
mel_outputs, gate_outputs, alignments, mel_lengths = self.decoder.infer(
encoder_outputs, input_lengths)
mel_outputs_postnet = self.postnet(mel_outputs)
mel_outputs_postnet = mel_outputs + mel_outputs_postnet
outputs = self.parse_output(
[mel_outputs, mel_outputs_postnet, gate_outputs, alignments, mel_lengths])
return outputs
return mel_outputs_postnet, mel_lengths

View file

@ -240,7 +240,7 @@ def main():
with torch.no_grad():
with MeasureTime(measurements, "latency"):
with MeasureTime(measurements, "tacotron2_latency"):
_, mel, _, _, mel_lengths = tacotron2.infer(sequences_padded, input_lengths)
mel, mel_lengths = tacotron2.infer(sequences_padded, input_lengths)
with MeasureTime(measurements, "waveglow_latency"):
audios = waveglow.infer(mel, sigma=args.sigma_infer)

View file

@ -6,7 +6,7 @@ PRECISION="fp32"
NUM_ITERS=1003 # extra 3 iterations for warmup
TACOTRON2_CKPT="checkpoint_Tacotron2_1500_fp32"
WAVEGLOW_CKPT="checkpoint_WaveGlow_1000_fp32"
AMP_RUN=""
while [ -n "$1" ]
do
@ -41,6 +41,11 @@ do
shift
done
if [ "$PRECISION" = "amp" ]
then
AMP_RUN="--amp-run"
fi
LOG_SUFFIX=bs${BATCH_SIZE}_il${INPUT_LENGTH}_${PRECISION}
NVLOG_FILE=nvlog_${LOG_SUFFIX}.json
TMP_LOGFILE=tmp_log_${LOG_SUFFIX}.log
@ -51,7 +56,7 @@ python test_infer.py \
--tacotron2 $TACOTRON2_CKPT \
--waveglow $WAVEGLOW_CKPT \
--batch-size $BATCH_SIZE \
--input-length $INPUT_LENGTH $AMP_RUN $CPU_RUN \
--input-length $INPUT_LENGTH $AMP_RUN \
--log-file $NVLOG_FILE \
--num-iters $NUM_ITERS \
|& tee $TMP_LOGFILE

View file

@ -69,12 +69,6 @@ def parse_args(parser):
help='Model to train')
parser.add_argument('--log-file', type=str, default='nvlog.json',
help='Filename for logging')
parser.add_argument('--phrase-path', type=str, default=None,
help='Path to phrase sequence file used for sample generation')
parser.add_argument('--waveglow-checkpoint', type=str, default=None,
help='Path to pre-trained WaveGlow checkpoint for sample generation')
parser.add_argument('--tacotron2-checkpoint', type=str, default=None,
help='Path to pre-trained Tacotron2 checkpoint for sample generation')
parser.add_argument('--anneal-steps', nargs='*',
help='Epochs after which decrease learning rate')
parser.add_argument('--anneal-factor', type=float, choices=[0.1, 0.3], default=0.1,
@ -86,6 +80,8 @@ def parse_args(parser):
help='Number of total epochs to run')
training.add_argument('--epochs-per-checkpoint', type=int, default=50,
help='Number of epochs per checkpoint')
training.add_argument('--checkpoint-path', type=str, default='',
help='Checkpoint path to resume training')
training.add_argument('--seed', type=int, default=1234,
help='Seed for PyTorch random number generators')
training.add_argument('--dynamic-loss-scaling', type=bool, default=True,
@ -183,55 +179,38 @@ def init_distributed(args, world_size, rank, group_name):
print("Done initializing distributed")
def save_checkpoint(model, epoch, config, filepath):
def save_checkpoint(model, optimizer, epoch, config, amp_run, filepath):
print("Saving model and optimizer state at epoch {} to {}".format(
epoch, filepath))
torch.save({'epoch': epoch,
'config': config,
'state_dict': model.state_dict()}, filepath)
checkpoint = {'epoch': epoch,
'cuda_rng_state_all': torch.cuda.get_rng_state_all(),
'random_rng_state': torch.random.get_rng_state(),
'config': config,
'state_dict': model.state_dict(),
'optimizer': optimizer.state_dict()}
if amp_run:
checkpoint['amp'] = amp.state_dict()
torch.save(checkpoint, filepath)
def save_sample(model_name, model, waveglow_path, tacotron2_path, phrase_path, filepath, sampling_rate):
if phrase_path is None:
return
phrase = torch.load(phrase_path, map_location='cpu')
if model_name == 'Tacotron2':
if waveglow_path is None:
raise Exception(
"WaveGlow checkpoint path is missing, could not generate sample")
with torch.no_grad():
checkpoint = torch.load(waveglow_path, map_location='cpu')
waveglow = models.get_model(
'WaveGlow', checkpoint['config'], to_cuda=False)
waveglow.eval()
model.eval()
mel = model.infer(phrase.cuda())[0].cpu()
model.train()
audio = waveglow.infer(mel, sigma=0.6)
elif model_name == 'WaveGlow':
if tacotron2_path is None:
raise Exception(
"Tacotron2 checkpoint path is missing, could not generate sample")
with torch.no_grad():
checkpoint = torch.load(tacotron2_path, map_location='cpu')
tacotron2 = models.get_model(
'Tacotron2', checkpoint['config'], to_cuda=False)
tacotron2.eval()
mel = tacotron2.infer(phrase)[0].cuda()
model.eval()
audio = model.infer(mel, sigma=0.6).cpu()
model.train()
else:
raise NotImplementedError(
"unknown model requested: {}".format(model_name))
audio = audio[0].numpy()
audio = audio.astype('int16')
write_wav(filepath, sampling_rate, audio)
def load_checkpoint(model, optimizer, epoch, config, amp_run, filepath):
checkpoint = torch.load(filepath, map_location='cpu')
epoch[0] = checkpoint['epoch']+1
torch.cuda.set_rng_state_all(checkpoint['cuda_rng_state_all'])
torch.random.set_rng_state(checkpoint['random_rng_state'])
config = checkpoint['config']
model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])
if amp_run:
amp.load_state_dict(checkpoint['amp'])
# adapted from: https://discuss.pytorch.org/t/opinion-eval-should-be-a-context-manager/18998/3
# Following snippet is licensed under MIT license
@contextmanager
def evaluating(model):
'''Temporarily switch to evaluation mode.'''
@ -364,6 +343,14 @@ def main():
except AttributeError:
sigma = None
start_epoch = [0]
if args.checkpoint_path is not "":
load_checkpoint(model, optimizer, start_epoch, model_config,
args.amp_run, args.checkpoint_path)
start_epoch = start_epoch[0]
criterion = loss_functions.get_loss_function(model_name, sigma)
try:
@ -391,7 +378,7 @@ def main():
LOGGER.log(key=tags.TRAIN_LOOP)
for epoch in range(args.epochs):
for epoch in range(start_epoch, args.epochs):
LOGGER.epoch_start()
epoch_start_time = time.time()
LOGGER.log(key=tags.TRAIN_EPOCH_START, value=epoch)
@ -407,6 +394,9 @@ def main():
# if overflow at the last iteration then do not save checkpoint
overflow = False
if distributed_run:
train_loader.sampler.set_epoch(epoch)
for i, batch in enumerate(train_loader):
print("Batch: {}/{} epoch {}".format(i, len(train_loader), epoch))
LOGGER.iteration_start()
@ -489,10 +479,8 @@ def main():
if (epoch % args.epochs_per_checkpoint == 0) and args.rank == 0:
checkpoint_path = os.path.join(
args.output_directory, "checkpoint_{}_{}".format(model_name, epoch))
save_checkpoint(model, epoch, model_config, checkpoint_path)
save_sample(model_name, model, args.waveglow_checkpoint,
args.tacotron2_checkpoint, args.phrase_path,
os.path.join(args.output_directory, "sample_{}_{}.wav".format(model_name, iteration)), args.sampling_rate)
save_checkpoint(model, optimizer, epoch, model_config,
args.amp_run, checkpoint_path)
LOGGER.epoch_stop()

View file

@ -271,6 +271,50 @@ class WaveGlow(torch.nn.Module):
audio.size(0), -1).data
return audio
def infer_onnx(self, spect, z, sigma=1.0):
spect = self.upsample(spect)
# trim conv artifacts. maybe pad spec to kernel multiple
time_cutoff = self.upsample.kernel_size[0] - self.upsample.stride[0]
spect = spect[:, :, :-time_cutoff]
length_spect_group = int(spect.size(2)/8)
mel_dim = 80
spect = torch.squeeze(spect, 3)
spect = spect.view((1, mel_dim, length_spect_group, self.n_group))
spect = spect.permute(0, 2, 1, 3)
spect = spect.contiguous()
spect = spect.view((1, length_spect_group, self.n_group*mel_dim))
spect = spect.permute(0, 2, 1)
spect = torch.unsqueeze(spect, 3)
audio = z[:, :self.n_remaining_channels, :, :]
z = z[:, self.n_remaining_channels:self.n_group, :, :]
audio = sigma*audio
for k in reversed(range(self.n_flows)):
n_half = int(audio.size(1) / 2)
audio_0 = audio[:, :n_half, :, :]
audio_1 = audio[:, n_half:(n_half+n_half), :, :]
output = self.WN[k]((audio_0, spect))
s = output[:, n_half:(n_half+n_half), :, :]
b = output[:, :n_half, :, :]
audio_1 = (audio_1 - b) / torch.exp(s)
audio = torch.cat([audio_0, audio_1], 1)
audio = self.convinv[k](audio)
if k % self.n_early_every == 0 and k > 0:
audio = torch.cat((z[:, :self.n_early_size, :, :], audio), 1)
z = z[:, self.n_early_size:self.n_group, :, :]
audio = torch.squeeze(audio, 3)
audio = audio.permute(0,2,1).contiguous().view(1, (length_spect_group * self.n_group))
return audio
@staticmethod
def remove_weightnorm(model):
waveglow = model