[Tacotron2/PyT] Updates: better perf, better trt7 support, new logging, bug fixes

This commit is contained in:
Przemek Strzelczyk 2020-02-28 15:36:14 +01:00
parent 155578a762
commit 77a1bb917a
25 changed files with 225 additions and 526 deletions

View file

@ -0,0 +1,4 @@
__pycache__/
/checkpoints/
/output/
nvlog.json

View file

@ -1,6 +1,6 @@
FROM nvcr.io/nvidia/pytorch:19.11-py3
ARG FROM_IMAGE_NAME=nvcr.io/nvidia/pytorch:20.01-py3
FROM ${FROM_IMAGE_NAME}
ADD . /workspace/tacotron2
WORKDIR /workspace/tacotron2
RUN pip install -r requirements.txt
RUN pip --no-cache-dir --no-cache install 'git+https://github.com/NVIDIA/dllogger'
RUN pip install --no-cache-dir -r requirements.txt

View file

@ -11,7 +11,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
FROM nvcr.io/nvidia/tensorrtserver:19.10-py3-clientsdk AS trt
FROM nvcr.io/nvidia/tensorrtserver:20.01-py3-clientsdk AS trt
FROM continuumio/miniconda3
RUN apt-get update && apt-get install -y pbzip2 pv bzip2 cabextract mc iputils-ping wget

View file

