4f2ea4913c
* squash Signed-off-by: Jason <jasoli@nvidia.com> * add comments Signed-off-by: Jason <jasoli@nvidia.com> * style and cleanup Signed-off-by: Jason <jasoli@nvidia.com> * cleanup Signed-off-by: Jason <jasoli@nvidia.com> * add new test file Signed-off-by: Jason <jasoli@nvidia.com> * syntax Signed-off-by: Jason <jasoli@nvidia.com> * style Signed-off-by: Jason <jasoli@nvidia.com> * typo Signed-off-by: Jason <jasoli@nvidia.com> * update Signed-off-by: Jason <jasoli@nvidia.com> * update Signed-off-by: Jason <jasoli@nvidia.com> * update Signed-off-by: Jason <jasoli@nvidia.com> * try again Signed-off-by: Jason <jasoli@nvidia.com> * wip Signed-off-by: Jason <jasoli@nvidia.com> * style; ci should fail Signed-off-by: Jason <jasoli@nvidia.com> * final Signed-off-by: Jason <jasoli@nvidia.com>
375 lines
12 KiB
Python
375 lines
12 KiB
Python
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
from enum import Enum
|
|
from typing import Dict, Sequence
|
|
|
|
import librosa
|
|
import matplotlib.pylab as plt
|
|
import numpy as np
|
|
import torch
|
|
from numba import jit, prange
|
|
from numpy import ndarray
|
|
from pesq import pesq
|
|
from pystoi import stoi
|
|
|
|
from nemo.utils import logging
|
|
|
|
try:
|
|
from pytorch_lightning.utilities import rank_zero_only
|
|
except ModuleNotFoundError:
|
|
from functools import wraps
|
|
|
|
def rank_zero_only(fn):
|
|
@wraps(fn)
|
|
def wrapped_fn(*args, **kwargs):
|
|
logging.error(
|
|
f"Function {fn} requires lighting to be installed, but it was not found. Please install lightning first"
|
|
)
|
|
exit(1)
|
|
|
|
|
|
class OperationMode(Enum):
|
|
"""Training or Inference (Evaluation) mode"""
|
|
|
|
training = 0
|
|
validation = 1
|
|
infer = 2
|
|
|
|
|
|
def binarize_attention(attn, in_len, out_len):
|
|
"""Convert soft attention matrix to hard attention matrix.
|
|
|
|
Args:
|
|
attn (torch.Tensor): B x 1 x max_mel_len x max_text_len. Soft attention matrix.
|
|
in_len (torch.Tensor): B. Lengths of texts.
|
|
out_len (torch.Tensor): B. Lengths of spectrograms.
|
|
|
|
Output:
|
|
attn_out (torch.Tensor): B x 1 x max_mel_len x max_text_len. Hard attention matrix, final dim max_text_len should sum to 1.
|
|
"""
|
|
b_size = attn.shape[0]
|
|
with torch.no_grad():
|
|
attn_cpu = attn.data.cpu().numpy()
|
|
attn_out = torch.zeros_like(attn)
|
|
for ind in range(b_size):
|
|
hard_attn = mas(attn_cpu[ind, 0, : out_len[ind], : in_len[ind]])
|
|
attn_out[ind, 0, : out_len[ind], : in_len[ind]] = torch.tensor(hard_attn, device=attn.device)
|
|
return attn_out
|
|
|
|
|
|
def binarize_attention_parallel(attn, in_lens, out_lens):
|
|
"""For training purposes only. Binarizes attention with MAS.
|
|
These will no longer recieve a gradient.
|
|
|
|
Args:
|
|
attn: B x 1 x max_mel_len x max_text_len
|
|
"""
|
|
with torch.no_grad():
|
|
attn_cpu = attn.data.cpu().numpy()
|
|
attn_out = b_mas(attn_cpu, in_lens.cpu().numpy(), out_lens.cpu().numpy(), width=1)
|
|
return torch.from_numpy(attn_out).to(attn.get_device())
|
|
|
|
|
|
def get_mask_from_lengths(lengths, max_len=None):
|
|
if not max_len:
|
|
max_len = torch.max(lengths).item()
|
|
ids = torch.arange(0, max_len, device=lengths.device, dtype=torch.long)
|
|
mask = (ids < lengths.unsqueeze(1)).bool()
|
|
return mask
|
|
|
|
|
|
@jit(nopython=True)
|
|
def mas(attn_map, width=1):
|
|
# assumes mel x text
|
|
opt = np.zeros_like(attn_map)
|
|
attn_map = np.log(attn_map)
|
|
attn_map[0, 1:] = -np.inf
|
|
log_p = np.zeros_like(attn_map)
|
|
log_p[0, :] = attn_map[0, :]
|
|
prev_ind = np.zeros_like(attn_map, dtype=np.int64)
|
|
for i in range(1, attn_map.shape[0]):
|
|
for j in range(attn_map.shape[1]): # for each text dim
|
|
prev_j = np.arange(max(0, j - width), j + 1)
|
|
prev_log = np.array([log_p[i - 1, prev_idx] for prev_idx in prev_j])
|
|
|
|
ind = np.argmax(prev_log)
|
|
log_p[i, j] = attn_map[i, j] + prev_log[ind]
|
|
prev_ind[i, j] = prev_j[ind]
|
|
|
|
# now backtrack
|
|
curr_text_idx = attn_map.shape[1] - 1
|
|
for i in range(attn_map.shape[0] - 1, -1, -1):
|
|
opt[i, curr_text_idx] = 1
|
|
curr_text_idx = prev_ind[i, curr_text_idx]
|
|
opt[0, curr_text_idx] = 1
|
|
|
|
assert opt.sum(0).all()
|
|
assert opt.sum(1).all()
|
|
|
|
return opt
|
|
|
|
|
|
@jit(nopython=True)
|
|
def mas_width1(attn_map):
|
|
"""mas with hardcoded width=1"""
|
|
# assumes mel x text
|
|
opt = np.zeros_like(attn_map)
|
|
attn_map = np.log(attn_map)
|
|
attn_map[0, 1:] = -np.inf
|
|
log_p = np.zeros_like(attn_map)
|
|
log_p[0, :] = attn_map[0, :]
|
|
prev_ind = np.zeros_like(attn_map, dtype=np.int64)
|
|
for i in range(1, attn_map.shape[0]):
|
|
for j in range(attn_map.shape[1]): # for each text dim
|
|
prev_log = log_p[i - 1, j]
|
|
prev_j = j
|
|
|
|
if j - 1 >= 0 and log_p[i - 1, j - 1] >= log_p[i - 1, j]:
|
|
prev_log = log_p[i - 1, j - 1]
|
|
prev_j = j - 1
|
|
|
|
log_p[i, j] = attn_map[i, j] + prev_log
|
|
prev_ind[i, j] = prev_j
|
|
|
|
# now backtrack
|
|
curr_text_idx = attn_map.shape[1] - 1
|
|
for i in range(attn_map.shape[0] - 1, -1, -1):
|
|
opt[i, curr_text_idx] = 1
|
|
curr_text_idx = prev_ind[i, curr_text_idx]
|
|
opt[0, curr_text_idx] = 1
|
|
return opt
|
|
|
|
|
|
@jit(nopython=True, parallel=True)
|
|
def b_mas(b_attn_map, in_lens, out_lens, width=1):
|
|
assert width == 1
|
|
attn_out = np.zeros_like(b_attn_map)
|
|
|
|
for b in prange(b_attn_map.shape[0]):
|
|
out = mas_width1(b_attn_map[b, 0, : out_lens[b], : in_lens[b]])
|
|
attn_out[b, 0, : out_lens[b], : in_lens[b]] = out
|
|
return attn_out
|
|
|
|
|
|
def griffin_lim(magnitudes, n_iters=50, n_fft=1024):
|
|
"""
|
|
Griffin-Lim algorithm to convert magnitude spectrograms to audio signals
|
|
"""
|
|
phase = np.exp(2j * np.pi * np.random.rand(*magnitudes.shape))
|
|
complex_spec = magnitudes * phase
|
|
signal = librosa.istft(complex_spec)
|
|
if not np.isfinite(signal).all():
|
|
logging.warning("audio was not finite, skipping audio saving")
|
|
return np.array([0])
|
|
|
|
for _ in range(n_iters):
|
|
_, phase = librosa.magphase(librosa.stft(signal, n_fft=n_fft))
|
|
complex_spec = magnitudes * phase
|
|
signal = librosa.istft(complex_spec)
|
|
return signal
|
|
|
|
|
|
@rank_zero_only
|
|
def log_audio_to_tb(
|
|
swriter,
|
|
spect,
|
|
name,
|
|
step,
|
|
griffin_lim_mag_scale=1024,
|
|
griffin_lim_power=1.2,
|
|
sr=22050,
|
|
n_fft=1024,
|
|
n_mels=80,
|
|
fmax=8000,
|
|
):
|
|
filterbank = librosa.filters.mel(sr=sr, n_fft=n_fft, n_mels=n_mels, fmax=fmax)
|
|
log_mel = spect.data.cpu().numpy().T
|
|
mel = np.exp(log_mel)
|
|
magnitude = np.dot(mel, filterbank) * griffin_lim_mag_scale
|
|
audio = griffin_lim(magnitude.T ** griffin_lim_power)
|
|
swriter.add_audio(name, audio / max(np.abs(audio)), step, sample_rate=sr)
|
|
|
|
|
|
@rank_zero_only
|
|
def tacotron2_log_to_tb_func(
|
|
swriter,
|
|
tensors,
|
|
step,
|
|
tag="train",
|
|
log_images=False,
|
|
log_images_freq=1,
|
|
add_audio=True,
|
|
griffin_lim_mag_scale=1024,
|
|
griffin_lim_power=1.2,
|
|
sr=22050,
|
|
n_fft=1024,
|
|
n_mels=80,
|
|
fmax=8000,
|
|
):
|
|
_, spec_target, mel_postnet, gate, gate_target, alignments = tensors
|
|
if log_images and step % log_images_freq == 0:
|
|
swriter.add_image(
|
|
f"{tag}_alignment", plot_alignment_to_numpy(alignments[0].data.cpu().numpy().T), step, dataformats="HWC",
|
|
)
|
|
swriter.add_image(
|
|
f"{tag}_mel_target", plot_spectrogram_to_numpy(spec_target[0].data.cpu().numpy()), step, dataformats="HWC",
|
|
)
|
|
swriter.add_image(
|
|
f"{tag}_mel_predicted",
|
|
plot_spectrogram_to_numpy(mel_postnet[0].data.cpu().numpy()),
|
|
step,
|
|
dataformats="HWC",
|
|
)
|
|
swriter.add_image(
|
|
f"{tag}_gate",
|
|
plot_gate_outputs_to_numpy(gate_target[0].data.cpu().numpy(), torch.sigmoid(gate[0]).data.cpu().numpy(),),
|
|
step,
|
|
dataformats="HWC",
|
|
)
|
|
if add_audio:
|
|
filterbank = librosa.filters.mel(sr=sr, n_fft=n_fft, n_mels=n_mels, fmax=fmax)
|
|
log_mel = mel_postnet[0].data.cpu().numpy().T
|
|
mel = np.exp(log_mel)
|
|
magnitude = np.dot(mel, filterbank) * griffin_lim_mag_scale
|
|
audio = griffin_lim(magnitude.T ** griffin_lim_power)
|
|
swriter.add_audio(f"audio/{tag}_predicted", audio / max(np.abs(audio)), step, sample_rate=sr)
|
|
|
|
log_mel = spec_target[0].data.cpu().numpy().T
|
|
mel = np.exp(log_mel)
|
|
magnitude = np.dot(mel, filterbank) * griffin_lim_mag_scale
|
|
audio = griffin_lim(magnitude.T ** griffin_lim_power)
|
|
swriter.add_audio(f"audio/{tag}_target", audio / max(np.abs(audio)), step, sample_rate=sr)
|
|
|
|
|
|
def plot_alignment_to_numpy(alignment, info=None):
|
|
fig, ax = plt.subplots(figsize=(6, 4))
|
|
im = ax.imshow(alignment, aspect='auto', origin='lower', interpolation='none')
|
|
fig.colorbar(im, ax=ax)
|
|
xlabel = 'Decoder timestep'
|
|
if info is not None:
|
|
xlabel += '\n\n' + info
|
|
plt.xlabel(xlabel)
|
|
plt.ylabel('Encoder timestep')
|
|
plt.tight_layout()
|
|
|
|
fig.canvas.draw()
|
|
data = save_figure_to_numpy(fig)
|
|
plt.close()
|
|
return data
|
|
|
|
|
|
def plot_spectrogram_to_numpy(spectrogram):
|
|
spectrogram = spectrogram.astype(np.float32)
|
|
fig, ax = plt.subplots(figsize=(12, 3))
|
|
im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation='none')
|
|
plt.colorbar(im, ax=ax)
|
|
plt.xlabel("Frames")
|
|
plt.ylabel("Channels")
|
|
plt.tight_layout()
|
|
|
|
fig.canvas.draw()
|
|
data = save_figure_to_numpy(fig)
|
|
plt.close()
|
|
return data
|
|
|
|
|
|
def plot_gate_outputs_to_numpy(gate_targets, gate_outputs):
|
|
fig, ax = plt.subplots(figsize=(12, 3))
|
|
ax.scatter(
|
|
range(len(gate_targets)), gate_targets, alpha=0.5, color='green', marker='+', s=1, label='target',
|
|
)
|
|
ax.scatter(
|
|
range(len(gate_outputs)), gate_outputs, alpha=0.5, color='red', marker='.', s=1, label='predicted',
|
|
)
|
|
|
|
plt.xlabel("Frames (Green target, Red predicted)")
|
|
plt.ylabel("Gate State")
|
|
plt.tight_layout()
|
|
|
|
fig.canvas.draw()
|
|
data = save_figure_to_numpy(fig)
|
|
plt.close()
|
|
return data
|
|
|
|
|
|
def save_figure_to_numpy(fig):
|
|
# save it to a numpy array.
|
|
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
|
|
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
|
|
return data
|
|
|
|
|
|
@rank_zero_only
|
|
def waveglow_log_to_tb_func(
|
|
swriter, tensors, step, tag="train", n_fft=1024, hop_length=256, window="hann", mel_fb=None,
|
|
):
|
|
_, audio_pred, spec_target, mel_length = tensors
|
|
mel_length = mel_length[0]
|
|
spec_target = spec_target[0].data.cpu().numpy()[:, :mel_length]
|
|
swriter.add_image(
|
|
f"{tag}_mel_target", plot_spectrogram_to_numpy(spec_target), step, dataformats="HWC",
|
|
)
|
|
if mel_fb is not None:
|
|
mag, _ = librosa.core.magphase(
|
|
librosa.core.stft(
|
|
np.nan_to_num(audio_pred[0].cpu().detach().numpy()), n_fft=n_fft, hop_length=hop_length, window=window,
|
|
)
|
|
)
|
|
mel_pred = np.matmul(mel_fb.cpu().numpy(), mag).squeeze()
|
|
log_mel_pred = np.log(np.clip(mel_pred, a_min=1e-5, a_max=None))
|
|
swriter.add_image(
|
|
f"{tag}_mel_predicted", plot_spectrogram_to_numpy(log_mel_pred[:, :mel_length]), step, dataformats="HWC",
|
|
)
|
|
|
|
|
|
def remove(conv_list):
|
|
new_conv_list = torch.nn.ModuleList()
|
|
for old_conv in conv_list:
|
|
old_conv = torch.nn.utils.remove_weight_norm(old_conv)
|
|
new_conv_list.append(old_conv)
|
|
return new_conv_list
|
|
|
|
|
|
def eval_tts_scores(
|
|
y_clean: ndarray, y_est: ndarray, T_ys: Sequence[int] = (0,), sampling_rate=22050
|
|
) -> Dict[str, float]:
|
|
"""
|
|
calculate metric using EvalModule. y can be a batch.
|
|
Args:
|
|
y_clean: real audio
|
|
y_est: estimated audio
|
|
T_ys: length of the non-zero parts of the histograms
|
|
sampling_rate: The used Sampling rate.
|
|
|
|
Returns:
|
|
A dictionary mapping scoring systems (string) to numerical scores.
|
|
1st entry: 'STOI'
|
|
2nd entry: 'PESQ'
|
|
"""
|
|
|
|
if y_clean.ndim == 1:
|
|
y_clean = y_clean[np.newaxis, ...]
|
|
y_est = y_est[np.newaxis, ...]
|
|
if T_ys == (0,):
|
|
T_ys = (y_clean.shape[1],) * y_clean.shape[0]
|
|
|
|
clean = y_clean[0, : T_ys[0]]
|
|
estimated = y_est[0, : T_ys[0]]
|
|
stoi_score = stoi(clean, estimated, sampling_rate, extended=False)
|
|
pesq_score = pesq(16000, np.asarray(clean), estimated, 'wb')
|
|
## fs was set 16,000, as pesq lib doesnt currently support felxible fs.
|
|
|
|
return {'STOI': stoi_score, 'PESQ': pesq_score}
|