@ -231,7 +231,7 @@ and encapsulates some dependencies. Aside from these dependencies, ensure you
have the following components:
* [NVIDIA Docker](https://github.com/NVIDIA/nvidia-docker)
* [PyTorch 19.06-py3+ NGC container](https://ngc.nvidia.com/registry/nvidia-pytorch)
* [PyTorch 20.01-py3+ NGC container](https://ngc.nvidia.com/registry/nvidia-pytorch)
or newer
* [NVIDIA Volta](https://www.nvidia.com/en-us/data-center/volta-gpu-architecture/) or [Turing](https://www.nvidia.com/en-us/geforce/turing/) based GPU
@ -370,7 +370,7 @@ WaveGlow models.
* `--epochs` - number of epochs (Tacotron 2: 1501, WaveGlow: 1001)
* `--learning-rate` - learning rate (Tacotron 2: 1e-3, WaveGlow: 1e-4)
* `--batch-size` - batch size (Tacotron 2 FP16/FP32: 128/64, WaveGlow FP16/FP32: 10/4)
* `--batch-size` - batch size (Tacotron 2 FP16/FP32: 104/48, WaveGlow FP16/FP32: 10/4)
* `--amp-run` - use mixed precision training
#### Shared audio/STFT parameters
@ -496,21 +496,21 @@ To benchmark the training performance on a specific batch size, run:
* For 1 GPU
* FP16
```bash
python train.py -m Tacotron2 -o <output_dir> -lr 1e-3 --epochs 10 -bs <batch_size> --weight-decay 1e-6 --grad-clip-thresh 1.0 --cudnn-enabled --log-file nvlog.json --training-files filelists/ljs_audio_text_train_subset_2500_filelist.txt --dataset-path <dataset-path> --amp-run
python train.py -m Tacotron2 -o <output_dir> -lr 1e-3 --epochs 10 -bs <batch_size> --weight-decay 1e-6 --grad-clip-thresh 1.0 --cudnn-enabled --log-file nvlog.json --load-mel-from-disk --training-files=filelists/ljs_mel_text_train_subset_2500_filelist.txt --validation-files=filelists/ljs_mel_text_val_filelist.txt --dataset-path <dataset-path> --amp-run
```
* FP32
```bash
python train.py -m Tacotron2 -o <output_dir> -lr 1e-3 --epochs 10 -bs <batch_size> --weight-decay 1e-6 --grad-clip-thresh 1.0 --cudnn-enabled --log-file nvlog.json --training-files filelists/ljs_audio_text_train_subset_2500_filelist.txt --dataset-path <dataset-path>
python train.py -m Tacotron2 -o <output_dir> -lr 1e-3 --epochs 10 -bs <batch_size> --weight-decay 1e-6 --grad-clip-thresh 1.0 --cudnn-enabled --log-file nvlog.json --load-mel-from-disk --training-files=filelists/ljs_mel_text_train_subset_2500_filelist.txt --validation-files=filelists/ljs_mel_text_val_filelist.txt --dataset-path <dataset-path>
```
* For multiple GPUs
* FP16
```bash
python -m multiproc train.py -m Tacotron2 -o <output_dir> -lr 1e-3 --epochs 10 -bs <batch_size> --weight-decay 1e-6 --grad-clip-thresh 1.0 --cudnn-enabled --log-file nvlog.json --training-files filelists/ljs_audio_text_train_subset_2500_filelist.txt --dataset-path <dataset-path> --amp-run
python -m multiproc train.py -m Tacotron2 -o <output_dir> -lr 1e-3 --epochs 10 -bs <batch_size> --weight-decay 1e-6 --grad-clip-thresh 1.0 --cudnn-enabled --log-file nvlog.json --load-mel-from-disk --training-files=filelists/ljs_mel_text_train_subset_2500_filelist.txt --validation-files=filelists/ljs_mel_text_val_filelist.txt --dataset-path <dataset-path> --amp-run
```
* FP32
```bash
python -m multiproc train.py -m Tacotron2 -o <output_dir> -lr 1e-3 --epochs 10 -bs <batch_size> --weight-decay 1e-6 --grad-clip-thresh 1.0 --cudnn-enabled --log-file nvlog.json --training-files filelists/ljs_audio_text_train_subset_2500_filelist.txt --dataset-path <dataset-path>
python -m multiproc train.py -m Tacotron2 -o <output_dir> -lr 1e-3 --epochs 10 -bs <batch_size> --weight-decay 1e-6 --grad-clip-thresh 1.0 --cudnn-enabled --log-file nvlog.json --load-mel-from-disk --training-files=filelists/ljs_mel_text_train_subset_2500_filelist.txt --validation-files=filelists/ljs_mel_text_val_filelist.txt --dataset-path <dataset-path>
```
**WaveGlow**
@ -579,10 +579,10 @@ All of the results were produced using the `train.py` script as described in the
| WaveGlow FP16 | -2.2054 | -5.7602 | -5.901 | -5.9706 | -6.0258 |
| WaveGlow FP32 | -3.0327 | -5.858 | -6.0056 | -6.0613 | -6.1087 |
Tacotron 2 FP16 loss - batch size 128 (mean and std over 16 runs)
Tacotron 2 FP16 loss - batch size 104 (mean and std over 16 runs)
![](./img/tacotron2_amp_loss.png "Tacotron 2 FP16 loss")
Tacotron 2 FP32 loss - batch size 64 (mean and std over 16 runs)
Tacotron 2 FP32 loss - batch size 48 (mean and std over 16 runs)
![](./img/tacotron2_fp32_loss.png "Tacotron 2 FP16 loss")
WaveGlow FP16 loss - batch size 10 (mean and std over 16 runs)
@ -597,7 +597,7 @@ WaveGlow FP32 loss - batch size 4 (mean and std over 16 runs)
##### Training performance: NVIDIA DGX-1 (8x V100 16G)
Our results were obtained by running the `./platform/train_{tacotron2,waveglow}_{AMP,FP32}_DGX1_16GB_8GPU.sh`
training script in the PyTorch-19.06-py3 NGC container on NVIDIA DGX-1 with
training script in the PyTorch-19.12-py3 NGC container on NVIDIA DGX-1 with
8x V100 16G GPUs. Performance numbers (in output mel-spectrograms per second for
Tacotron 2 and output samples per second for WaveGlow) were averaged over
an entire training epoch.
@ -606,9 +606,9 @@ This table shows the results for Tacotron 2:
|Number of GPUs|Batch size per GPU|Number of mels used with mixed precision|Number of mels used with FP32|Speed-up with mixed precision|Multi-GPU weak scaling with mixed precision|Multi-GPU weak scaling with FP32|
|---:|---:|---:|---:|---:|---:|---:|
|1|128@FP16, 64@FP32 | 20,992 | 12,933 | 1.62 | 1.00 | 1.00 |
|4|128@FP16, 64@FP32 | 74,989 | 46,115 | 1.63 | 3.57 | 3.57 |
|8|128@FP16, 64@FP32 | 140,060 | 88,719 | 1.58 | 6.67 | 6.86 |
|1|104@FP16, 48@FP32 | 15,313 | 9,674 | 1.58 | 1.00 | 1.00 |
|4|104@FP16, 48@FP32 | 53,661 | 32,778 | 1.64 | 3.50 | 3.39 |
|8|104@FP16, 48@FP32 | 100,422 | 59,549 | 1.69 | 6.56 | 6.16 |
The following table shows the results for WaveGlow:
@ -626,9 +626,9 @@ The following table shows the expected training time for convergence for Tacotro
|Number of GPUs|Batch size per GPU|Time to train with mixed precision (Hrs)|Time to train with FP32 (Hrs)|Speed-up with mixed precision|
|---:|---:|---:|---:|---:|
|1| 128@FP16, 64@FP32 | 153 | 234 | 1.53 |
|4| 128@FP16, 64@FP32 | 42 | 64 | 1.54 |
|8| 128@FP16, 64@FP32 | 22 | 33 | 1.52 |
|1| 104@FP16, 48@FP32 | 193 | 312 | 1.62 |
|4| 104@FP16, 48@FP32 | 53 | 85 | 1.58 |
|8| 104@FP16, 48@FP32 | 31 | 45 | 1.47 |
The following table shows the expected training time for convergence for WaveGlow (1001 epochs):
@ -704,8 +704,11 @@ November 2019
* Implemented training resume from checkpoint
* Added notebook for running Tacotron 2 and WaveGlow in TRTIS.
December 2019
* Added `trt` subfolder for running Tacotron 2 and WaveGlow in TensorRT.
December 2019
* Added export and inference scripts for TensorRT. See [Tacotron2 TensorRT README](trt/README.md).
January 2020
* Updated batch sizes and performance results for Tacotron 2.
### Known issues

View file

@ -58,7 +58,7 @@ class STFT(torch.nn.Module):
forward_basis = torch.FloatTensor(fourier_basis[:, None, :])
inverse_basis = torch.FloatTensor(
np.linalg.pinv(scale * fourier_basis).T[:, None, :])
np.linalg.pinv(scale * fourier_basis).T[:, None, :].astype(np.float32))
if window is not None:
assert(filter_length >= win_length)

View file

@ -27,6 +27,8 @@
import torch
import argparse
import sys
sys.path.append('./')
from inference import checkpoint_from_distributed, unwrap_distributed, load_and_setup_model
def parse_args(parser):

View file

@ -25,6 +25,7 @@
#
# *****************************************************************************
import types
import torch
import argparse
@ -113,32 +114,52 @@ def convert_1d_to_2d_(glow):
glow.cuda()
def test_inference(waveglow):
def infer_onnx(self, spect, z, sigma=0.9):
from scipy.io.wavfile import write
spect = self.upsample(spect)
# trim conv artifacts. maybe pad spec to kernel multiple
time_cutoff = self.upsample.kernel_size[0] - self.upsample.stride[0]
spect = spect[:, :, :-time_cutoff]
mel = torch.load("mel.pt").cuda()
# mel = torch.load("mel_spectrograms/LJ001-0015.wav.pt").cuda()
# mel = mel.unsqueeze(0)
mel_lengths = [mel.size(2)]
stride = 256
kernel_size = 1024
n_group = 8
z_size2 = (mel.size(2)-1)*stride+(kernel_size-1)+1
# corresponds to cutoff in infer_onnx
z_size2 = z_size2 - (kernel_size-stride)
z_size2 = z_size2//n_group
z = torch.randn(1, n_group, z_size2, 1).cuda()
mel = mel.unsqueeze(3)
length_spect_group = spect.size(2)//8
mel_dim = 80
batch_size = spect.size(0)
with torch.no_grad():
audios = waveglow(mel, z)
spect = torch.squeeze(spect, 3)
spect = spect.view((batch_size, mel_dim, length_spect_group, self.n_group))
spect = spect.permute(0, 2, 1, 3)
spect = spect.contiguous()
spect = spect.view((batch_size, length_spect_group, self.n_group*mel_dim))
spect = spect.permute(0, 2, 1)
spect = torch.unsqueeze(spect, 3)
spect = spect.contiguous()
for i, audio in enumerate(audios):
audio = audio[:mel_lengths[i]*256]
audio = audio/torch.max(torch.abs(audio))
write("audio_pyt.wav", 22050, audio.cpu().numpy())
audio = z[:, :self.n_remaining_channels, :, :]
z = z[:, self.n_remaining_channels:self.n_group, :, :]
audio = sigma*audio
for k in reversed(range(self.n_flows)):
n_half = int(audio.size(1) / 2)
audio_0 = audio[:, :n_half, :, :]
audio_1 = audio[:, n_half:(n_half+n_half), :, :]
output = self.WN[k]((audio_0, spect))
s = output[:, n_half:(n_half+n_half), :, :]
b = output[:, :n_half, :, :]
audio_1 = (audio_1 - b) / torch.exp(s)
audio = torch.cat([audio_0, audio_1], 1)
audio = self.convinv[k](audio)
if k % self.n_early_every == 0 and k > 0:
audio = torch.cat((z[:, :self.n_early_size, :, :], audio), 1)
z = z[:, self.n_early_size:self.n_group, :, :]
audio = torch.squeeze(audio, 3)
audio = audio.permute(0,2,1).contiguous().view(batch_size, (length_spect_group * self.n_group))
return audio
def export_onnx(parser, args):
@ -166,12 +187,16 @@ def export_onnx(parser, args):
# export to ONNX
convert_1d_to_2d_(waveglow)
waveglow.forward = waveglow.infer_onnx
fType = types.MethodType
waveglow.forward = fType(infer_onnx, waveglow)
if args.amp_run:
waveglow.half()
mel = mel.unsqueeze(3)
opset_version = 10
torch.onnx.export(waveglow, (mel, z), args.output+"/"+"waveglow.onnx",
opset_version=opset_version,
do_constant_folding=True,
@ -181,8 +206,6 @@ def export_onnx(parser, args):
"z": {0: "batch_size", 2: "z_seq"},
"audio": {0: "batch_size", 1: "audio_seq"}})
test_inference(waveglow)
def main():

View file

@ -64,20 +64,24 @@ def main():
config_template = r"""
name: "{model_name}"
platform: "tensorrt_plan"
default_model_filename: "waveglow_fp16.engine"
max_batch_size: 1
input {{
name: "0"
name: "mel"
data_type: {fp_type}
dims: [1, 80, 620, 1]
dims: [80, -1, 1]
}}
input {{
name: "1"
name: "z"
data_type: {fp_type}
dims: [1, 8, 19840, 1]
dims: [8, -1, 1]
}}
output {{
name: "1991"
name: "audio"
data_type: {fp_type}
dims: [1, 158720]
dims: [-1]
}}
"""

View file

@ -50,6 +50,7 @@ def parse_args(parser):
help='full path to the input text (phareses separated by new line)')
parser.add_argument('-o', '--output', required=True,
help='output folder to save audio (file per phrase)')
parser.add_argument('--suffix', type=str, default="", help="output filename suffix")
parser.add_argument('--tacotron2', type=str,
help='full path to the Tacotron2 model checkpoint file')
parser.add_argument('--waveglow', type=str,
@ -242,7 +243,7 @@ def main():
for i, audio in enumerate(audios):
audio = audio[:mel_lengths[i]*args.stft_hop_length]
audio = audio/torch.max(torch.abs(audio))
audio_path = args.output + "audio_"+str(i)+".wav"
audio_path = args.output+"audio_"+str(i)+"_"+args.suffix+".wav"
write(audio_path, args.sampling_rate, audio.cpu().numpy())
DLLogger.flush()

View file

@ -106,7 +106,7 @@ cd /workspace/onnx-tensorrt/build && cmake .. -DCMAKE_CXX_FLAGS=-isystem\ /usr/l
In order to export the model into the ONNX intermediate representation, type:
```bash
python exports/export_waveglow_onnx.py --waveglow <waveglow_checkpoint> --wn-channels 256 --amp-run
python exports/export_waveglow_onnx.py --waveglow <waveglow_checkpoint> --wn-channels 256 --amp-run --output ./output
```
This will save the model as `waveglow.onnx` (you can change its name with the flag `--output <filename>`).
@ -114,15 +114,15 @@ This will save the model as `waveglow.onnx` (you can change its name with the fl
With the model exported to ONNX, type the following to obtain a TRT engine and save it as `trtis_repo/waveglow/1/model.plan`:
```bash
onnx2trt <exported_waveglow_onnx> -o trtis_repo/waveglow/1/model.plan -b 1 -w 8589934592
python trt/export_onnx2trt.py --waveglow <exported_waveglow_onnx> -o trtis_repo/waveglow/1/ --fp16
```
### Setup the TRTIS server.
Download the TRTIS container by typing:
```bash
docker pull nvcr.io/nvidia/tensorrtserver:19.10-py3
docker tag nvcr.io/nvidia/tensorrtserver:19.10-py3 tensorrtserver:19.10
docker pull nvcr.io/nvidia/tensorrtserver:20.01-py3
docker tag nvcr.io/nvidia/tensorrtserver:20.01-py3 tensorrtserver:20.01
```
### Setup the TRTIS notebook client.
@ -130,14 +130,14 @@ docker tag nvcr.io/nvidia/tensorrtserver:19.10-py3 tensorrtserver:19.10
Now go to the root directory of the Tacotron 2 repo, and type:
```bash
docker build -f Dockerfile_trtis_client --network=host -t speech_ai__tts_only:demo .
docker build -f Dockerfile_trtis_client --network=host -t speech_ai_tts_only:demo .
```
### Run the TRTIS server.
To run the server, type in the root directory of the Tacotron 2 repo:
```bash
NV_GPU=1 nvidia-docker run -ti --ipc=host --network=host --rm -p8000:8000 -p8001:8001 -v $PWD/trtis_repo/:/models tensorrtserver:19.10 trtserver --model-store=/models --log-verbose 1
NV_GPU=1 nvidia-docker run -ti --ipc=host --network=host --rm -p8000:8000 -p8001:8001 -v $PWD/trtis_repo/:/models tensorrtserver:20.01 trtserver --model-store=/models --log-verbose 1
```
The flag `NV_GPU` selects the GPU the server is going to see. If we want it to see all the available GPUs, then run the above command without this flag.
@ -147,7 +147,7 @@ By default, the model repository will be in `trtis_repo/`.
Leave the server running. In another terminal, type:
```bash
docker run -it --rm --network=host --device /dev/snd:/dev/snd --device /dev/usb:/dev/usb speech_ai__tts_only:demo bash ./run_this.sh
docker run -it --rm --network=host --device /dev/snd:/dev/snd --device /dev/usb:/dev/usb speech_ai_tts_only:demo bash ./run_this.sh
```
Open the URL in a browser, open `notebook.ipynb`, click play, and enjoy.

View file

@ -2,168 +2,9 @@
"cells": [
{
"cell_type": "code",
"execution_count": 25,
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "bfd62b9362ec4dbb825e786821358a5e",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"VBox(layout=Layout(height='1in'))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/markdown": [
"**tacotron2 input**"
],
"text/plain": [
"<IPython.core.display.Markdown object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "fead313d50594d0688a399cff8b6eb86",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Textarea(value='type here', layout=Layout(height='80px', width='550px'), placeholder='')"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/markdown": [
"**tacotron2 preprocessing**"
],
"text/plain": [
"<IPython.core.display.Markdown object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "e610d8bd77ad44a7839d6ede103a71ef",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Output(layout=Layout(height='1in', object_fit='fill', object_position='{center} {center}', width='10in'))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/markdown": [
"**tacotron2 output / waveglow input**"
],
"text/plain": [
"<IPython.core.display.Markdown object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "240214dbb35c4bfd97b84df7cb5cd8cf",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Output(layout=Layout(height='2.1in', object_fit='fill', object_position='{center} {center}', width='10in'))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/markdown": [
"**waveglow output**"
],
"text/plain": [
"<IPython.core.display.Markdown object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "5bc0e4519ea74a89b34c55fa8a040b02",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Output(layout=Layout(height='2in', object_fit='fill', object_position='{center} {center}', width='10in'))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/markdown": [
"**play**"
],
"text/plain": [
"<IPython.core.display.Markdown object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "30976ec770c34d34880af8d7f3d6ff08",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Output(layout=Layout(height='1in', object_fit='fill', object_position='{center} {center}', width='10in'))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "bfd62b9362ec4dbb825e786821358a5e",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"VBox(layout=Layout(height='1in'))"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"outputs": [],
"source": [
"import os\n",
"import time\n",
@ -192,7 +33,7 @@
" 'protocol': 0, # 0: http, 1: grpc \n",
" 'autoplay': True, # autoplay\n",
" 'character_limit_min': 4, # don't touch this\n",
" 'character_limit_max': 124 # don't touch this\n",
" 'character_limit_max': 340 # don't touch this\n",
"}\n",
"\n",
"\n",
@ -330,23 +171,25 @@
" ::returns:: waveform\n",
" '''\n",
" # padding/trimming mel to dimension 620\n",
" mel = force_to_shape(mel, 620)\n",
" mel = mel[:,:,None]\n",
" # prepare input/output\n",
" mel = mel[None,:,:]\n",
" input_dict = {}\n",
" input_dict['0'] = (mel,)\n",
" shape = (8,19840,1)\n",
" shape = (1,*shape)\n",
" input_dict['1'] = np.random.normal(0.0, 1.0, shape).astype(mel.dtype)\n",
" input_dict['1'] = (input_dict['1'],)\n",
" input_dict['mel'] = (mel,)\n",
" stride = 256\n",
" kernel_size = 1024\n",
" n_group = 8\n",
" z_size = (mel.shape[1]-1)*stride + (kernel_size-1) + 1 - (kernel_size-stride)\n",
" z_size = z_size//n_group\n",
" shape = (n_group,z_size,1)\n",
" input_dict['z'] = np.random.normal(0.0, 1.0, shape).astype(mel.dtype)\n",
" input_dict['z'] = (input_dict['z'],)\n",
" output_dict = {}\n",
" output_dict['1991'] = InferContext.ResultFormat.RAW\n",
" output_dict['audio'] = InferContext.ResultFormat.RAW\n",
" batch_size = 1\n",
" # call waveglow\n",
" result = infer_ctx_waveglow.run(input_dict, output_dict, batch_size)\n",
" # get the results\n",
" signal = result['1991'][0] # take only the first instance in the output batch\n",
" signal = signal[0] # remove this line, when waveglow supports dynamic batch sizes\n",
" signal = result['audio'][0] # take only the first instance in the output batch\n",
" # postprocessing of waveglow: trimming signal to its actual size\n",
" trimmed_length = mel_lengths[0] * args.stft_hop_length\n",
" signal = signal[:trimmed_length] # trim\n",
@ -432,7 +275,7 @@
")\n",
"\n",
"# default text\n",
"text_area.value = \"I think grown-ups just act like they know what they're doing. \""
"text_area.value = \"The forms of printed letters should be beautiful, and that their arrangement on the page should be reasonable and a help to the shapeliness of the letters themselves.\""
]
},
{
@ -459,7 +302,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.3"
"version": "3.7.4"
}
},
"nbformat": 4,

View file

@ -1,2 +1,2 @@
mkdir -p output
python train.py -m Tacotron2 -o output/ --amp-run -lr 1e-3 --epochs 1501 -bs 128 --weight-decay 1e-6 --grad-clip-thresh 1.0 --cudnn-enabled --load-mel-from-disk --training-files=filelists/ljs_mel_text_train_filelist.txt --validation-files=filelists/ljs_mel_text_val_filelist.txt --log-file nvlog.json --anneal-steps 500 1000 1500 --anneal-factor 0.3
python train.py -m Tacotron2 -o output/ --amp-run -lr 1e-3 --epochs 1501 -bs 104 --weight-decay 1e-6 --grad-clip-thresh 1.0 --cudnn-enabled --load-mel-from-disk --training-files=filelists/ljs_mel_text_train_filelist.txt --validation-files=filelists/ljs_mel_text_val_filelist.txt --log-file nvlog.json --anneal-steps 500 1000 1500 --anneal-factor 0.3

View file

@ -1,2 +1,2 @@
mkdir -p output
python -m multiproc train.py -m Tacotron2 -o output/ --amp-run -lr 1e-3 --epochs 1501 -bs 128 --weight-decay 1e-6 --grad-clip-thresh 1.0 --cudnn-enabled --load-mel-from-disk --training-files=filelists/ljs_mel_text_train_filelist.txt --validation-files=filelists/ljs_mel_text_val_filelist.txt --log-file nvlog.json --anneal-steps 500 1000 1500 --anneal-factor 0.3
python -m multiproc train.py -m Tacotron2 -o output/ --amp-run -lr 1e-3 --epochs 1501 -bs 104 --weight-decay 1e-6 --grad-clip-thresh 1.0 --cudnn-enabled --load-mel-from-disk --training-files=filelists/ljs_mel_text_train_filelist.txt --validation-files=filelists/ljs_mel_text_val_filelist.txt --log-file nvlog.json --anneal-steps 500 1000 1500 --anneal-factor 0.3

View file

@ -1,2 +1,2 @@
mkdir -p output
python -m multiproc train.py -m Tacotron2 -o output/ --amp-run -lr 1e-3 --epochs 1501 -bs 128 --weight-decay 1e-6 --grad-clip-thresh 1.0 --cudnn-enabled --load-mel-from-disk --training-files=filelists/ljs_mel_text_train_filelist.txt --validation-files=filelists/ljs_mel_text_val_filelist.txt --log-file nvlog.json --anneal-steps 500 1000 1500 --anneal-factor 0.3
python -m multiproc train.py -m Tacotron2 -o output/ --amp-run -lr 1e-3 --epochs 1501 -bs 104 --weight-decay 1e-6 --grad-clip-thresh 1.0 --cudnn-enabled --load-mel-from-disk --training-files=filelists/ljs_mel_text_train_filelist.txt --validation-files=filelists/ljs_mel_text_val_filelist.txt --log-file nvlog.json --anneal-steps 500 1000 1500 --anneal-factor 0.3

View file

@ -1,2 +1,2 @@
mkdir -p output
python train.py -m Tacotron2 -o output/ -lr 1e-3 --epochs 1501 -bs 64 --weight-decay 1e-6 --grad-clip-thresh 1.0 --cudnn-enabled --load-mel-from-disk --training-files=filelists/ljs_mel_text_train_filelist.txt --validation-files=filelists/ljs_mel_text_val_filelist.txt --log-file nvlog.json --anneal-steps 500 1000 1500 --anneal-factor 0.1
python train.py -m Tacotron2 -o output/ -lr 1e-3 --epochs 1501 -bs 48 --weight-decay 1e-6 --grad-clip-thresh 1.0 --cudnn-enabled --load-mel-from-disk --training-files=filelists/ljs_mel_text_train_filelist.txt --validation-files=filelists/ljs_mel_text_val_filelist.txt --log-file nvlog.json --anneal-steps 500 1000 1500 --anneal-factor 0.1

View file

@ -1,2 +1,2 @@
mkdir -p output
python -m multiproc train.py -m Tacotron2 -o output/ -lr 1e-3 --epochs 1501 -bs 64 --weight-decay 1e-6 --grad-clip-thresh 1.0 --cudnn-enabled --load-mel-from-disk --training-files=filelists/ljs_mel_text_train_filelist.txt --validation-files=filelists/ljs_mel_text_val_filelist.txt --log-file nvlog.json --anneal-steps 500 1000 1500 --anneal-factor 0.1
python -m multiproc train.py -m Tacotron2 -o output/ -lr 1e-3 --epochs 1501 -bs 48 --weight-decay 1e-6 --grad-clip-thresh 1.0 --cudnn-enabled --load-mel-from-disk --training-files=filelists/ljs_mel_text_train_filelist.txt --validation-files=filelists/ljs_mel_text_val_filelist.txt --log-file nvlog.json --anneal-steps 500 1000 1500 --anneal-factor 0.1

View file

@ -1,2 +1,2 @@
mkdir -p output
python -m multiproc train.py -m Tacotron2 -o output/ -lr 1e-3 --epochs 1501 -bs 64 --weight-decay 1e-6 --grad-clip-thresh 1.0 --cudnn-enabled --load-mel-from-disk --training-files=filelists/ljs_mel_text_train_filelist.txt --validation-files=filelists/ljs_mel_text_val_filelist.txt --log-file nvlog.json --anneal-steps 500 1000 1500 --anneal-factor 0.1
python -m multiproc train.py -m Tacotron2 -o output/ -lr 1e-3 --epochs 1501 -bs 48 --weight-decay 1e-6 --grad-clip-thresh 1.0 --cudnn-enabled --load-mel-from-disk --training-files=filelists/ljs_mel_text_train_filelist.txt --validation-files=filelists/ljs_mel_text_val_filelist.txt --log-file nvlog.json --anneal-steps 500 1000 1500 --anneal-factor 0.1

View file

@ -4,3 +4,4 @@ inflect
librosa
scipy
Unidecode
git+git://github.com/NVIDIA/dllogger.git@26a0f8f1958de2c0c460925ff6102a4d2486d6cc#egg=dllogger

View file

@ -1,2 +1,2 @@
mkdir -p output
python -m multiproc train.py -m Tacotron2 -o ./output/ -lr 1e-3 --epochs 1501 -bs 128 --weight-decay 1e-6 --grad-clip-thresh 1.0 --cudnn-enabled --log-file nvlog.json --anneal-steps 500 1000 1500 --anneal-factor 0.1 --amp-run
python -m multiproc train.py -m Tacotron2 -o ./output/ -lr 1e-3 --epochs 1501 -bs 104 --weight-decay 1e-6 --grad-clip-thresh 1.0 --cudnn-enabled --log-file nvlog.json --anneal-steps 500 1000 1500 --anneal-factor 0.1 --amp-run

View file

@ -236,18 +236,38 @@ def validate(model, criterion, valset, epoch, batch_iter, batch_size,
collate_fn=collate_fn)
val_loss = 0.0
num_iters = 0
val_items_per_sec = 0.0
for i, batch in enumerate(val_loader):
x, y, len_x = batch_to_gpu(batch)
torch.cuda.synchronize()
iter_start_time = time.perf_counter()
x, y, num_items = batch_to_gpu(batch)
y_pred = model(x)
loss = criterion(y_pred, y)
if distributed_run:
reduced_val_loss = reduce_tensor(loss.data, world_size).item()
reduced_num_items = reduce_tensor(num_items.data, 1).item()
else:
reduced_val_loss = loss.item()
reduced_num_items = num_items.item()
val_loss += reduced_val_loss
val_loss = val_loss / (i + 1)
DLLogger.log(step=(epoch, batch_iter, epoch), data={'val_iter_loss': val_loss})
torch.cuda.synchronize()
iter_stop_time = time.perf_counter()
iter_time = iter_stop_time - iter_start_time
items_per_sec = reduced_num_items/iter_time
DLLogger.log(step=(epoch, batch_iter, i), data={'val_items_per_sec': items_per_sec})
val_items_per_sec += items_per_sec
num_iters += 1
val_loss = val_loss/(i + 1)
DLLogger.log(step=(epoch,), data={'val_loss': val_loss})
DLLogger.log(step=(epoch,), data={'val_items_per_sec':
(val_items_per_sec/num_iters if num_iters > 0 else 0.0)})
return val_loss
def adjust_learning_rate(iteration, epoch, optimizer, learning_rate,
@ -307,8 +327,8 @@ def main():
if distributed_run:
init_distributed(args, world_size, local_rank, args.group_name)
run_start_time = time.time()
DLLogger.log(step=tuple(), data={'run_start': run_start_time})
torch.cuda.synchronize()
run_start_time = time.perf_counter()
model_config = models.get_model_config(model_name, args)
model = models.get_model(model_name, model_config,
@ -350,8 +370,14 @@ def main():
model_name, n_frames_per_step)
trainset = data_functions.get_data_loader(
model_name, args.dataset_path, args.training_files, args)
train_sampler = DistributedSampler(trainset) if distributed_run else None
train_loader = DataLoader(trainset, num_workers=1, shuffle=False,
if distributed_run:
train_sampler = DistributedSampler(trainset)
shuffle = False
else:
train_sampler = None
shuffle = True
train_loader = DataLoader(trainset, num_workers=1, shuffle=shuffle,
sampler=train_sampler,
batch_size=args.batch_size, pin_memory=False,
drop_last=True, collate_fn=collate_fn)
@ -362,21 +388,21 @@ def main():
batch_to_gpu = data_functions.get_batch_to_gpu(model_name)
iteration = 0
train_epoch_avg_items_per_sec = 0.0
train_epoch_items_per_sec = 0.0
val_loss = 0.0
num_iters = 0
model.train()
for epoch in range(start_epoch, args.epochs):
epoch_start_time = time.time()
DLLogger.log(step=(epoch,) , data={'train_epoch_start': epoch_start_time})
torch.cuda.synchronize()
epoch_start_time = time.perf_counter()
# used to calculate avg items/sec over epoch
reduced_num_items_epoch = 0
# used to calculate avg loss over epoch
train_epoch_avg_loss = 0.0
train_epoch_avg_items_per_sec = 0.0
train_epoch_items_per_sec = 0.0
num_iters = 0
@ -387,12 +413,11 @@ def main():
train_loader.sampler.set_epoch(epoch)
for i, batch in enumerate(train_loader):
iter_start_time = time.time()
torch.cuda.synchronize()
iter_start_time = time.perf_counter()
DLLogger.log(step=(epoch, i),
data={'glob_iter/iters_per_epoch': str(iteration)+"/"+str(len(train_loader))})
DLLogger.log(step=(epoch, i), data={'train_iter_start': iter_start_time})
start = time.perf_counter()
adjust_learning_rate(iteration, epoch, optimizer, args.learning_rate,
args.anneal_steps, args.anneal_factor, local_rank)
@ -411,7 +436,7 @@ def main():
if np.isnan(reduced_loss):
raise Exception("loss is NaN")
DLLogger.log(step=(epoch,i), data={'train_iter_loss': reduced_loss})
DLLogger.log(step=(epoch,i), data={'train_loss': reduced_loss})
train_epoch_avg_loss += reduced_loss
num_iters += 1
@ -431,25 +456,24 @@ def main():
optimizer.step()
iter_stop_time = time.time()
torch.cuda.synchronize()
iter_stop_time = time.perf_counter()
iter_time = iter_stop_time - iter_start_time
items_per_sec = reduced_num_items/iter_time
train_epoch_avg_items_per_sec += items_per_sec
train_epoch_items_per_sec += items_per_sec
DLLogger.log(step=(epoch, i), data={'train_iter_items/sec': items_per_sec})
DLLogger.log(step=(epoch, i), data={'train_iter_stop': iter_stop_time})
DLLogger.log(step=(epoch, i), data={'train_items_per_sec': items_per_sec})
DLLogger.log(step=(epoch, i), data={'train_iter_time': iter_time})
iteration += 1
epoch_stop_time = time.time()
torch.cuda.synchronize()
epoch_stop_time = time.perf_counter()
epoch_time = epoch_stop_time - epoch_start_time
DLLogger.log(step=(epoch,), data={'train_epoch_items/sec': reduced_num_items_epoch/epoch_time})
DLLogger.log(step=(epoch,), data={'train_epoch_avg_items/sec':
(train_epoch_avg_items_per_sec/num_iters if num_iters > 0 else 0.0)})
DLLogger.log(step=(epoch,), data={'train_epoch_avg_loss': (train_epoch_avg_loss/num_iters if num_iters > 0 else 0.0)})
DLLogger.log(step=(epoch,), data={'epoch_time': epoch_time})
DLLogger.log(step=(epoch,), data={'train_items_per_sec':
(train_epoch_items_per_sec/num_iters if num_iters > 0 else 0.0)})
DLLogger.log(step=(epoch,), data={'train_loss': (train_epoch_avg_loss/num_iters if num_iters > 0 else 0.0)})
DLLogger.log(step=(epoch,), data={'train_epoch_time': epoch_time})
val_loss = validate(model, criterion, valset, epoch, i,
args.batch_size, world_size, collate_fn,
@ -463,14 +487,13 @@ def main():
if local_rank == 0:
DLLogger.flush()
run_stop_time = time.time()
DLLogger.log(step=tuple(), data={'run_stop': run_start_time})
torch.cuda.synchronize()
run_stop_time = time.perf_counter()
run_time = run_stop_time - run_start_time
DLLogger.log(step=tuple(), data={'run_time': run_time})
DLLogger.log(step=tuple(), data={'train_items_per_sec':
(train_epoch_avg_items_per_sec/num_iters if num_iters > 0 else 0.0)})
DLLogger.log(step=tuple(), data={'val_loss': val_loss})
DLLogger.log(step=tuple(), data={'train_items_per_sec':
(train_epoch_items_per_sec/num_iters if num_iters > 0 else 0.0)})
if local_rank == 0:
DLLogger.flush()

View file

@ -265,9 +265,13 @@ def infer_waveglow_trt(waveglow, waveglow_context, mel, measurements):
z_size = (mel_size-1)*stride+(kernel_size-1)+1
z_size = z_size - (kernel_size-stride)
z_size = z_size//n_group
z = torch.randn(batch_size, n_group, z_size, 1).cuda().float()
z = torch.randn(batch_size, n_group, z_size, 1).cuda()
audios = torch.zeros(batch_size, mel_size*stride).cuda()
audios = torch.zeros(batch_size, mel_size*256).cuda()
if "HALF" in str(waveglow.get_binding_dtype(waveglow.get_binding_index("mel"))):
z = z.half()
mel = mel.half()
audios = audios.half()
waveglow_tensors = {
# inputs
@ -330,7 +334,6 @@ def main():
measurements = {}
sequences, sequence_lengths = prepare_input_sequence(texts)
print("|||sequence_lengths", sequence_lengths)
sequences = sequences.to(torch.int32)
sequence_lengths = sequence_lengths.to(torch.int32)
with MeasureTime(measurements, "latency"):
@ -342,7 +345,7 @@ def main():
with encoder_context, decoder_context, postnet_context, waveglow_context:
pass
audios.float()
audios = audios.float()
if args.waveglow_ckpt != "":
with MeasureTime(measurements, "denoiser"):
audios = denoiser(audios, strength=args.denoising_strength).squeeze(1)

View file

@ -2,10 +2,5 @@
for i in {1..1003}
do
python trt/inference_trt.py -i ./phrases/phrase_1_128.txt --encoder ./output/encoder_fp32.engine --decoder ./output/decoder_iter_fp32.engine --postnet ./output/postnet_fp32.engine --waveglow ./output/waveglow_fp32.engine -o output_1/ >> tmp_log_bs1_fp32.log 2>&1
done
for i in {1..1003}
do
python trt/inference_trt.py -i ./phrases/phrase_1_128.txt --encoder ./output/encoder_fp16.engine --decoder ./output/decoder_iter_fp16.engine --postnet ./output/postnet_fp16.engine --waveglow ./output/waveglow_fp16.engine -o output_1/ >> tmp_log_bs1_fp16.log 2>&1
python trt/inference_trt.py -i ./phrases/phrase_1_128.txt --encoder ./output/encoder_fp16.engine --decoder ./output/decoder_iter_fp16.engine --postnet ./output/postnet_fp16.engine --waveglow ./output/waveglow_fp16.engine -o output/ --fp16 >> tmp_log_bs1_fp16.log 2>&1
done

View file

@ -1,181 +0,0 @@
# *****************************************************************************
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# * Neither the name of the NVIDIA CORPORATION nor the
# names of its contributors may be used to endorse or promote products
# derived from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
# DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
# (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
# ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
# *****************************************************************************
import torch
import argparse
import numpy as np
from scipy.io.wavfile import write
import tensorrt as trt
import sys
sys.path.append('./')
import time
import dllogger as DLLogger
from dllogger import StdOutBackend, JSONStreamBackend, Verbosity
from apex import amp
from inference import MeasureTime, prepare_input_sequence
from test_infer import print_stats
from trt.inference_trt import infer_tacotron2_trt, infer_waveglow_trt
from trt.trt_utils import load_engine
import models
from test_infer import print_stats
def parse_args(parser):
"""
Parse commandline arguments.
"""
parser.add_argument('--encoder', type=str, required=True,
help='full path to the Encoder TRT engine')
parser.add_argument('--decoder', type=str, required=True,
help='full path to the DecoderIter TRT engine')
parser.add_argument('--postnet', type=str, required=True,
help='full path to the Postnet TRT engine')
parser.add_argument('--waveglow', type=str, required=True,
help='full path to the WaveGlow TRT engine')
parser.add_argument('-s', '--sigma-infer', default=0.6, type=float)
parser.add_argument('-sr', '--sampling-rate', default=22050, type=int,
help='Sampling rate')
parser.add_argument('--fp16', action='store_true',
help='inference ')
parser.add_argument('--log-file', type=str, default='nvlog.json',
help='Filename for logging')
parser.add_argument('--stft-hop-length', type=int, default=256,
help='STFT hop length for estimating audio length from mel size')
parser.add_argument('--num-iters', type=int, default=10,
help='Number of iterations')
parser.add_argument('-il', '--input-length', type=int, default=64,
help='Input length')
parser.add_argument('-bs', '--batch-size', type=int, default=1,
help='Batch size')
return parser
def main():
"""
Launches text to speech (inference).
Inference is executed on a single GPU.
"""
parser = argparse.ArgumentParser(
description='PyTorch Tacotron 2 Inference')
parser = parse_args(parser)
args, unknown_args = parser.parse_known_args()
TRT_LOGGER = trt.Logger(trt.Logger.WARNING)
encoder = load_engine(args.encoder, TRT_LOGGER)
decoder_iter = load_engine(args.decoder, TRT_LOGGER)
postnet = load_engine(args.postnet, TRT_LOGGER)
waveglow = load_engine(args.waveglow, TRT_LOGGER)
# initialize CUDA state
torch.cuda.init()
# create TRT contexts for each engine
encoder_context = encoder.create_execution_context()
decoder_context = decoder_iter.create_execution_context()
postnet_context = postnet.create_execution_context()
waveglow_context = waveglow.create_execution_context()
DLLogger.init(backends=[JSONStreamBackend(Verbosity.DEFAULT, args.log_file),
StdOutBackend(Verbosity.VERBOSE)])
for k,v in vars(args).items():
DLLogger.log(step="PARAMETER", data={k:v})
DLLogger.log(step="PARAMETER", data={'model_name':'Tacotron2_PyT'})
measurements_all = {"pre_processing": [],
"tacotron2_latency": [],
"waveglow_latency": [],
"latency": [],
"type_conversion": [],
"data_transfer": [],
"storage": [],
"tacotron2_items_per_sec": [],
"waveglow_items_per_sec": [],
"num_mels_per_audio": [],
"throughput": []}
texts = ["The forms of printed letters should be beautiful, and that their arrangement on the page should be reasonable and a help to the shapeliness of the letters themselves. The forms of printed letters should be beautiful, and that their arrangement on the page should be reasonable and a help to the shapeliness of the letters themselves."]
texts = [texts[0][:args.input_length]]
texts = texts*args.batch_size
warmup_iters = 3
for iter in range(args.num_iters):
measurements = {}
with MeasureTime(measurements, "pre_processing"):
sequences_padded, input_lengths = prepare_input_sequence(texts)
with torch.no_grad():
with MeasureTime(measurements, "latency"):
with MeasureTime(measurements, "tacotron2_latency"):
mel, mel_lengths = infer_tacotron2_trt(encoder, decoder_iter, postnet,
encoder_context, decoder_context, postnet_context,
sequences_padded, input_lengths, measurements)
with MeasureTime(measurements, "waveglow_latency"):
audios = infer_waveglow_trt(waveglow, waveglow_context, mel, measurements)
num_mels = mel.size(0)*mel.size(2)
num_samples = audios.size(0)*audios.size(1)
with MeasureTime(measurements, "type_conversion"):
audios = audios.float()
with MeasureTime(measurements, "data_transfer"):
audios = audios.cpu()
with MeasureTime(measurements, "storage"):
audios = audios.numpy()
for i, audio in enumerate(audios):
audio_path = "audio_"+str(i)+".wav"
write(audio_path, args.sampling_rate,
audio[:mel_lengths[i]*args.stft_hop_length])
measurements['tacotron2_items_per_sec'] = num_mels/measurements['tacotron2_latency']
measurements['waveglow_items_per_sec'] = num_samples/measurements['waveglow_latency']
measurements['num_mels_per_audio'] = mel.size(2)
measurements['throughput'] = num_samples/measurements['latency']
if iter >= warmup_iters:
for k,v in measurements.items():
if k in measurements_all.keys():
measurements_all[k].append(v)
DLLogger.log(step=(iter-warmup_iters), data={k: v})
with encoder_context, decoder_context, postnet_context, waveglow_context:
pass
DLLogger.flush()
print_stats(measurements_all)
if __name__ == '__main__':
main()

View file

@ -27,19 +27,6 @@
import tensorrt as trt
def binding_info(engine, context):
for i in range(engine.num_bindings):
print("|||| binding", i)
print("|||| binding_is_input", engine.binding_is_input(i))
print("|||| get_binding_dtype", engine.get_binding_dtype(i))
print("|||| get_binding_name", engine.get_binding_name(i))
print("|||| get_binding_shape", engine.get_binding_shape(i))
print("|||| get_binding_vectorized_dim", engine.get_binding_vectorized_dim(i))
print("|||| all_binding_shapes_specified", context.all_binding_shapes_specified)
print("|||| all_shape_inputs_specified", context.all_shape_inputs_specified)
def is_dimension_dynamic(dim):
return dim is None or dim <= 0
@ -60,7 +47,6 @@ def run_trt_engine(context, engine, tensors):
elif is_shape_dynamic(context.get_binding_shape(idx)):
context.set_binding_shape(idx, tensor.shape)
# binding_info(engine, context)
context.execute_v2(bindings=bindings)
@ -70,6 +56,45 @@ def load_engine(engine_filepath, trt_logger):
return engine
def engine_info(engine_filepath):
TRT_LOGGER = trt.Logger(trt.Logger.WARNING)
engine = load_engine(engine_filepath, TRT_LOGGER)
binding_template = r"""
{btype} {{
name: "{bname}"
data_type: {dtype}
dims: {dims}
}}"""
type_mapping = {"DataType.HALF": "TYPE_FP16",
"DataType.FLOAT": "TYPE_FP32",
"DataType.INT32": "TYPE_INT32"}
print("engine name", engine.name)
print("has_implicit_batch_dimension", engine.has_implicit_batch_dimension)
start_dim = 0 if engine.has_implicit_batch_dimension else 1
print("num_optimization_profiles", engine.num_optimization_profiles)
print("max_batch_size:", engine.max_batch_size)
print("device_memory_size:", engine.device_memory_size)
print("max_workspace_size:", engine.max_workspace_size)
print("num_layers:", engine.num_layers)
for i in range(engine.num_bindings):
btype = "input" if engine.binding_is_input(i) else "output"
bname = engine.get_binding_name(i)
dtype = engine.get_binding_dtype(i)
bdims = engine.get_binding_shape(i)
config_values = {
"btype": btype,
"bname": bname,
"dtype": type_mapping[str(dtype)],
"dims": list(bdims[start_dim:])
}
final_binding_str = binding_template.format_map(config_values)
print(final_binding_str)
def build_engine(model_file, shapes, max_ws=512*1024*1024, fp16=False):
TRT_LOGGER = trt.Logger(trt.Logger.WARNING)
builder = trt.Builder(TRT_LOGGER)
@ -90,9 +115,7 @@ def build_engine(model_file, shapes, max_ws=512*1024*1024, fp16=False):
with open(model_file, 'rb') as model:
parsed = parser.parse(model.read())
for i in range(parser.num_errors):
e = parser.get_error(i)
print("TensorRT ONNX parser error:", parser.get_error(i))
engine = builder.build_engine(network, config=config)
return engine

View file

@ -78,7 +78,7 @@ class Invertible1x1Conv(torch.nn.Module):
return z
else:
# Forward computation
log_det_W = batch_size * n_of_groups * torch.logdet(W.float())
log_det_W = batch_size * n_of_groups * torch.logdet(W.unsqueeze(0).float()).squeeze()
z = self.conv(z)
return z, log_det_W
@ -273,51 +273,6 @@ class WaveGlow(torch.nn.Module):
return audio
def infer_onnx(self, spect, z, sigma=0.9):
spect = self.upsample(spect)
# trim conv artifacts. maybe pad spec to kernel multiple
time_cutoff = self.upsample.kernel_size[0] - self.upsample.stride[0]
spect = spect[:, :, :-time_cutoff]
length_spect_group = spect.size(2)//8
mel_dim = 80
batch_size = spect.size(0)
spect = torch.squeeze(spect, 3)
spect = spect.view((batch_size, mel_dim, length_spect_group, self.n_group))
spect = spect.permute(0, 2, 1, 3)
spect = spect.contiguous()
spect = spect.view((batch_size, length_spect_group, self.n_group*mel_dim))
spect = spect.permute(0, 2, 1)
spect = torch.unsqueeze(spect, 3)
audio = z[:, :self.n_remaining_channels, :, :]
z = z[:, self.n_remaining_channels:self.n_group, :, :]
audio = sigma*audio
for k in reversed(range(self.n_flows)):
n_half = int(audio.size(1) / 2)
audio_0 = audio[:, :n_half, :, :]
audio_1 = audio[:, n_half:(n_half+n_half), :, :]
output = self.WN[k]((audio_0, spect))
s = output[:, n_half:(n_half+n_half), :, :]
b = output[:, :n_half, :, :]
audio_1 = (audio_1 - b) / torch.exp(s)
audio = torch.cat([audio_0, audio_1], 1)
audio = self.convinv[k](audio)
if k % self.n_early_every == 0 and k > 0:
audio = torch.cat((z[:, :self.n_early_size, :, :], audio), 1)
z = z[:, self.n_early_size:self.n_group, :, :]
audio = torch.squeeze(audio, 3)
audio = audio.permute(0,2,1).contiguous().view(batch_size, (length_spect_group * self.n_group))
return audio
@staticmethod
def remove_weightnorm(model):
waveglow = model