[FastPitch/PyT] updated checkpoints, multispeaker and text processing

This commit is contained in:
kkudrynski 2020-10-30 15:03:27 +01:00
parent 03c5a9fd9b
commit bec82593f5
36 changed files with 1470 additions and 814 deletions

View file

@ -1,8 +1,15 @@
*.swp
*.swo
*.pyc
__pycache__
scripts_joc/
runs*/ runs*/
LJSpeech-1.1/ LJSpeech-1.1/
output* output*
scripts_joc/
tests/
*.pyc
__pycache__
.idea/
.DS_Store
*.swp
*.swo
*.swn

View file

@ -488,11 +488,11 @@ The `scripts/train.sh` script is configured for 8x GPU with at least 16GB of mem
``` ```
In a single accumulated step, there are `batch_size x gradient_accumulation_steps x GPUs = 256` examples being processed in parallel. With a smaller number of GPUs, increase `--gradient_accumulation_steps` to keep this relation satisfied, e.g., through env variables In a single accumulated step, there are `batch_size x gradient_accumulation_steps x GPUs = 256` examples being processed in parallel. With a smaller number of GPUs, increase `--gradient_accumulation_steps` to keep this relation satisfied, e.g., through env variables
```bash ```bash
NGPU=4 GRAD_ACC=2 bash scripts/train.sh NUM_GPUS=4 GRAD_ACCUMULATION=2 bash scripts/train.sh
``` ```
With automatic mixed precision (AMP), a larger batch size fits in 16GB of memory: With automatic mixed precision (AMP), a larger batch size fits in 16GB of memory:
```bash ```bash
NGPU=4 GRAD_ACC=1 BS=64 AMP=true bash scripts/train.sh NUM_GPUS=4 GRAD_ACCUMULATION=1 BS=64 AMP=true bash scripts/train.sh
``` ```
### Inference process ### Inference process
@ -545,18 +545,18 @@ To benchmark the training performance on a specific batch size, run:
* NVIDIA DGX A100 (8x A100 40GB) * NVIDIA DGX A100 (8x A100 40GB)
```bash ```bash
AMP=true NGPU=1 BS=128 GRAD_ACC=2 EPOCHS=10 bash scripts/train.sh AMP=true NUM_GPUS=1 BS=128 GRAD_ACCUMULATION=2 EPOCHS=10 bash scripts/train.sh
AMP=true NGPU=8 BS=32 GRAD_ACC=1 EPOCHS=10 bash scripts/train.sh AMP=true NUM_GPUS=8 BS=32 GRAD_ACCUMULATION=1 EPOCHS=10 bash scripts/train.sh
NGPU=1 BS=128 GRAD_ACC=2 EPOCHS=10 bash scripts/train.sh NUM_GPUS=1 BS=128 GRAD_ACCUMULATION=2 EPOCHS=10 bash scripts/train.sh
NGPU=8 BS=32 GRAD_ACC=1 EPOCHS=10 bash scripts/train.sh NUM_GPUS=8 BS=32 GRAD_ACCUMULATION=1 EPOCHS=10 bash scripts/train.sh
``` ```
* NVIDIA DGX-1 (8x V100 16GB) * NVIDIA DGX-1 (8x V100 16GB)
```bash ```bash
AMP=true NGPU=1 BS=64 GRAD_ACC=4 EPOCHS=10 bash scripts/train.sh AMP=true NUM_GPUS=1 BS=64 GRAD_ACCUMULATION=4 EPOCHS=10 bash scripts/train.sh
AMP=true NGPU=8 BS=32 GRAD_ACC=1 EPOCHS=10 bash scripts/train.sh AMP=true NUM_GPUS=8 BS=32 GRAD_ACCUMULATION=1 EPOCHS=10 bash scripts/train.sh
NGPU=1 BS=32 GRAD_ACC=8 EPOCHS=10 bash scripts/train.sh NUM_GPUS=1 BS=32 GRAD_ACCUMULATION=8 EPOCHS=10 bash scripts/train.sh
NGPU=8 BS=32 GRAD_ACC=1 EPOCHS=10 bash scripts/train.sh NUM_GPUS=8 BS=32 GRAD_ACCUMULATION=1 EPOCHS=10 bash scripts/train.sh
``` ```
Each of these scripts runs for 10 epochs and for each epoch measures the Each of these scripts runs for 10 epochs and for each epoch measures the
@ -569,12 +569,12 @@ To benchmark the inference performance on a specific batch size, run:
* For FP16 * For FP16
```bash ```bash
AMP=true BS_SEQ=”1 4 8” REPEATS=100 bash scripts/inference_benchmark.sh AMP=true BS_SEQUENCE=”1 4 8” REPEATS=100 bash scripts/inference_benchmark.sh
``` ```
* For FP32 or TF32 * For FP32 or TF32
```bash ```bash
BS_SEQ=”1 4 8” REPEATS=100 bash scripts/inference_benchmark.sh BS_SEQUENCE=”1 4 8” REPEATS=100 bash scripts/inference_benchmark.sh
``` ```
The output log files will contain performance numbers for the FastPitch model The output log files will contain performance numbers for the FastPitch model
@ -726,6 +726,10 @@ The input utterance has 128 characters, synthesized audio has 8.05 s.
### Changelog ### Changelog
October 2020
- Added multispeaker capabilities
- Updated text processing module
June 2020 June 2020
- Updated performance tables to include A100 results - Updated performance tables to include A100 results

View file

@ -1,74 +1,3 @@
""" from https://github.com/keithito/tacotron """ from .cmudict import CMUDict
import re
from common.text import cleaners
from common.text.symbols import symbols
cmudict = CMUDict()
# Mappings from symbol to numeric ID and vice versa:
_symbol_to_id = {s: i for i, s in enumerate(symbols)}
_id_to_symbol = {i: s for i, s in enumerate(symbols)}
# Regular expression matching text enclosed in curly braces:
_curly_re = re.compile(r'(.*?)\{(.+?)\}(.*)')
def text_to_sequence(text, cleaner_names):
'''Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
The text can optionally have ARPAbet sequences enclosed in curly braces embedded
in it. For example, "Turn left on {HH AW1 S S T AH0 N} Street."
Args:
text: string to convert to a sequence
cleaner_names: names of the cleaner functions to run the text through
Returns:
List of integers corresponding to the symbols in the text
'''
sequence = []
# Check for curly braces and treat their contents as ARPAbet:
while len(text):
m = _curly_re.match(text)
if not m:
sequence += _symbols_to_sequence(_clean_text(text, cleaner_names))
break
sequence += _symbols_to_sequence(_clean_text(m.group(1), cleaner_names))
sequence += _arpabet_to_sequence(m.group(2))
text = m.group(3)
return sequence
def sequence_to_text(sequence):
'''Converts a sequence of IDs back to a string'''
result = ''
for symbol_id in sequence:
if symbol_id in _id_to_symbol:
s = _id_to_symbol[symbol_id]
# Enclose ARPAbet back in curly braces:
if len(s) > 1 and s[0] == '@':
s = '{%s}' % s[1:]
result += s
return result.replace('}{', ' ')
def _clean_text(text, cleaner_names):
for name in cleaner_names:
cleaner = getattr(cleaners, name)
if not cleaner:
raise Exception('Unknown cleaner: %s' % name)
text = cleaner(text)
return text
def _symbols_to_sequence(symbols):
return [_symbol_to_id[s] for s in symbols if _should_keep_symbol(s)]
def _arpabet_to_sequence(text):
return _symbols_to_sequence(['@' + s for s in text.split()])
def _should_keep_symbol(s):
return s in _symbol_to_id and s is not '_' and s is not '~'

View file

@ -0,0 +1,58 @@
import re
_no_period_re = re.compile(r'(No[.])(?=[ ]?[0-9])')
_percent_re = re.compile(r'([ ]?[%])')
_half_re = re.compile('([0-9]½)|(½)')
# List of (regular expression, replacement) pairs for abbreviations:
_abbreviations = [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1]) for x in [
('mrs', 'misess'),
('ms', 'miss'),
('mr', 'mister'),
('dr', 'doctor'),
('st', 'saint'),
('co', 'company'),
('jr', 'junior'),
('maj', 'major'),
('gen', 'general'),
('drs', 'doctors'),
('rev', 'reverend'),
('lt', 'lieutenant'),
('hon', 'honorable'),
('sgt', 'sergeant'),
('capt', 'captain'),
('esq', 'esquire'),
('ltd', 'limited'),
('col', 'colonel'),
('ft', 'fort'),
('sen', 'senator'),
]]
def _expand_no_period(m):
word = m.group(0)
if word[0] == 'N':
return 'Number'
return 'number'
def _expand_percent(m):
return ' percent'
def _expand_half(m):
word = m.group(1)
if word is None:
return 'half'
return word[0] + ' and a half'
def normalize_abbreviations(text):
text = re.sub(_no_period_re, _expand_no_period, text)
text = re.sub(_percent_re, _expand_percent, text)
text = re.sub(_half_re, _expand_half, text)
for regex, replacement in _abbreviations:
text = re.sub(regex, replacement, text)
return text

View file

@ -0,0 +1,67 @@
import re
from . import cmudict
_letter_to_arpabet = {
'A': 'EY1',
'B': 'B IY1',
'C': 'S IY1',
'D': 'D IY1',
'E': 'IY1',
'F': 'EH1 F',
'G': 'JH IY1',
'H': 'EY1 CH',
'I': 'AY1',
'J': 'JH EY1',
'K': 'K EY1',
'L': 'EH1 L',
'M': 'EH1 M',
'N': 'EH1 N',
'O': 'OW1',
'P': 'P IY1',
'Q': 'K Y UW1',
'R': 'AA1 R',
'S': 'EH1 S',
'T': 'T IY1',
'U': 'Y UW1',
'V': 'V IY1',
'X': 'EH1 K S',
'Y': 'W AY1',
'W': 'D AH1 B AH0 L Y UW0',
'Z': 'Z IY1',
's': 'Z'
}
# must ignore roman numerals
# _acronym_re = re.compile(r'([A-Z][A-Z]+)s?|([A-Z]\.([A-Z]\.)+s?)')
_acronym_re = re.compile(r'([A-Z][A-Z]+)s?')
def _expand_acronyms(m, add_spaces=True):
acronym = m.group(0)
# remove dots if they exist
acronym = re.sub('\.', '', acronym)
acronym = "".join(acronym.split())
arpabet = cmudict.lookup(acronym)
if arpabet is None:
acronym = list(acronym)
arpabet = ["{" + _letter_to_arpabet[letter] + "}" for letter in acronym]
# temporary fix
if arpabet[-1] == '{Z}' and len(arpabet) > 1:
arpabet[-2] = arpabet[-2][:-1] + ' ' + arpabet[-1][1:]
del arpabet[-1]
arpabet = ' '.join(arpabet)
elif len(arpabet) == 1:
arpabet = "{" + arpabet[0] + "}"
else:
arpabet = acronym
return arpabet
def normalize_acronyms(text):
text = re.sub(_acronym_re, _expand_acronyms, text)
return text

View file

@ -1,90 +1,92 @@
""" from https://github.com/keithito/tacotron """ """ adapted from https://github.com/keithito/tacotron """
''' '''
Cleaners are transformations that run over the input text at both training and eval time. Cleaners are transformations that run over the input text at both training and eval time.
Cleaners can be selected by passing a comma-delimited list of cleaner names as the "cleaners" Cleaners can be selected by passing a comma-delimited list of cleaner names as the "cleaners"
hyperparameter. Some cleaners are English-specific. You'll typically want to use: hyperparameter. Some cleaners are English-specific. You'll typically want to use:
1. "english_cleaners" for English text 1. "english_cleaners" for English text
2. "transliteration_cleaners" for non-English text that can be transliterated to ASCII using 2. "transliteration_cleaners" for non-English text that can be transliterated to ASCII using
the Unidecode library (https://pypi.python.org/pypi/Unidecode) the Unidecode library (https://pypi.python.org/pypi/Unidecode)
3. "basic_cleaners" if you do not want to transliterate (in this case, you should also update 3. "basic_cleaners" if you do not want to transliterate (in this case, you should also update
the symbols in symbols.py to match your data). the symbols in symbols.py to match your data).
''' '''
import re import re
from unidecode import unidecode from unidecode import unidecode
from .numbers import normalize_numbers from .numerical import normalize_numbers
from .acronyms import normalize_acronyms
from .datestime import normalize_datestime
from .letters_and_numbers import normalize_letters_and_numbers
from .abbreviations import normalize_abbreviations
# Regular expression matching whitespace: # Regular expression matching whitespace:
_whitespace_re = re.compile(r'\s+') _whitespace_re = re.compile(r'\s+')
# List of (regular expression, replacement) pairs for abbreviations:
_abbreviations = [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1]) for x in [
('mrs', 'misess'),
('mr', 'mister'),
('dr', 'doctor'),
('st', 'saint'),
('co', 'company'),
('jr', 'junior'),
('maj', 'major'),
('gen', 'general'),
('drs', 'doctors'),
('rev', 'reverend'),
('lt', 'lieutenant'),
('hon', 'honorable'),
('sgt', 'sergeant'),
('capt', 'captain'),
('esq', 'esquire'),
('ltd', 'limited'),
('col', 'colonel'),
('ft', 'fort'),
]]
def expand_abbreviations(text): def expand_abbreviations(text):
for regex, replacement in _abbreviations: return normalize_abbreviations(text)
text = re.sub(regex, replacement, text)
return text
def expand_numbers(text): def expand_numbers(text):
return normalize_numbers(text) return normalize_numbers(text)
def expand_acronyms(text):
return normalize_acronyms(text)
def expand_datestime(text):
return normalize_datestime(text)
def expand_letters_and_numbers(text):
return normalize_letters_and_numbers(text)
def lowercase(text): def lowercase(text):
return text.lower() return text.lower()
def collapse_whitespace(text): def collapse_whitespace(text):
return re.sub(_whitespace_re, ' ', text) return re.sub(_whitespace_re, ' ', text)
def separate_acronyms(text):
text = re.sub(r"([0-9]+)([a-zA-Z]+)", r"\1 \2", text)
text = re.sub(r"([a-zA-Z]+)([0-9]+)", r"\1 \2", text)
return text
def convert_to_ascii(text): def convert_to_ascii(text):
return unidecode(text) return unidecode(text)
def basic_cleaners(text): def basic_cleaners(text):
'''Basic pipeline that lowercases and collapses whitespace without transliteration.''' '''Basic pipeline that collapses whitespace without transliteration.'''
text = lowercase(text) text = lowercase(text)
text = collapse_whitespace(text) text = collapse_whitespace(text)
return text return text
def transliteration_cleaners(text): def transliteration_cleaners(text):
'''Pipeline for non-English text that transliterates to ASCII.''' '''Pipeline for non-English text that transliterates to ASCII.'''
text = convert_to_ascii(text) text = convert_to_ascii(text)
text = lowercase(text) text = lowercase(text)
text = collapse_whitespace(text) text = collapse_whitespace(text)
return text return text
def english_cleaners_post_chars(word):
return word
def english_cleaners(text): def english_cleaners(text):
'''Pipeline for English text, including number and abbreviation expansion.''' '''Pipeline for English text, with number and abbreviation expansion.'''
text = convert_to_ascii(text) text = convert_to_ascii(text)
text = lowercase(text) text = lowercase(text)
text = expand_numbers(text) text = expand_numbers(text)
text = expand_abbreviations(text) text = expand_abbreviations(text)
text = collapse_whitespace(text) text = collapse_whitespace(text)
return text return text

View file

@ -18,7 +18,18 @@ _valid_symbol_set = set(valid_symbols)
class CMUDict: class CMUDict:
'''Thin wrapper around CMUDict data. http://www.speech.cs.cmu.edu/cgi-bin/cmudict''' '''Thin wrapper around CMUDict data. http://www.speech.cs.cmu.edu/cgi-bin/cmudict'''
def __init__(self, file_or_path, keep_ambiguous=True): def __init__(self, file_or_path=None, heteronyms_path=None, keep_ambiguous=True):
if file_or_path is None:
self._entries = {}
else:
self.initialize(file_or_path, keep_ambiguous)
if heteronyms_path is None:
self.heteronyms = []
else:
self.heteronyms = set(lines_to_list(heteronyms_path))
def initialize(self, file_or_path, keep_ambiguous=True):
if isinstance(file_or_path, str): if isinstance(file_or_path, str):
with open(file_or_path, encoding='latin-1') as f: with open(file_or_path, encoding='latin-1') as f:
entries = _parse_cmudict(f) entries = _parse_cmudict(f)
@ -28,17 +39,18 @@ class CMUDict:
entries = {word: pron for word, pron in entries.items() if len(pron) == 1} entries = {word: pron for word, pron in entries.items() if len(pron) == 1}
self._entries = entries self._entries = entries
def __len__(self): def __len__(self):
if len(self._entries) == 0:
raise ValueError("CMUDict not initialized")
return len(self._entries) return len(self._entries)
def lookup(self, word): def lookup(self, word):
'''Returns list of ARPAbet pronunciations of the given word.''' '''Returns list of ARPAbet pronunciations of the given word.'''
if len(self._entries) == 0:
raise ValueError("CMUDict not initialized")
return self._entries.get(word.upper()) return self._entries.get(word.upper())
_alt_re = re.compile(r'\([0-9]+\)') _alt_re = re.compile(r'\([0-9]+\)')

View file

@ -0,0 +1,22 @@
import re
_ampm_re = re.compile(
r'([0-9]|0[0-9]|1[0-9]|2[0-3]):?([0-5][0-9])?\s*([AaPp][Mm]\b)')
def _expand_ampm(m):
matches = list(m.groups(0))
txt = matches[0]
txt = txt if int(matches[1]) == 0 else txt + ' ' + matches[1]
if matches[2][0].lower() == 'a':
txt += ' a.m.'
elif matches[2][0].lower() == 'p':
txt += ' p.m.'
return txt
def normalize_datestime(text):
text = re.sub(_ampm_re, _expand_ampm, text)
#text = re.sub(r"([0-9]|0[0-9]|1[0-9]|2[0-3]):([0-5][0-9])?", r"\1 \2", text)
return text

View file

@ -0,0 +1,90 @@
import re
_letters_and_numbers_re = re.compile(
r"((?:[a-zA-Z]+[0-9]|[0-9]+[a-zA-Z])[a-zA-Z0-9']*)", re.IGNORECASE)
_hardware_re = re.compile(
'([0-9]+(?:[.,][0-9]+)?)(?:\s?)(tb|gb|mb|kb|ghz|mhz|khz|hz|mm)', re.IGNORECASE)
_hardware_key = {'tb': 'terabyte',
'gb': 'gigabyte',
'mb': 'megabyte',
'kb': 'kilobyte',
'ghz': 'gigahertz',
'mhz': 'megahertz',
'khz': 'kilohertz',
'hz': 'hertz',
'mm': 'millimeter',
'cm': 'centimeter',
'km': 'kilometer'}
_dimension_re = re.compile(
r'\b(\d+(?:[,.]\d+)?\s*[xX]\s*\d+(?:[,.]\d+)?\s*[xX]\s*\d+(?:[,.]\d+)?(?:in|inch|m)?)\b|\b(\d+(?:[,.]\d+)?\s*[xX]\s*\d+(?:[,.]\d+)?(?:in|inch|m)?)\b')
_dimension_key = {'m': 'meter',
'in': 'inch',
'inch': 'inch'}
def _expand_letters_and_numbers(m):
text = re.split(r'(\d+)', m.group(0))
# remove trailing space
if text[-1] == '':
text = text[:-1]
elif text[0] == '':
text = text[1:]
# if not like 1920s, or AK47's , 20th, 1st, 2nd, 3rd, etc...
if text[-1] in ("'s", "s", "th", "nd", "st", "rd") and text[-2].isdigit():
text[-2] = text[-2] + text[-1]
text = text[:-1]
# for combining digits 2 by 2
new_text = []
for i in range(len(text)):
string = text[i]
if string.isdigit() and len(string) < 5:
# heuristics
if len(string) > 2 and string[-2] == '0':
if string[-1] == '0':
string = [string]
else:
string = [string[:-3], string[-2], string[-1]]
elif len(string) % 2 == 0:
string = [string[i:i+2] for i in range(0, len(string), 2)]
elif len(string) > 2:
string = [string[0]] + [string[i:i+2] for i in range(1, len(string), 2)]
new_text.extend(string)
else:
new_text.append(string)
text = new_text
text = " ".join(text)
return text
def _expand_hardware(m):
quantity, measure = m.groups(0)
measure = _hardware_key[measure.lower()]
if measure[-1] != 'z' and float(quantity.replace(',', '')) > 1:
return "{} {}s".format(quantity, measure)
return "{} {}".format(quantity, measure)
def _expand_dimension(m):
text = "".join([x for x in m.groups(0) if x != 0])
text = text.replace(' x ', ' by ')
text = text.replace('x', ' by ')
if text.endswith(tuple(_dimension_key.keys())):
if text[-2].isdigit():
text = "{} {}".format(text[:-1], _dimension_key[text[-1:]])
elif text[-3].isdigit():
text = "{} {}".format(text[:-2], _dimension_key[text[-2:]])
return text
def normalize_letters_and_numbers(text):
text = re.sub(_hardware_re, _expand_hardware, text)
text = re.sub(_dimension_re, _expand_dimension, text)
text = re.sub(_letters_and_numbers_re, _expand_letters_and_numbers, text)
return text

View file

@ -1,71 +0,0 @@
""" from https://github.com/keithito/tacotron """
import inflect
import re
_inflect = inflect.engine()
_comma_number_re = re.compile(r'([0-9][0-9\,]+[0-9])')
_decimal_number_re = re.compile(r'([0-9]+\.[0-9]+)')
_pounds_re = re.compile(r'£([0-9\,]*[0-9]+)')
_dollars_re = re.compile(r'\$([0-9\.\,]*[0-9]+)')
_ordinal_re = re.compile(r'[0-9]+(st|nd|rd|th)')
_number_re = re.compile(r'[0-9]+')
def _remove_commas(m):
return m.group(1).replace(',', '')
def _expand_decimal_point(m):
return m.group(1).replace('.', ' point ')
def _expand_dollars(m):
match = m.group(1)
parts = match.split('.')
if len(parts) > 2:
return match + ' dollars' # Unexpected format
dollars = int(parts[0]) if parts[0] else 0
cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0
if dollars and cents:
dollar_unit = 'dollar' if dollars == 1 else 'dollars'
cent_unit = 'cent' if cents == 1 else 'cents'
return '%s %s, %s %s' % (dollars, dollar_unit, cents, cent_unit)
elif dollars:
dollar_unit = 'dollar' if dollars == 1 else 'dollars'
return '%s %s' % (dollars, dollar_unit)
elif cents:
cent_unit = 'cent' if cents == 1 else 'cents'
return '%s %s' % (cents, cent_unit)
else:
return 'zero dollars'
def _expand_ordinal(m):
return _inflect.number_to_words(m.group(0))
def _expand_number(m):
num = int(m.group(0))
if num > 1000 and num < 3000:
if num == 2000:
return 'two thousand'
elif num > 2000 and num < 2010:
return 'two thousand ' + _inflect.number_to_words(num % 100)
elif num % 100 == 0:
return _inflect.number_to_words(num // 100) + ' hundred'
else:
return _inflect.number_to_words(num, andword='', zero='oh', group=2).replace(', ', ' ')
else:
return _inflect.number_to_words(num, andword='')
def normalize_numbers(text):
text = re.sub(_comma_number_re, _remove_commas, text)
text = re.sub(_pounds_re, r'\1 pounds', text)
text = re.sub(_dollars_re, _expand_dollars, text)
text = re.sub(_decimal_number_re, _expand_decimal_point, text)
text = re.sub(_ordinal_re, _expand_ordinal, text)
text = re.sub(_number_re, _expand_number, text)
return text

View file

@ -0,0 +1,153 @@
""" adapted from https://github.com/keithito/tacotron """
import inflect
import re
_magnitudes = ['trillion', 'billion', 'million', 'thousand', 'hundred', 'm', 'b', 't']
_magnitudes_key = {'m': 'million', 'b': 'billion', 't': 'trillion'}
_measurements = '(f|c|k|d|m)'
_measurements_key = {'f': 'fahrenheit',
'c': 'celsius',
'k': 'thousand',
'm': 'meters'}
_currency_key = {'$': 'dollar', '£': 'pound', '': 'euro', '': 'won'}
_inflect = inflect.engine()
_comma_number_re = re.compile(r'([0-9][0-9\,]+[0-9])')
_decimal_number_re = re.compile(r'([0-9]+\.[0-9]+)')
_currency_re = re.compile(r'([\$€£₩])([0-9\.\,]*[0-9]+)(?:[ ]?({})(?=[^a-zA-Z]))?'.format("|".join(_magnitudes)), re.IGNORECASE)
_measurement_re = re.compile(r'([0-9\.\,]*[0-9]+(\s)?{}\b)'.format(_measurements), re.IGNORECASE)
_ordinal_re = re.compile(r'[0-9]+(st|nd|rd|th)')
# _range_re = re.compile(r'(?<=[0-9])+(-)(?=[0-9])+.*?')
_roman_re = re.compile(r'\b(?=[MDCLXVI]+\b)M{0,4}(CM|CD|D?C{0,3})(XC|XL|L?X{0,3})(IX|IV|V?I{2,3})\b') # avoid I
_multiply_re = re.compile(r'(\b[0-9]+)(x)([0-9]+)')
_number_re = re.compile(r"[0-9]+'s|[0-9]+s|[0-9]+")
def _remove_commas(m):
return m.group(1).replace(',', '')
def _expand_decimal_point(m):
return m.group(1).replace('.', ' point ')
def _expand_currency(m):
currency = _currency_key[m.group(1)]
quantity = m.group(2)
magnitude = m.group(3)
# remove commas from quantity to be able to convert to numerical
quantity = quantity.replace(',', '')
# check for million, billion, etc...
if magnitude is not None and magnitude.lower() in _magnitudes:
if len(magnitude) == 1:
magnitude = _magnitudes_key[magnitude.lower()]
return "{} {} {}".format(_expand_hundreds(quantity), magnitude, currency+'s')
parts = quantity.split('.')
if len(parts) > 2:
return quantity + " " + currency + "s" # Unexpected format
dollars = int(parts[0]) if parts[0] else 0
cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0
if dollars and cents:
dollar_unit = currency if dollars == 1 else currency+'s'
cent_unit = 'cent' if cents == 1 else 'cents'
return "{} {}, {} {}".format(
_expand_hundreds(dollars), dollar_unit,
_inflect.number_to_words(cents), cent_unit)
elif dollars:
dollar_unit = currency if dollars == 1 else currency+'s'
return "{} {}".format(_expand_hundreds(dollars), dollar_unit)
elif cents:
cent_unit = 'cent' if cents == 1 else 'cents'
return "{} {}".format(_inflect.number_to_words(cents), cent_unit)
else:
return 'zero' + ' ' + currency + 's'
def _expand_hundreds(text):
number = float(text)
if 1000 < number < 10000 and (number % 100 == 0) and (number % 1000 != 0):
return _inflect.number_to_words(int(number / 100)) + " hundred"
else:
return _inflect.number_to_words(text)
def _expand_ordinal(m):
return _inflect.number_to_words(m.group(0))
def _expand_measurement(m):
_, number, measurement = re.split('(\d+(?:\.\d+)?)', m.group(0))
number = _inflect.number_to_words(number)
measurement = "".join(measurement.split())
measurement = _measurements_key[measurement.lower()]
return "{} {}".format(number, measurement)
def _expand_range(m):
return ' to '
def _expand_multiply(m):
left = m.group(1)
right = m.group(3)
return "{} by {}".format(left, right)
def _expand_roman(m):
# from https://stackoverflow.com/questions/19308177/converting-roman-numerals-to-integers-in-python
roman_numerals = {'I':1, 'V':5, 'X':10, 'L':50, 'C':100, 'D':500, 'M':1000}
result = 0
num = m.group(0)
for i, c in enumerate(num):
if (i+1) == len(num) or roman_numerals[c] >= roman_numerals[num[i+1]]:
result += roman_numerals[c]
else:
result -= roman_numerals[c]
return str(result)
def _expand_number(m):
_, number, suffix = re.split(r"(\d+(?:'?\d+)?)", m.group(0))
number = int(number)
if number > 1000 < 10000 and (number % 100 == 0) and (number % 1000 != 0):
text = _inflect.number_to_words(number // 100) + " hundred"
elif number > 1000 and number < 3000:
if number == 2000:
text = 'two thousand'
elif number > 2000 and number < 2010:
text = 'two thousand ' + _inflect.number_to_words(number % 100)
elif number % 100 == 0:
text = _inflect.number_to_words(number // 100) + ' hundred'
else:
number = _inflect.number_to_words(number, andword='', zero='oh', group=2).replace(', ', ' ')
number = re.sub(r'-', ' ', number)
text = number
else:
number = _inflect.number_to_words(number, andword='and')
number = re.sub(r'-', ' ', number)
number = re.sub(r',', '', number)
text = number
if suffix in ("'s", "s"):
if text[-1] == 'y':
text = text[:-1] + 'ies'
else:
text = text + suffix
return text
def normalize_numbers(text):
text = re.sub(_comma_number_re, _remove_commas, text)
text = re.sub(_currency_re, _expand_currency, text)
text = re.sub(_decimal_number_re, _expand_decimal_point, text)
text = re.sub(_ordinal_re, _expand_ordinal, text)
# text = re.sub(_range_re, _expand_range, text)
# text = re.sub(_measurement_re, _expand_measurement, text)
text = re.sub(_roman_re, _expand_roman, text)
text = re.sub(_multiply_re, _expand_multiply, text)
text = re.sub(_number_re, _expand_number, text)
return text

View file

@ -4,16 +4,41 @@
Defines the set of symbols used in text input to the model. Defines the set of symbols used in text input to the model.
The default is a set of ASCII characters that works well for English or text that has been run through Unidecode. For other data, you can modify _characters. See TRAINING_DATA.md for details. ''' The default is a set of ASCII characters that works well for English or text that has been run through Unidecode. For other data, you can modify _characters. See TRAINING_DATA.md for details. '''
from common.text import cmudict from .cmudict import valid_symbols
_pad = '_'
_punctuation = '!\'(),.:;? '
_special = '-'
_letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'
# Prepend "@" to ARPAbet symbols to ensure uniqueness (some are the same as uppercase letters): # Prepend "@" to ARPAbet symbols to ensure uniqueness (some are the same as uppercase letters):
_arpabet = ['@' + s for s in cmudict.valid_symbols] _arpabet = ['@' + s for s in valid_symbols]
# Export all symbols:
symbols = [_pad] + list(_special) + list(_punctuation) + list(_letters) + _arpabet def get_symbols(symbol_set='english_basic'):
pad_idx = 0 if symbol_set == 'english_basic':
_pad = '_'
_punctuation = '!\'(),.:;? '
_special = '-'
_letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'
symbols = list(_pad + _special + _punctuation + _letters) + _arpabet
elif symbol_set == 'english_basic_lowercase':
_pad = '_'
_punctuation = '!\'"(),.:;? '
_special = '-'
_letters = 'abcdefghijklmnopqrstuvwxyz'
symbols = list(_pad + _special + _punctuation + _letters) + _arpabet
elif symbol_set == 'english_expanded':
_punctuation = '!\'",.:;? '
_math = '#%&*+-/[]()'
_special = '_@©°½—₩€$'
_accented = 'áçéêëñöøćž'
_letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'
symbols = list(_punctuation + _math + _special + _accented + _letters) + _arpabet
else:
raise Exception("{} symbol set does not exist".format(symbol_set))
return symbols
def get_pad_idx(symbol_set='english_basic'):
if symbol_set in {'english_basic', 'english_basic_lowercase'}:
return 0
else:
raise Exception("{} symbol set not used yet".format(symbol_set))

View file

@ -0,0 +1,175 @@
""" adapted from https://github.com/keithito/tacotron """
import re
import numpy as np
from . import cleaners
from .symbols import get_symbols
from .cmudict import CMUDict
from .numerical import _currency_re, _expand_currency
#########
# REGEX #
#########
# Regular expression matching text enclosed in curly braces for encoding
_curly_re = re.compile(r'(.*?)\{(.+?)\}(.*)')
# Regular expression matching words and not words
_words_re = re.compile(r"([a-zA-ZÀ-ž]+['][a-zA-ZÀ-ž]{1,2}|[a-zA-ZÀ-ž]+)|([{][^}]+[}]|[^a-zA-ZÀ-ž{}]+)")
# Regular expression separating words enclosed in curly braces for cleaning
_arpa_re = re.compile(r'{[^}]+}|\S+')
def lines_to_list(filename):
with open(filename, encoding='utf-8') as f:
lines = f.readlines()
lines = [l.rstrip() for l in lines]
return lines
class TextProcessing(object):
def __init__(self, symbol_set, cleaner_names, p_arpabet=0.0,
handle_arpabet='word', handle_arpabet_ambiguous='ignore',
expand_currency=True):
self.symbols = get_symbols(symbol_set)
self.cleaner_names = cleaner_names
# Mappings from symbol to numeric ID and vice versa:
self.symbol_to_id = {s: i for i, s in enumerate(self.symbols)}
self.id_to_symbol = {i: s for i, s in enumerate(self.symbols)}
self.expand_currency = expand_currency
# cmudict
self.p_arpabet = p_arpabet
self.handle_arpabet = handle_arpabet
self.handle_arpabet_ambiguous = handle_arpabet_ambiguous
def text_to_sequence(self, text):
sequence = []
# Check for curly braces and treat their contents as ARPAbet:
while len(text):
m = _curly_re.match(text)
if not m:
sequence += self.symbols_to_sequence(text)
break
sequence += self.symbols_to_sequence(m.group(1))
sequence += self.arpabet_to_sequence(m.group(2))
text = m.group(3)
return sequence
def sequence_to_text(self, sequence):
result = ''
for symbol_id in sequence:
if symbol_id in self.id_to_symbol:
s = self.id_to_symbol[symbol_id]
# Enclose ARPAbet back in curly braces:
if len(s) > 1 and s[0] == '@':
s = '{%s}' % s[1:]
result += s
return result.replace('}{', ' ')
def clean_text(self, text):
for name in self.cleaner_names:
cleaner = getattr(cleaners, name)
if not cleaner:
raise Exception('Unknown cleaner: %s' % name)
text = cleaner(text)
return text
def symbols_to_sequence(self, symbols):
return [self.symbol_to_id[s] for s in symbols if s in self.symbol_to_id]
def arpabet_to_sequence(self, text):
return self.symbols_to_sequence(['@' + s for s in text.split()])
def get_arpabet(self, word):
arpabet_suffix = ''
if word.lower() in cmudict.heteronyms:
return word
if len(word) > 2 and word.endswith("'s"):
arpabet = cmudict.lookup(word)
if arpabet is None:
arpabet = self.get_arpabet(word[:-2])
arpabet_suffix = ' Z'
elif len(word) > 1 and word.endswith("s"):
arpabet = cmudict.lookup(word)
if arpabet is None:
arpabet = self.get_arpabet(word[:-1])
arpabet_suffix = ' Z'
else:
arpabet = cmudict.lookup(word)
if arpabet is None:
return word
elif arpabet[0] == '{':
arpabet = [arpabet[1:-1]]
if len(arpabet) > 1:
if self.handle_arpabet_ambiguous == 'first':
arpabet = arpabet[0]
elif self.handle_arpabet_ambiguous == 'random':
arpabet = np.random.choice(arpabet)
elif self.handle_arpabet_ambiguous == 'ignore':
return word
else:
arpabet = arpabet[0]
arpabet = "{" + arpabet + arpabet_suffix + "}"
return arpabet
# def get_characters(self, word):
# for name in self.cleaner_names:
# cleaner = getattr(cleaners, f'{name}_post_chars')
# if not cleaner:
# raise Exception('Unknown cleaner: %s' % name)
# word = cleaner(word)
# return word
def encode_text(self, text, return_all=False):
if self.expand_currency:
text = re.sub(_currency_re, _expand_currency, text)
text_clean = [self.clean_text(split) if split[0] != '{' else split
for split in _arpa_re.findall(text)]
text_clean = ' '.join(text_clean)
text = text_clean
text_arpabet = ''
if self.p_arpabet > 0:
if self.handle_arpabet == 'sentence':
if np.random.uniform() < self.p_arpabet:
words = _words_re.findall(text)
text_arpabet = [
self.get_arpabet(word[0])
if (word[0] != '') else word[1]
for word in words]
text_arpabet = ''.join(text_arpabet)
text = text_arpabet
elif self.handle_arpabet == 'word':
words = _words_re.findall(text)
text_arpabet = [
word[1] if word[0] == '' else (
self.get_arpabet(word[0])
if np.random.uniform() < self.p_arpabet
else word[0])
for word in words]
text_arpabet = ''.join(text_arpabet)
text = text_arpabet
elif self.handle_arpabet != '':
raise Exception("{} handle_arpabet is not supported".format(
self.handle_arpabet))
text_encoded = self.text_to_sequence(text)
if return_all:
return text_encoded, text_clean, text_arpabet
return text_encoded

View file

@ -25,7 +25,6 @@
# #
# ***************************************************************************** # *****************************************************************************
import os
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Optional
@ -48,14 +47,20 @@ def load_wav_to_torch(full_path):
return torch.FloatTensor(data.astype(np.float32)), sampling_rate return torch.FloatTensor(data.astype(np.float32)), sampling_rate
def load_filepaths_and_text(dataset_path, filename, split="|"): def load_filepaths_and_text(dataset_path, fnames, has_speakers=False, split="|"):
def split_line(root, line): def split_line(root, line):
parts = line.strip().split(split) parts = line.strip().split(split)
paths, text = parts[:-1], parts[-1] if has_speakers:
return tuple(os.path.join(root, p) for p in paths) + (text,) paths, non_paths = parts[:-2], parts[-2:]
with open(filename, encoding='utf-8') as f: else:
filepaths_and_text = [split_line(dataset_path, line) for line in f] paths, non_paths = parts[:-1], parts[-1:]
return filepaths_and_text return tuple(str(Path(root, p)) for p in paths) + tuple(non_paths)
fpaths_and_text = []
for fname in fnames.split(','):
with open(fname, encoding='utf-8') as f:
fpaths_and_text += [split_line(dataset_path, line) for line in f]
return fpaths_and_text
def stats_filename(dataset_path, filelist_path, feature_name): def stats_filename(dataset_path, filelist_path, feature_name):

View file

@ -40,6 +40,7 @@ from torch.utils.data import DataLoader
from common import utils from common import utils
from inference import load_and_setup_model from inference import load_and_setup_model
from tacotron2.data_function import TextMelLoader, TextMelCollate, batch_to_gpu from tacotron2.data_function import TextMelLoader, TextMelCollate, batch_to_gpu
from common.text.text_processing import TextProcessing
def parse_args(parser): def parse_args(parser):
@ -59,6 +60,8 @@ def parse_args(parser):
parser.add_argument('--text-cleaners', nargs='*', parser.add_argument('--text-cleaners', nargs='*',
default=['english_cleaners'], type=str, default=['english_cleaners'], type=str,
help='Type of text cleaners for input text') help='Type of text cleaners for input text')
parser.add_argument('--symbol-set', type=str, default='english_basic',
help='Define symbol set for input text')
parser.add_argument('--max-wav-value', default=32768.0, type=float, parser.add_argument('--max-wav-value', default=32768.0, type=float,
help='Maximum audiowave value') help='Maximum audiowave value')
parser.add_argument('--sampling-rate', default=22050, type=int, parser.add_argument('--sampling-rate', default=22050, type=int,
@ -98,6 +101,7 @@ def parse_args(parser):
class FilenamedLoader(TextMelLoader): class FilenamedLoader(TextMelLoader):
def __init__(self, filenames, *args, **kwargs): def __init__(self, filenames, *args, **kwargs):
super(FilenamedLoader, self).__init__(*args, **kwargs) super(FilenamedLoader, self).__init__(*args, **kwargs)
self.tp = TextProcessing(args[-1].symbol_set, args[-1].text_cleaners)
self.filenames = filenames self.filenames = filenames
def __getitem__(self, index): def __getitem__(self, index):
@ -211,6 +215,8 @@ def main():
filenames = [Path(l.split('|')[0]).stem filenames = [Path(l.split('|')[0]).stem
for l in open(args.wav_text_filelist, 'r')] for l in open(args.wav_text_filelist, 'r')]
# Compatibility with Tacotron2 Data loader
args.n_speakers = 1
dataset = FilenamedLoader(filenames, args.dataset_path, args.wav_text_filelist, dataset = FilenamedLoader(filenames, args.dataset_path, args.wav_text_filelist,
args, load_mel_from_disk=False) args, load_mel_from_disk=False)
# TextMelCollate supports only n_frames_per_step=1 # TextMelCollate supports only n_frames_per_step=1

View file

@ -27,8 +27,6 @@
import argparse import argparse
from common.text import symbols
def parse_fastpitch_args(parent, add_help=False): def parse_fastpitch_args(parent, add_help=False):
""" """
@ -42,11 +40,12 @@ def parse_fastpitch_args(parent, add_help=False):
help='Number of bins in mel-spectrograms') help='Number of bins in mel-spectrograms')
io.add_argument('--max-seq-len', default=2048, type=int, io.add_argument('--max-seq-len', default=2048, type=int,
help='') help='')
global symbols
len_symbols = len(symbols)
symbols = parser.add_argument_group('symbols parameters') symbols = parser.add_argument_group('symbols parameters')
symbols.add_argument('--n-symbols', default=len_symbols, type=int, symbols.add_argument('--n-symbols', default=148, type=int,
help='Number of symbols in dictionary') help='Number of symbols in dictionary')
symbols.add_argument('--padding-idx', default=0, type=int,
help='Index of padding symbol in dictionary')
symbols.add_argument('--symbols-embedding-dim', default=384, type=int, symbols.add_argument('--symbols-embedding-dim', default=384, type=int,
help='Input embedding dimension') help='Input embedding dimension')
@ -102,11 +101,18 @@ def parse_fastpitch_args(parent, add_help=False):
pitch_pred = parser.add_argument_group('pitch predictor parameters') pitch_pred = parser.add_argument_group('pitch predictor parameters')
pitch_pred.add_argument('--pitch-predictor-kernel-size', default=3, type=int, pitch_pred.add_argument('--pitch-predictor-kernel-size', default=3, type=int,
help='Pitch predictor conv-1D kernel size') help='Pitch predictor conv-1D kernel size')
pitch_pred.add_argument('--pitch-predictor-filter-size', default=256, type=int, pitch_pred.add_argument('--pitch-predictor-filter-size', default=256, type=int,
help='Pitch predictor conv-1D filter size') help='Pitch predictor conv-1D filter size')
pitch_pred.add_argument('--p-pitch-predictor-dropout', default=0.1, type=float, pitch_pred.add_argument('--p-pitch-predictor-dropout', default=0.1, type=float,
help='Pitch probability for pitch predictor') help='Pitch probability for pitch predictor')
pitch_pred.add_argument('--pitch-predictor-n-layers', default=2, type=int, pitch_pred.add_argument('--pitch-predictor-n-layers', default=2, type=int,
help='Number of conv-1D layers') help='Number of conv-1D layers')
cond = parser.add_argument_group('conditioning parameters')
cond.add_argument('--pitch-embedding-kernel-size', default=3, type=int,
help='Pitch embedding conv-1D kernel size')
cond.add_argument('--speaker-emb-weight', type=float, default=1.0,
help='Scale speaker embedding')
return parser return parser

View file

@ -31,6 +31,7 @@ import torch
from common.utils import to_gpu from common.utils import to_gpu
from tacotron2.data_function import TextMelLoader from tacotron2.data_function import TextMelLoader
from common.text.text_processing import TextProcessing
class TextMelAliLoader(TextMelLoader): class TextMelAliLoader(TextMelLoader):
@ -38,18 +39,27 @@ class TextMelAliLoader(TextMelLoader):
""" """
def __init__(self, *args): def __init__(self, *args):
super(TextMelAliLoader, self).__init__(*args) super(TextMelAliLoader, self).__init__(*args)
if len(self.audiopaths_and_text[0]) != 4: self.tp = TextProcessing(args[-1].symbol_set, args[-1].text_cleaners)
raise ValueError('Expected four columns in audiopaths file') self.n_speakers = args[-1].n_speakers
if len(self.audiopaths_and_text[0]) != 4 + (args[-1].n_speakers > 1):
raise ValueError('Expected four columns in audiopaths file for single speaker model. \n'
'For multispeaker model, the filelist format is '
'<mel>|<dur>|<pitch>|<text>|<speaker_id>')
def __getitem__(self, index): def __getitem__(self, index):
# separate filename and text # separate filename and text
audiopath, durpath, pitchpath, text = self.audiopaths_and_text[index] if self.n_speakers > 1:
audiopath, durpath, pitchpath, text, speaker = self.audiopaths_and_text[index]
speaker = int(speaker)
else:
audiopath, durpath, pitchpath, text = self.audiopaths_and_text[index]
speaker = None
len_text = len(text) len_text = len(text)
text = self.get_text(text) text = self.get_text(text)
mel = self.get_mel(audiopath) mel = self.get_mel(audiopath)
dur = torch.load(durpath) dur = torch.load(durpath)
pitch = torch.load(pitchpath) pitch = torch.load(pitchpath)
return (text, mel, len_text, dur, pitch) return (text, mel, len_text, dur, pitch, speaker)
class TextMelAliCollate(): class TextMelAliCollate():
@ -107,16 +117,24 @@ class TextMelAliCollate():
pitch = batch[ids_sorted_decreasing[i]][4] pitch = batch[ids_sorted_decreasing[i]][4]
pitch_padded[i, :pitch.shape[0]] = pitch pitch_padded[i, :pitch.shape[0]] = pitch
if batch[0][5] is not None:
speaker = torch.zeros_like(input_lengths)
for i in range(len(ids_sorted_decreasing)):
speaker[i] = batch[ids_sorted_decreasing[i]][5]
else:
speaker = None
# count number of items - characters in text # count number of items - characters in text
len_x = [x[2] for x in batch] len_x = [x[2] for x in batch]
len_x = torch.Tensor(len_x) len_x = torch.Tensor(len_x)
return (text_padded, input_lengths, mel_padded,
output_lengths, len_x, dur_padded, dur_lens, pitch_padded) return (text_padded, input_lengths, mel_padded, output_lengths,
len_x, dur_padded, dur_lens, pitch_padded, speaker)
def batch_to_gpu(batch): def batch_to_gpu(batch):
text_padded, input_lengths, mel_padded, \ text_padded, input_lengths, mel_padded, output_lengths, \
output_lengths, len_x, dur_padded, dur_lens, pitch_padded = batch len_x, dur_padded, dur_lens, pitch_padded, speaker = batch
text_padded = to_gpu(text_padded).long() text_padded = to_gpu(text_padded).long()
input_lengths = to_gpu(input_lengths).long() input_lengths = to_gpu(input_lengths).long()
mel_padded = to_gpu(mel_padded).float() mel_padded = to_gpu(mel_padded).float()
@ -124,9 +142,11 @@ def batch_to_gpu(batch):
dur_padded = to_gpu(dur_padded).long() dur_padded = to_gpu(dur_padded).long()
dur_lens = to_gpu(dur_lens).long() dur_lens = to_gpu(dur_lens).long()
pitch_padded = to_gpu(pitch_padded).float() pitch_padded = to_gpu(pitch_padded).float()
if speaker is not None:
speaker = to_gpu(speaker).long()
# Alignments act as both inputs and targets - pass shallow copies # Alignments act as both inputs and targets - pass shallow copies
x = [text_padded, input_lengths, mel_padded, output_lengths, x = [text_padded, input_lengths, mel_padded, output_lengths,
dur_padded, dur_lens, pitch_padded] dur_padded, dur_lens, pitch_padded, speaker]
y = [mel_padded, dur_padded, dur_lens, pitch_padded] y = [mel_padded, dur_padded, dur_lens, pitch_padded]
len_x = torch.sum(output_lengths) len_x = torch.sum(output_lengths)
return (x, y, len_x) return (x, y, len_x)

View file

@ -73,7 +73,7 @@ class FastPitchLoss(nn.Module):
'mel_loss': mel_loss.clone().detach(), 'mel_loss': mel_loss.clone().detach(),
'duration_predictor_loss': dur_pred_loss.clone().detach(), 'duration_predictor_loss': dur_pred_loss.clone().detach(),
'pitch_loss': pitch_loss.clone().detach(), 'pitch_loss': pitch_loss.clone().detach(),
'dur_error': (torch.abs(dur_pred - dur_tgt).sum() 'dur_error': (torch.abs(dur_pred - dur_tgt).sum()
/ dur_mask.sum()).detach(), / dur_mask.sum()).detach(),
} }
assert meta_agg in ('sum', 'mean') assert meta_agg in ('sum', 'mean')

View file

@ -1,210 +1,236 @@
# ***************************************************************************** # *****************************************************************************
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
# #
# Redistribution and use in source and binary forms, with or without # Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met: # modification, are permitted provided that the following conditions are met:
# * Redistributions of source code must retain the above copyright # * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer. # notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright # * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the # notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution. # documentation and/or other materials provided with the distribution.
# * Neither the name of the NVIDIA CORPORATION nor the # * Neither the name of the NVIDIA CORPORATION nor the
# names of its contributors may be used to endorse or promote products # names of its contributors may be used to endorse or promote products
# derived from this software without specific prior written permission. # derived from this software without specific prior written permission.
# #
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED # ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE # WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY # DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
# DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES # DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
# (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; # (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND # LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
# ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT # ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
# #
# ***************************************************************************** # *****************************************************************************
import torch import torch
from torch import nn as nn from torch import nn as nn
from torch.nn.utils.rnn import pad_sequence from torch.nn.utils.rnn import pad_sequence
from common.layers import ConvReLUNorm from common.layers import ConvReLUNorm
from common.utils import mask_from_lens from common.utils import mask_from_lens
from fastpitch.transformer import FFTransformer from fastpitch.transformer import FFTransformer
def regulate_len(durations, enc_out, pace=1.0, mel_max_len=None): def regulate_len(durations, enc_out, pace=1.0, mel_max_len=None):
"""If target=None, then predicted durations are applied""" """If target=None, then predicted durations are applied"""
reps = torch.round(durations.float() / pace).long() reps = torch.round(durations.float() / pace).long()
dec_lens = reps.sum(dim=1) dec_lens = reps.sum(dim=1)
enc_rep = pad_sequence([torch.repeat_interleave(o, r, dim=0) enc_rep = pad_sequence([torch.repeat_interleave(o, r, dim=0)
for o, r in zip(enc_out, reps)], for o, r in zip(enc_out, reps)],
batch_first=True) batch_first=True)
if mel_max_len: if mel_max_len:
enc_rep = enc_rep[:, :mel_max_len] enc_rep = enc_rep[:, :mel_max_len]
dec_lens = torch.clamp_max(dec_lens, mel_max_len) dec_lens = torch.clamp_max(dec_lens, mel_max_len)
return enc_rep, dec_lens return enc_rep, dec_lens
class TemporalPredictor(nn.Module): class TemporalPredictor(nn.Module):
"""Predicts a single float per each temporal location""" """Predicts a single float per each temporal location"""
def __init__(self, input_size, filter_size, kernel_size, dropout, def __init__(self, input_size, filter_size, kernel_size, dropout,
n_layers=2): n_layers=2):
super(TemporalPredictor, self).__init__() super(TemporalPredictor, self).__init__()
self.layers = nn.Sequential(*[ self.layers = nn.Sequential(*[
ConvReLUNorm(input_size if i == 0 else filter_size, filter_size, ConvReLUNorm(input_size if i == 0 else filter_size, filter_size,
kernel_size=kernel_size, dropout=dropout) kernel_size=kernel_size, dropout=dropout)
for i in range(n_layers)] for i in range(n_layers)]
) )
self.fc = nn.Linear(filter_size, 1, bias=True) self.fc = nn.Linear(filter_size, 1, bias=True)
def forward(self, enc_out, enc_out_mask): def forward(self, enc_out, enc_out_mask):
out = enc_out * enc_out_mask out = enc_out * enc_out_mask
out = self.layers(out.transpose(1, 2)).transpose(1, 2) out = self.layers(out.transpose(1, 2)).transpose(1, 2)
out = self.fc(out) * enc_out_mask out = self.fc(out) * enc_out_mask
return out.squeeze(-1) return out.squeeze(-1)
class FastPitch(nn.Module): class FastPitch(nn.Module):
def __init__(self, n_mel_channels, max_seq_len, n_symbols, def __init__(self, n_mel_channels, max_seq_len, n_symbols, padding_idx,
symbols_embedding_dim, in_fft_n_layers, in_fft_n_heads, symbols_embedding_dim, in_fft_n_layers, in_fft_n_heads,
in_fft_d_head, in_fft_d_head,
in_fft_conv1d_kernel_size, in_fft_conv1d_filter_size, in_fft_conv1d_kernel_size, in_fft_conv1d_filter_size,
in_fft_output_size, in_fft_output_size,
p_in_fft_dropout, p_in_fft_dropatt, p_in_fft_dropemb, p_in_fft_dropout, p_in_fft_dropatt, p_in_fft_dropemb,
out_fft_n_layers, out_fft_n_heads, out_fft_d_head, out_fft_n_layers, out_fft_n_heads, out_fft_d_head,
out_fft_conv1d_kernel_size, out_fft_conv1d_filter_size, out_fft_conv1d_kernel_size, out_fft_conv1d_filter_size,
out_fft_output_size, out_fft_output_size,
p_out_fft_dropout, p_out_fft_dropatt, p_out_fft_dropemb, p_out_fft_dropout, p_out_fft_dropatt, p_out_fft_dropemb,
dur_predictor_kernel_size, dur_predictor_filter_size, dur_predictor_kernel_size, dur_predictor_filter_size,
p_dur_predictor_dropout, dur_predictor_n_layers, p_dur_predictor_dropout, dur_predictor_n_layers,
pitch_predictor_kernel_size, pitch_predictor_filter_size, pitch_predictor_kernel_size, pitch_predictor_filter_size,
p_pitch_predictor_dropout, pitch_predictor_n_layers): p_pitch_predictor_dropout, pitch_predictor_n_layers,
super(FastPitch, self).__init__() pitch_embedding_kernel_size, n_speakers, speaker_emb_weight):
del max_seq_len # unused super(FastPitch, self).__init__()
del n_symbols del max_seq_len # unused
self.encoder = FFTransformer( self.encoder = FFTransformer(
n_layer=in_fft_n_layers, n_head=in_fft_n_heads, n_layer=in_fft_n_layers, n_head=in_fft_n_heads,
d_model=symbols_embedding_dim, d_model=symbols_embedding_dim,
d_head=in_fft_d_head, d_head=in_fft_d_head,
d_inner=in_fft_conv1d_filter_size, d_inner=in_fft_conv1d_filter_size,
kernel_size=in_fft_conv1d_kernel_size, kernel_size=in_fft_conv1d_kernel_size,
dropout=p_in_fft_dropout, dropout=p_in_fft_dropout,
dropatt=p_in_fft_dropatt, dropatt=p_in_fft_dropatt,
dropemb=p_in_fft_dropemb, dropemb=p_in_fft_dropemb,
d_embed=symbols_embedding_dim, embed_input=True,
embed_input=True) d_embed=symbols_embedding_dim,
n_embed=n_symbols,
self.duration_predictor = TemporalPredictor( padding_idx=padding_idx)
in_fft_output_size,
filter_size=dur_predictor_filter_size, if n_speakers > 1:
kernel_size=dur_predictor_kernel_size, self.speaker_emb = nn.Embedding(n_speakers, symbols_embedding_dim)
dropout=p_dur_predictor_dropout, n_layers=dur_predictor_n_layers else:
) self.speaker_emb = None
self.speaker_emb_weight = speaker_emb_weight
self.decoder = FFTransformer(
n_layer=out_fft_n_layers, n_head=out_fft_n_heads, self.duration_predictor = TemporalPredictor(
d_model=symbols_embedding_dim, in_fft_output_size,
d_head=out_fft_d_head, filter_size=dur_predictor_filter_size,
d_inner=out_fft_conv1d_filter_size, kernel_size=dur_predictor_kernel_size,
kernel_size=out_fft_conv1d_kernel_size, dropout=p_dur_predictor_dropout, n_layers=dur_predictor_n_layers
dropout=p_out_fft_dropout, )
dropatt=p_out_fft_dropatt,
dropemb=p_out_fft_dropemb, self.decoder = FFTransformer(
d_embed=symbols_embedding_dim, n_layer=out_fft_n_layers, n_head=out_fft_n_heads,
embed_input=False) d_model=symbols_embedding_dim,
d_head=out_fft_d_head,
self.pitch_predictor = TemporalPredictor( d_inner=out_fft_conv1d_filter_size,
in_fft_output_size, kernel_size=out_fft_conv1d_kernel_size,
filter_size=pitch_predictor_filter_size, dropout=p_out_fft_dropout,
kernel_size=pitch_predictor_kernel_size, dropatt=p_out_fft_dropatt,
dropout=p_pitch_predictor_dropout, n_layers=pitch_predictor_n_layers dropemb=p_out_fft_dropemb,
) embed_input=False,
self.pitch_emb = nn.Conv1d(1, symbols_embedding_dim, kernel_size=3, d_embed=symbols_embedding_dim
padding=1) )
# Store values precomputed for training data within the model self.pitch_predictor = TemporalPredictor(
self.register_buffer('pitch_mean', torch.zeros(1)) in_fft_output_size,
self.register_buffer('pitch_std', torch.zeros(1)) filter_size=pitch_predictor_filter_size,
kernel_size=pitch_predictor_kernel_size,
self.proj = nn.Linear(out_fft_output_size, n_mel_channels, bias=True) dropout=p_pitch_predictor_dropout, n_layers=pitch_predictor_n_layers
)
def forward(self, inputs, use_gt_durations=True, use_gt_pitch=True,
pace=1.0, max_duration=75): self.pitch_emb = nn.Conv1d(
inputs, _, mel_tgt, _, dur_tgt, _, pitch_tgt = inputs 1, symbols_embedding_dim,
mel_max_len = mel_tgt.size(2) kernel_size=pitch_embedding_kernel_size,
padding=int((pitch_embedding_kernel_size - 1) / 2))
# Input FFT
enc_out, enc_mask = self.encoder(inputs) # Store values precomputed for training data within the model
self.register_buffer('pitch_mean', torch.zeros(1))
# Embedded for predictors self.register_buffer('pitch_std', torch.zeros(1))
pred_enc_out, pred_enc_mask = enc_out, enc_mask
self.proj = nn.Linear(out_fft_output_size, n_mel_channels, bias=True)
# Predict durations
log_dur_pred = self.duration_predictor(pred_enc_out, pred_enc_mask) def forward(self, inputs, use_gt_durations=True, use_gt_pitch=True,
dur_pred = torch.clamp(torch.exp(log_dur_pred) - 1, 0, max_duration) pace=1.0, max_duration=75):
inputs, _, mel_tgt, _, dur_tgt, _, pitch_tgt, speaker = inputs
# Predict pitch mel_max_len = mel_tgt.size(2)
pitch_pred = self.pitch_predictor(enc_out, enc_mask)
# Calculate speaker embedding
if use_gt_pitch and pitch_tgt is not None: if self.speaker_emb is None:
pitch_emb = self.pitch_emb(pitch_tgt.unsqueeze(1)) spk_emb = 0
else: else:
pitch_emb = self.pitch_emb(pitch_pred.unsqueeze(1)) spk_emb = self.speaker_emb(speaker).unsqueeze(1)
enc_out = enc_out + pitch_emb.transpose(1, 2) spk_emb.mul_(self.speaker_emb_weight)
len_regulated, dec_lens = regulate_len( # Input FFT
dur_tgt if use_gt_durations else dur_pred, enc_out, enc_mask = self.encoder(inputs, conditioning=spk_emb)
enc_out, pace, mel_max_len)
# Embedded for predictors
# Output FFT pred_enc_out, pred_enc_mask = enc_out, enc_mask
dec_out, dec_mask = self.decoder(len_regulated, dec_lens)
mel_out = self.proj(dec_out) # Predict durations
return mel_out, dec_mask, dur_pred, log_dur_pred, pitch_pred log_dur_pred = self.duration_predictor(pred_enc_out, pred_enc_mask)
dur_pred = torch.clamp(torch.exp(log_dur_pred) - 1, 0, max_duration)
def infer(self, inputs, input_lens, pace=1.0, dur_tgt=None, pitch_tgt=None,
pitch_transform=None, max_duration=75): # Predict pitch
del input_lens # unused pitch_pred = self.pitch_predictor(enc_out, enc_mask)
# Input FFT if use_gt_pitch and pitch_tgt is not None:
enc_out, enc_mask = self.encoder(inputs) pitch_emb = self.pitch_emb(pitch_tgt.unsqueeze(1))
else:
# Embedded for predictors pitch_emb = self.pitch_emb(pitch_pred.unsqueeze(1))
pred_enc_out, pred_enc_mask = enc_out, enc_mask enc_out = enc_out + pitch_emb.transpose(1, 2)
# Predict durations len_regulated, dec_lens = regulate_len(
log_dur_pred = self.duration_predictor(pred_enc_out, pred_enc_mask) dur_tgt if use_gt_durations else dur_pred,
dur_pred = torch.clamp(torch.exp(log_dur_pred) - 1, 0, max_duration) enc_out, pace, mel_max_len)
# Pitch over chars # Output FFT
pitch_pred = self.pitch_predictor(enc_out, enc_mask) dec_out, dec_mask = self.decoder(len_regulated, dec_lens)
mel_out = self.proj(dec_out)
if pitch_transform is not None: return mel_out, dec_mask, dur_pred, log_dur_pred, pitch_pred
if self.pitch_std[0] == 0.0:
# XXX LJSpeech-1.1 defaults def infer(self, inputs, input_lens, pace=1.0, dur_tgt=None, pitch_tgt=None,
mean, std = 218.14, 67.24 pitch_transform=None, max_duration=75, speaker=0):
else: del input_lens # unused
mean, std = self.pitch_mean[0], self.pitch_std[0]
pitch_pred = pitch_transform(pitch_pred, enc_mask.sum(dim=(1,2)), mean, std) if self.speaker_emb is None:
spk_emb = 0
if pitch_tgt is None: else:
pitch_emb = self.pitch_emb(pitch_pred.unsqueeze(1)).transpose(1, 2) speaker = torch.ones(inputs.size(0)).long().to(inputs.device) * speaker
else: spk_emb = self.speaker_emb(speaker).unsqueeze(1)
pitch_emb = self.pitch_emb(pitch_tgt.unsqueeze(1)).transpose(1, 2) spk_emb.mul_(self.speaker_emb_weight)
enc_out = enc_out + pitch_emb # Input FFT
enc_out, enc_mask = self.encoder(inputs, conditioning=spk_emb)
len_regulated, dec_lens = regulate_len(
dur_pred if dur_tgt is None else dur_tgt, # Embedded for predictors
enc_out, pace, mel_max_len=None) pred_enc_out, pred_enc_mask = enc_out, enc_mask
dec_out, dec_mask = self.decoder(len_regulated, dec_lens) # Predict durations
mel_out = self.proj(dec_out) log_dur_pred = self.duration_predictor(pred_enc_out, pred_enc_mask)
# mel_lens = dec_mask.squeeze(2).sum(axis=1).long() dur_pred = torch.clamp(torch.exp(log_dur_pred) - 1, 0, max_duration)
mel_out = mel_out.permute(0, 2, 1) # For inference.py
return mel_out, dec_lens, dur_pred, pitch_pred # Pitch over chars
pitch_pred = self.pitch_predictor(enc_out, enc_mask)
if pitch_transform is not None:
if self.pitch_std[0] == 0.0:
# XXX LJSpeech-1.1 defaults
mean, std = 218.14, 67.24
else:
mean, std = self.pitch_mean[0], self.pitch_std[0]
pitch_pred = pitch_transform(pitch_pred, enc_mask.sum(dim=(1,2)), mean, std)
if pitch_tgt is None:
pitch_emb = self.pitch_emb(pitch_pred.unsqueeze(1)).transpose(1, 2)
else:
pitch_emb = self.pitch_emb(pitch_tgt.unsqueeze(1)).transpose(1, 2)
enc_out = enc_out + pitch_emb
len_regulated, dec_lens = regulate_len(
dur_pred if dur_tgt is None else dur_tgt,
enc_out, pace, mel_max_len=None)
dec_out, dec_mask = self.decoder(len_regulated, dec_lens)
mel_out = self.proj(dec_out)
# mel_lens = dec_mask.squeeze(2).sum(axis=1).long()
mel_out = mel_out.permute(0, 2, 1) # For inference.py
return mel_out, dec_lens, dur_pred, pitch_pred

View file

@ -1,218 +1,246 @@
# ***************************************************************************** # *****************************************************************************
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
# #
# Redistribution and use in source and binary forms, with or without # Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met: # modification, are permitted provided that the following conditions are met:
# * Redistributions of source code must retain the above copyright # * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer. # notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright # * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the # notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution. # documentation and/or other materials provided with the distribution.
# * Neither the name of the NVIDIA CORPORATION nor the # * Neither the name of the NVIDIA CORPORATION nor the
# names of its contributors may be used to endorse or promote products # names of its contributors may be used to endorse or promote products
# derived from this software without specific prior written permission. # derived from this software without specific prior written permission.
# #
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED # ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE # WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY # DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
# DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES # DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
# (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; # (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND # LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
# ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT # ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
# #
# ***************************************************************************** # *****************************************************************************
from typing import List, Optional from typing import List, Optional
import torch import torch
from torch import nn as nn from torch import nn as nn
from torch.nn.utils.rnn import pad_sequence from torch.nn.utils.rnn import pad_sequence
from common.layers import ConvReLUNorm from common.layers import ConvReLUNorm
from fastpitch.transformer_jit import FFTransformer from fastpitch.transformer_jit import FFTransformer
def regulate_len(durations, enc_out, pace: float = 1.0, def regulate_len(durations, enc_out, pace: float = 1.0,
mel_max_len: Optional[int] = None): mel_max_len: Optional[int] = None):
"""If target=None, then predicted durations are applied""" """If target=None, then predicted durations are applied"""
reps = torch.round(durations.float() / pace).long() reps = torch.round(durations.float() / pace).long()
dec_lens = reps.sum(dim=1) dec_lens = reps.sum(dim=1)
max_len = dec_lens.max() max_len = dec_lens.max()
bsz, _, hid = enc_out.size() bsz, _, hid = enc_out.size()
reps_padded = torch.cat([reps, (max_len - dec_lens)[:, None]], dim=1) reps_padded = torch.cat([reps, (max_len - dec_lens)[:, None]], dim=1)
pad_vec = torch.zeros(bsz, 1, hid, dtype=enc_out.dtype, pad_vec = torch.zeros(bsz, 1, hid, dtype=enc_out.dtype,
device=enc_out.device) device=enc_out.device)
enc_rep = torch.cat([enc_out, pad_vec], dim=1) enc_rep = torch.cat([enc_out, pad_vec], dim=1)
enc_rep = torch.repeat_interleave( enc_rep = torch.repeat_interleave(
enc_rep.view(-1, hid), reps_padded.view(-1), dim=0 enc_rep.view(-1, hid), reps_padded.view(-1), dim=0
).view(bsz, -1, hid) ).view(bsz, -1, hid)
# enc_rep = pad_sequence([torch.repeat_interleave(o, r, dim=0) # enc_rep = pad_sequence([torch.repeat_interleave(o, r, dim=0)
# for o, r in zip(enc_out, reps)], # for o, r in zip(enc_out, reps)],
# batch_first=True) # batch_first=True)
if mel_max_len is not None: if mel_max_len is not None:
enc_rep = enc_rep[:, :mel_max_len] enc_rep = enc_rep[:, :mel_max_len]
dec_lens = torch.clamp_max(dec_lens, mel_max_len) dec_lens = torch.clamp_max(dec_lens, mel_max_len)
return enc_rep, dec_lens return enc_rep, dec_lens
class TemporalPredictor(nn.Module): class TemporalPredictor(nn.Module):
"""Predicts a single float per each temporal location""" """Predicts a single float per each temporal location"""
def __init__(self, input_size, filter_size, kernel_size, dropout, def __init__(self, input_size, filter_size, kernel_size, dropout,
n_layers=2): n_layers=2):
super(TemporalPredictor, self).__init__() super(TemporalPredictor, self).__init__()
self.layers = nn.Sequential(*[ self.layers = nn.Sequential(*[
ConvReLUNorm(input_size if i == 0 else filter_size, filter_size, ConvReLUNorm(input_size if i == 0 else filter_size, filter_size,
kernel_size=kernel_size, dropout=dropout) kernel_size=kernel_size, dropout=dropout)
for i in range(n_layers)] for i in range(n_layers)]
) )
self.fc = nn.Linear(filter_size, 1, bias=True) self.fc = nn.Linear(filter_size, 1, bias=True)
def forward(self, enc_out, enc_out_mask): def forward(self, enc_out, enc_out_mask):
out = enc_out * enc_out_mask out = enc_out * enc_out_mask
out = self.layers(out.transpose(1, 2)).transpose(1, 2) out = self.layers(out.transpose(1, 2)).transpose(1, 2)
out = self.fc(out) * enc_out_mask out = self.fc(out) * enc_out_mask
return out.squeeze(-1) return out.squeeze(-1)
class FastPitch(nn.Module): class FastPitch(nn.Module):
def __init__(self, n_mel_channels, max_seq_len, n_symbols, def __init__(self, n_mel_channels, max_seq_len, n_symbols, padding_idx,
symbols_embedding_dim, in_fft_n_layers, in_fft_n_heads, symbols_embedding_dim, in_fft_n_layers, in_fft_n_heads,
in_fft_d_head, in_fft_d_head,
in_fft_conv1d_kernel_size, in_fft_conv1d_filter_size, in_fft_conv1d_kernel_size, in_fft_conv1d_filter_size,
in_fft_output_size, in_fft_output_size,
p_in_fft_dropout, p_in_fft_dropatt, p_in_fft_dropemb, p_in_fft_dropout, p_in_fft_dropatt, p_in_fft_dropemb,
out_fft_n_layers, out_fft_n_heads, out_fft_d_head, out_fft_n_layers, out_fft_n_heads, out_fft_d_head,
out_fft_conv1d_kernel_size, out_fft_conv1d_filter_size, out_fft_conv1d_kernel_size, out_fft_conv1d_filter_size,
out_fft_output_size, out_fft_output_size,
p_out_fft_dropout, p_out_fft_dropatt, p_out_fft_dropemb, p_out_fft_dropout, p_out_fft_dropatt, p_out_fft_dropemb,
dur_predictor_kernel_size, dur_predictor_filter_size, dur_predictor_kernel_size, dur_predictor_filter_size,
p_dur_predictor_dropout, dur_predictor_n_layers, p_dur_predictor_dropout, dur_predictor_n_layers,
pitch_predictor_kernel_size, pitch_predictor_filter_size, pitch_predictor_kernel_size, pitch_predictor_filter_size,
p_pitch_predictor_dropout, pitch_predictor_n_layers): p_pitch_predictor_dropout, pitch_predictor_n_layers,
super(FastPitch, self).__init__() pitch_embedding_kernel_size, n_speakers, speaker_emb_weight):
del max_seq_len # unused super(FastPitch, self).__init__()
del n_symbols del max_seq_len # unused
self.encoder = FFTransformer( self.encoder = FFTransformer(
n_layer=in_fft_n_layers, n_head=in_fft_n_heads, n_layer=in_fft_n_layers, n_head=in_fft_n_heads,
d_model=symbols_embedding_dim, d_model=symbols_embedding_dim,
d_head=in_fft_d_head, d_head=in_fft_d_head,
d_inner=in_fft_conv1d_filter_size, d_inner=in_fft_conv1d_filter_size,
kernel_size=in_fft_conv1d_kernel_size, kernel_size=in_fft_conv1d_kernel_size,
dropout=p_in_fft_dropout, dropout=p_in_fft_dropout,
dropatt=p_in_fft_dropatt, dropatt=p_in_fft_dropatt,
dropemb=p_in_fft_dropemb, dropemb=p_in_fft_dropemb,
d_embed=symbols_embedding_dim, embed_input=True,
embed_input=True) d_embed=symbols_embedding_dim,
n_embed=n_symbols,
self.duration_predictor = TemporalPredictor( padding_idx=padding_idx)
in_fft_output_size,
filter_size=dur_predictor_filter_size, if n_speakers > 1:
kernel_size=dur_predictor_kernel_size, self.speaker_emb = nn.Embedding(n_speakers, symbols_embedding_dim)
dropout=p_dur_predictor_dropout, n_layers=dur_predictor_n_layers else:
) self.speaker_emb = None
self.speaker_emb_weight = speaker_emb_weight
self.decoder = FFTransformer(
n_layer=out_fft_n_layers, n_head=out_fft_n_heads, self.duration_predictor = TemporalPredictor(
d_model=symbols_embedding_dim, in_fft_output_size,
d_head=out_fft_d_head, filter_size=dur_predictor_filter_size,
d_inner=out_fft_conv1d_filter_size, kernel_size=dur_predictor_kernel_size,
kernel_size=out_fft_conv1d_kernel_size, dropout=p_dur_predictor_dropout, n_layers=dur_predictor_n_layers
dropout=p_out_fft_dropout, )
dropatt=p_out_fft_dropatt,
dropemb=p_out_fft_dropemb, self.decoder = FFTransformer(
d_embed=symbols_embedding_dim, n_layer=out_fft_n_layers, n_head=out_fft_n_heads,
embed_input=False) d_model=symbols_embedding_dim,
d_head=out_fft_d_head,
self.pitch_predictor = TemporalPredictor( d_inner=out_fft_conv1d_filter_size,
in_fft_output_size, kernel_size=out_fft_conv1d_kernel_size,
filter_size=pitch_predictor_filter_size, dropout=p_out_fft_dropout,
kernel_size=pitch_predictor_kernel_size, dropatt=p_out_fft_dropatt,
dropout=p_pitch_predictor_dropout, n_layers=pitch_predictor_n_layers dropemb=p_out_fft_dropemb,
) embed_input=False,
self.pitch_emb = nn.Conv1d(1, symbols_embedding_dim, kernel_size=3, d_embed=symbols_embedding_dim
padding=1) )
# Store values precomputed for training data within the model self.pitch_predictor = TemporalPredictor(
self.register_buffer('pitch_mean', torch.zeros(1)) in_fft_output_size,
self.register_buffer('pitch_std', torch.zeros(1)) filter_size=pitch_predictor_filter_size,
kernel_size=pitch_predictor_kernel_size,
self.proj = nn.Linear(out_fft_output_size, n_mel_channels, bias=True) dropout=p_pitch_predictor_dropout, n_layers=pitch_predictor_n_layers
)
def forward(self, inputs: List[torch.Tensor], use_gt_durations: bool = True,
use_gt_pitch: bool = True, pace: float = 1.0, self.pitch_emb = nn.Conv1d(
max_duration: int = 75): 1, symbols_embedding_dim,
inputs, _, mel_tgt, _, dur_tgt, _, pitch_tgt = inputs kernel_size=pitch_embedding_kernel_size,
mel_max_len = mel_tgt.size(2) padding=int((pitch_embedding_kernel_size - 1) / 2))
# Input FFT # Store values precomputed for training data within the model
enc_out, enc_mask = self.encoder(inputs) self.register_buffer('pitch_mean', torch.zeros(1))
self.register_buffer('pitch_std', torch.zeros(1))
# Embedded for predictors
pred_enc_out, pred_enc_mask = enc_out, enc_mask self.proj = nn.Linear(out_fft_output_size, n_mel_channels, bias=True)
# Predict durations def forward(self, inputs: List[torch.Tensor], use_gt_durations: bool = True,
log_dur_pred = self.duration_predictor(pred_enc_out, pred_enc_mask) use_gt_pitch: bool = True, pace: float = 1.0,
dur_pred = torch.clamp(torch.exp(log_dur_pred) - 1, 0, max_duration) max_duration: int = 75):
inputs, _, mel_tgt, _, dur_tgt, _, pitch_tgt, speaker = inputs
# Predict pitch mel_max_len = mel_tgt.size(2)
pitch_pred = self.pitch_predictor(enc_out, enc_mask)
# Calculate speaker embedding
if use_gt_pitch and pitch_tgt is not None: if self.speaker_emb is None:
pitch_emb = self.pitch_emb(pitch_tgt.unsqueeze(1)) spk_emb = 0
else: else:
pitch_emb = self.pitch_emb(pitch_pred.unsqueeze(1)) spk_emb = self.speaker_emb(speaker).unsqueeze(1)
enc_out = enc_out + pitch_emb.transpose(1, 2) spk_emb.mul_(self.speaker_emb_weight)
len_regulated, dec_lens = regulate_len( # Input FFT
dur_tgt if use_gt_durations else dur_pred, enc_out, enc_mask = self.encoder(inputs, conditioning=spk_emb)
enc_out, pace, mel_max_len)
# Embedded for predictors
# Output FFT pred_enc_out, pred_enc_mask = enc_out, enc_mask
dec_out, dec_mask = self.decoder(len_regulated, dec_lens)
mel_out = self.proj(dec_out) # Predict durations
return mel_out, dec_mask, dur_pred, log_dur_pred, pitch_pred log_dur_pred = self.duration_predictor(pred_enc_out, pred_enc_mask)
dur_pred = torch.clamp(torch.exp(log_dur_pred) - 1, 0, max_duration)
def infer(self, inputs, input_lens, pace: float = 1.0,
dur_tgt: Optional[torch.Tensor] = None, # Predict pitch
pitch_tgt: Optional[torch.Tensor] = None, pitch_pred = self.pitch_predictor(enc_out, enc_mask)
max_duration: float = 75):
if use_gt_pitch and pitch_tgt is not None:
# Input FFT pitch_emb = self.pitch_emb(pitch_tgt.unsqueeze(1))
enc_out, enc_mask = self.encoder(inputs) else:
pitch_emb = self.pitch_emb(pitch_pred.unsqueeze(1))
# Embedded for predictors enc_out = enc_out + pitch_emb.transpose(1, 2)
pred_enc_out, pred_enc_mask = enc_out, enc_mask
len_regulated, dec_lens = regulate_len(
# Predict durations dur_tgt if use_gt_durations else dur_pred,
log_dur_pred = self.duration_predictor(pred_enc_out, pred_enc_mask) enc_out, pace, mel_max_len)
dur_pred = torch.clamp(torch.exp(log_dur_pred) - 1, 0, max_duration)
# Output FFT
# Pitch over chars dec_out, dec_mask = self.decoder(len_regulated, dec_lens)
pitch_pred = self.pitch_predictor(enc_out, enc_mask) mel_out = self.proj(dec_out)
return mel_out, dec_mask, dur_pred, log_dur_pred, pitch_pred
if pitch_tgt is None:
pitch_emb = self.pitch_emb(pitch_pred.unsqueeze(1)).transpose(1, 2) def infer(self, inputs, input_lens, pace: float = 1.0,
else: dur_tgt: Optional[torch.Tensor] = None,
pitch_emb = self.pitch_emb(pitch_tgt.unsqueeze(1)).transpose(1, 2) pitch_tgt: Optional[torch.Tensor] = None,
max_duration: float = 75,
enc_out = enc_out + pitch_emb speaker: int = 0):
del input_lens # unused
len_regulated, dec_lens = regulate_len(
dur_pred if dur_tgt is None else dur_tgt, if self.speaker_emb is None:
enc_out, pace, mel_max_len=None) spk_emb = None
else:
dec_out, dec_mask = self.decoder(len_regulated, dec_lens) speaker = torch.ones(inputs.size(0), dtype=torch.long, device=inputs.device).fill_(speaker)
mel_out = self.proj(dec_out) spk_emb = self.speaker_emb(speaker).unsqueeze(1)
# mel_lens = dec_mask.squeeze(2).sum(axis=1).long() spk_emb.mul_(self.speaker_emb_weight)
mel_out = mel_out.permute(0, 2, 1) # For inference.py
return mel_out, dec_lens, dur_pred, pitch_pred # Input FFT
enc_out, enc_mask = self.encoder(inputs, conditioning=spk_emb)
# Embedded for predictors
pred_enc_out, pred_enc_mask = enc_out, enc_mask
# Predict durations
log_dur_pred = self.duration_predictor(pred_enc_out, pred_enc_mask)
dur_pred = torch.clamp(torch.exp(log_dur_pred) - 1, 0, max_duration)
# Pitch over chars
pitch_pred = self.pitch_predictor(enc_out, enc_mask)
if pitch_tgt is None:
pitch_emb = self.pitch_emb(pitch_pred.unsqueeze(1)).transpose(1, 2)
else:
pitch_emb = self.pitch_emb(pitch_tgt.unsqueeze(1)).transpose(1, 2)
enc_out = enc_out + pitch_emb
len_regulated, dec_lens = regulate_len(
dur_pred if dur_tgt is None else dur_tgt,
enc_out, pace, mel_max_len=None)
dec_out, dec_mask = self.decoder(len_regulated, dec_lens)
mel_out = self.proj(dec_out)
# mel_lens = dec_mask.squeeze(2).sum(axis=1).long()
mel_out = mel_out.permute(0, 2, 1) # For inference.py
return mel_out, dec_lens, dur_pred, pitch_pred

View file

@ -17,7 +17,6 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from common.utils import mask_from_lens from common.utils import mask_from_lens
from common.text.symbols import pad_idx, symbols
class PositionalEmbedding(nn.Module): class PositionalEmbedding(nn.Module):
@ -248,16 +247,17 @@ class TransformerLayer(nn.Module):
class FFTransformer(nn.Module): class FFTransformer(nn.Module):
def __init__(self, n_layer, n_head, d_model, d_head, d_inner, kernel_size, def __init__(self, n_layer, n_head, d_model, d_head, d_inner, kernel_size,
dropout, dropatt, dropemb=0.0, embed_input=True, d_embed=None, dropout, dropatt, dropemb=0.0, embed_input=True,
pre_lnorm=False): n_embed=None, d_embed=None, padding_idx=0, pre_lnorm=False):
super(FFTransformer, self).__init__() super(FFTransformer, self).__init__()
self.d_model = d_model self.d_model = d_model
self.n_head = n_head self.n_head = n_head
self.d_head = d_head self.d_head = d_head
self.padding_idx = padding_idx
if embed_input: if embed_input:
self.word_emb = nn.Embedding(len(symbols), d_embed or d_model, self.word_emb = nn.Embedding(n_embed, d_embed or d_model,
padding_idx=pad_idx) padding_idx=self.padding_idx)
else: else:
self.word_emb = None self.word_emb = None
@ -272,18 +272,18 @@ class FFTransformer(nn.Module):
dropatt=dropatt, pre_lnorm=pre_lnorm) dropatt=dropatt, pre_lnorm=pre_lnorm)
) )
def forward(self, dec_inp, seq_lens=None): def forward(self, dec_inp, seq_lens=None, conditioning=0):
if self.word_emb is None: if self.word_emb is None:
inp = dec_inp inp = dec_inp
mask = mask_from_lens(seq_lens).unsqueeze(2) mask = mask_from_lens(seq_lens).unsqueeze(2)
else: else:
inp = self.word_emb(dec_inp) inp = self.word_emb(dec_inp)
# [bsz x L x 1] # [bsz x L x 1]
mask = (dec_inp != pad_idx).unsqueeze(2) mask = (dec_inp != self.padding_idx).unsqueeze(2)
pos_seq = torch.arange(inp.size(1), device=inp.device, dtype=inp.dtype) pos_seq = torch.arange(inp.size(1), device=inp.device, dtype=inp.dtype)
pos_emb = self.pos_emb(pos_seq) * mask pos_emb = self.pos_emb(pos_seq) * mask
out = self.drop(inp + pos_emb) out = self.drop(inp + pos_emb + conditioning)
for layer in self.layers: for layer in self.layers:
out = layer(out, mask=mask) out = layer(out, mask=mask)

View file

@ -19,7 +19,6 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from common.utils import mask_from_lens from common.utils import mask_from_lens
from common.text.symbols import pad_idx, symbols
class NoOp(nn.Module): class NoOp(nn.Module):
@ -255,20 +254,20 @@ class TransformerLayer(nn.Module):
class FFTransformer(nn.Module): class FFTransformer(nn.Module):
pad_idx = 0 # XXX
def __init__(self, n_layer, n_head, d_model, d_head, d_inner, kernel_size, def __init__(self, n_layer, n_head, d_model, d_head, d_inner, kernel_size,
dropout, dropatt, dropemb=0.0, embed_input=True, d_embed=None, dropout, dropatt, dropemb=0.0, embed_input=True,
pre_lnorm=False): n_embed=None, d_embed=None, padding_idx=0, pre_lnorm=False):
super(FFTransformer, self).__init__() super(FFTransformer, self).__init__()
self.d_model = d_model self.d_model = d_model
self.n_head = n_head self.n_head = n_head
self.d_head = d_head self.d_head = d_head
self.padding_idx = padding_idx
self.n_embed = n_embed
self.embed_input = embed_input self.embed_input = embed_input
if embed_input: if embed_input:
self.word_emb = nn.Embedding(len(symbols), d_embed or d_model, self.word_emb = nn.Embedding(n_embed, d_embed or d_model,
padding_idx=FFTransformer.pad_idx) padding_idx=self.padding_idx)
else: else:
self.word_emb = NoOp() self.word_emb = NoOp()
@ -283,19 +282,23 @@ class FFTransformer(nn.Module):
dropatt=dropatt, pre_lnorm=pre_lnorm) dropatt=dropatt, pre_lnorm=pre_lnorm)
) )
def forward(self, dec_inp, seq_lens: Optional[torch.Tensor] = None): def forward(self, dec_inp, seq_lens: Optional[torch.Tensor] = None,
if self.embed_input: conditioning: Optional[torch.Tensor] = None):
inp = self.word_emb(dec_inp) if not self.embed_input:
# [bsz x L x 1]
# mask = (dec_inp != FFTransformer.pad_idx).unsqueeze(2)
mask = (dec_inp != 0).unsqueeze(2)
else:
inp = dec_inp inp = dec_inp
assert seq_lens is not None assert seq_lens is not None
mask = mask_from_lens(seq_lens).unsqueeze(2) mask = mask_from_lens(seq_lens).unsqueeze(2)
else:
inp = self.word_emb(dec_inp)
# [bsz x L x 1]
mask = (dec_inp != self.padding_idx).unsqueeze(2)
pos_seq = torch.arange(inp.size(1), device=inp.device, dtype=inp.dtype) pos_seq = torch.arange(inp.size(1), device=inp.device, dtype=inp.dtype)
pos_emb = self.pos_emb(pos_seq) * mask pos_emb = self.pos_emb(pos_seq) * mask
out = self.drop(inp + pos_emb) if conditioning is not None:
out = self.drop(inp + pos_emb + conditioning)
else:
out = self.drop(inp + pos_emb)
for layer in self.layers: for layer in self.layers:
out = layer(out, mask=mask) out = layer(out, mask=mask)

View file

@ -45,7 +45,7 @@ from dllogger import StdOutBackend, JSONStreamBackend, Verbosity
from common import utils from common import utils
from common.tb_dllogger import (init_inference_metadata, stdout_metric_format, from common.tb_dllogger import (init_inference_metadata, stdout_metric_format,
unique_log_fpath) unique_log_fpath)
from common.text import text_to_sequence from common.text.text_processing import TextProcessing
from pitch_transform import pitch_transform_custom from pitch_transform import pitch_transform_custom
from waveglow import model as glow from waveglow import model as glow
from waveglow.denoiser import Denoiser from waveglow.denoiser import Denoiser
@ -92,9 +92,11 @@ def parse_args(parser):
help='Use EMA averaged model (if saved in checkpoints)') help='Use EMA averaged model (if saved in checkpoints)')
parser.add_argument('--dataset-path', type=str, parser.add_argument('--dataset-path', type=str,
help='Path to dataset (for loading extra data fields)') help='Path to dataset (for loading extra data fields)')
parser.add_argument('--speaker', type=int, default=0,
help='Speaker ID for a multi-speaker model')
transform = parser.add_argument_group('transform') transform = parser.add_argument_group('transform')
transform.add_argument('--fade-out', type=int, default=5, transform.add_argument('--fade-out', type=int, default=10,
help='Number of fadeout frames at the end') help='Number of fadeout frames at the end')
transform.add_argument('--pace', type=float, default=1.0, transform.add_argument('--pace', type=float, default=1.0,
help='Adjust the pace of speech') help='Adjust the pace of speech')
@ -108,6 +110,18 @@ def parse_args(parser):
help='Raise/lower the pitch by <hz>') help='Raise/lower the pitch by <hz>')
transform.add_argument('--pitch-transform-custom', action='store_true', transform.add_argument('--pitch-transform-custom', action='store_true',
help='Apply the transform from pitch_transform.py') help='Apply the transform from pitch_transform.py')
text_processing = parser.add_argument_group('Text processing parameters')
text_processing.add_argument('--text-cleaners', nargs='*',
default=['english_cleaners'], type=str,
help='Type of text cleaners for input text')
text_processing.add_argument('--symbol-set', type=str, default='english_basic',
help='Define symbol set for input text')
cond = parser.add_argument_group('conditioning on additional attributes')
cond.add_argument('--n-speakers', type=int, default=1,
help='Number of speakers in the model.')
return parser return parser
@ -138,7 +152,7 @@ def load_and_setup_model(model_name, parser, checkpoint, amp, device,
if any(key.startswith('module.') for key in sd): if any(key.startswith('module.') for key in sd):
sd = {k.replace('module.', ''): v for k,v in sd.items()} sd = {k.replace('module.', ''): v for k,v in sd.items()}
status += ' ' + str(model.load_state_dict(sd, strict=False)) status += ' ' + str(model.load_state_dict(sd, strict=True))
else: else:
model = checkpoint_data['model'] model = checkpoint_data['model']
print(f'Loaded {model_name}{status}') print(f'Loaded {model_name}{status}')
@ -162,10 +176,13 @@ def load_fields(fpath):
return {c:f for c, f in zip(columns, fields)} return {c:f for c, f in zip(columns, fields)}
def prepare_input_sequence(fields, device, batch_size=128, dataset=None, def prepare_input_sequence(fields, device, symbol_set, text_cleaners,
load_mels=False, load_pitch=False): batch_size=128, dataset=None, load_mels=False,
fields['text'] = [torch.LongTensor(text_to_sequence(t, ['english_cleaners'])) load_pitch=False):
for t in fields['text']] tp = TextProcessing(symbol_set, text_cleaners)
fields['text'] = [torch.LongTensor(tp.encode_text(text))
for text in fields['text']]
order = np.argsort([-t.size(0) for t in fields['text']]) order = np.argsort([-t.size(0) for t in fields['text']])
fields['text'] = [fields['text'][i] for i in order] fields['text'] = [fields['text'][i] for i in order]
@ -206,7 +223,6 @@ def prepare_input_sequence(fields, device, batch_size=128, dataset=None,
def build_pitch_transformation(args): def build_pitch_transformation(args):
if args.pitch_transform_custom: if args.pitch_transform_custom:
def custom_(pitch, pitch_lens, mean, std): def custom_(pitch, pitch_lens, mean, std):
return (pitch_transform_custom(pitch * std + mean, pitch_lens) return (pitch_transform_custom(pitch * std + mean, pitch_lens)
@ -262,7 +278,7 @@ def main():
StdOutBackend(Verbosity.VERBOSE, StdOutBackend(Verbosity.VERBOSE,
metric_format=stdout_metric_format)]) metric_format=stdout_metric_format)])
init_inference_metadata() init_inference_metadata()
[DLLogger.log("PARAMETER", {k:v}) for k,v in vars(args).items()] [DLLogger.log("PARAMETER", {k: v}) for k, v in vars(args).items()]
device = torch.device('cuda' if args.cuda else 'cpu') device = torch.device('cuda' if args.cuda else 'cpu')
@ -293,8 +309,8 @@ def main():
fields = load_fields(args.input) fields = load_fields(args.input)
batches = prepare_input_sequence( batches = prepare_input_sequence(
fields, device, args.batch_size, args.dataset_path, fields, device, args.symbol_set, args.text_cleaners, args.batch_size,
load_mels=(generator is None)) args.dataset_path, load_mels=(generator is None))
if args.include_warmup: if args.include_warmup:
# Use real data rather than synthetic - FastPitch predicts len # Use real data rather than synthetic - FastPitch predicts len
@ -311,11 +327,13 @@ def main():
waveglow_measures = MeasureTime() waveglow_measures = MeasureTime()
gen_kw = {'pace': args.pace, gen_kw = {'pace': args.pace,
'speaker': args.speaker,
'pitch_tgt': None, 'pitch_tgt': None,
'pitch_transform': build_pitch_transformation(args)} 'pitch_transform': build_pitch_transformation(args)}
if args.torchscript: if args.torchscript:
gen_kw.pop('pitch_transform') gen_kw.pop('pitch_transform')
print('NOTE: Pitch transforms are disabled with TorchScript')
all_utterances = 0 all_utterances = 0
all_samples = 0 all_samples = 0
@ -323,11 +341,10 @@ def main():
all_frames = 0 all_frames = 0
reps = args.repeats reps = args.repeats
log_enabled = True # reps == 1 log_enabled = reps == 1
log = lambda s, d: DLLogger.log(step=s, data=d) if log_enabled else None log = lambda s, d: DLLogger.log(step=s, data=d) if log_enabled else None
# for repeat in (tqdm.tqdm(range(reps)) if reps > 1 else range(reps)): for rep in (tqdm.tqdm(range(reps)) if reps > 1 else range(reps)):
for rep in range(reps):
for b in batches: for b in batches:
if generator is None: if generator is None:
log(rep, {'Synthesizing from ground truth mels'}) log(rep, {'Synthesizing from ground truth mels'})
@ -348,7 +365,7 @@ def main():
audios = waveglow(mel, sigma=args.sigma_infer) audios = waveglow(mel, sigma=args.sigma_infer)
audios = denoiser(audios.float(), audios = denoiser(audios.float(),
strength=args.denoising_strength strength=args.denoising_strength
).squeeze(1) ).squeeze(1)
all_utterances += len(audios) all_utterances += len(audios)
all_samples += sum(audio.size(0) for audio in audios) all_samples += sum(audio.size(0) for audio in audios)
@ -367,7 +384,7 @@ def main():
fade_w = torch.linspace(1.0, 0.0, fade_len) fade_w = torch.linspace(1.0, 0.0, fade_len)
audio[-fade_len:] *= fade_w.to(audio.device) audio[-fade_len:] *= fade_w.to(audio.device)
audio = audio/torch.max(torch.abs(audio)) audio = audio / torch.max(torch.abs(audio))
fname = b['output'][i] if 'output' in b else f'audio_{i}.wav' fname = b['output'][i] if 'output' in b else f'audio_{i}.wav'
audio_path = Path(args.output, fname) audio_path = Path(args.output, fname)
write(audio_path, args.sampling_rate, audio.cpu().numpy()) write(audio_path, args.sampling_rate, audio.cpu().numpy())

View file

@ -37,6 +37,7 @@ from fastpitch.model import FastPitch as _FastPitch
from fastpitch.model_jit import FastPitch as _FastPitchJIT from fastpitch.model_jit import FastPitch as _FastPitchJIT
from tacotron2.model import Tacotron2 from tacotron2.model import Tacotron2
from waveglow.model import WaveGlow from waveglow.model import WaveGlow
from common.text.symbols import get_symbols, get_pad_idx
def parse_model_args(model_name, parser, add_help=False): def parse_model_args(model_name, parser, add_help=False):
@ -94,25 +95,27 @@ def get_model(model_name, model_config, device,
model = WaveGlow(**model_config) model = WaveGlow(**model_config)
elif model_name == 'FastPitch': elif model_name == 'FastPitch':
if forward_is_infer: if forward_is_infer:
if jitable: if jitable:
class FastPitch__forward_is_infer(_FastPitchJIT): class FastPitch__forward_is_infer(_FastPitchJIT):
def forward(self, inputs, input_lengths, pace: float = 1.0, def forward(self, inputs, input_lengths, pace: float = 1.0,
dur_tgt: Optional[torch.Tensor] = None, dur_tgt: Optional[torch.Tensor] = None,
pitch_tgt: Optional[torch.Tensor] = None): pitch_tgt: Optional[torch.Tensor] = None,
speaker: int = 0):
return self.infer(inputs, input_lengths, pace=pace, return self.infer(inputs, input_lengths, pace=pace,
dur_tgt=dur_tgt, pitch_tgt=pitch_tgt) dur_tgt=dur_tgt, pitch_tgt=pitch_tgt,
speaker=speaker)
else: else:
class FastPitch__forward_is_infer(_FastPitch): class FastPitch__forward_is_infer(_FastPitch):
def forward(self, inputs, input_lengths, pace: float = 1.0, def forward(self, inputs, input_lengths, pace: float = 1.0,
dur_tgt: Optional[torch.Tensor] = None, dur_tgt: Optional[torch.Tensor] = None,
pitch_tgt: Optional[torch.Tensor] = None, pitch_tgt: Optional[torch.Tensor] = None,
pitch_transform=None): pitch_transform=None,
speaker: Optional[int] = None):
return self.infer(inputs, input_lengths, pace=pace, return self.infer(inputs, input_lengths, pace=pace,
dur_tgt=dur_tgt, pitch_tgt=pitch_tgt, dur_tgt=dur_tgt, pitch_tgt=pitch_tgt,
pitch_transform=pitch_transform) pitch_transform=pitch_transform,
speaker=speaker)
model = FastPitch__forward_is_infer(**model_config) model = FastPitch__forward_is_infer(**model_config)
else: else:
@ -136,7 +139,7 @@ def get_model_config(model_name, args):
# audio # audio
n_mel_channels=args.n_mel_channels, n_mel_channels=args.n_mel_channels,
# symbols # symbols
n_symbols=args.n_symbols, n_symbols=len(get_symbols(args.symbol_set)),
symbols_embedding_dim=args.symbols_embedding_dim, symbols_embedding_dim=args.symbols_embedding_dim,
# encoder # encoder
encoder_kernel_size=args.encoder_kernel_size, encoder_kernel_size=args.encoder_kernel_size,
@ -183,7 +186,8 @@ def get_model_config(model_name, args):
n_mel_channels=args.n_mel_channels, n_mel_channels=args.n_mel_channels,
max_seq_len=args.max_seq_len, max_seq_len=args.max_seq_len,
# symbols # symbols
n_symbols=args.n_symbols, n_symbols=len(get_symbols(args.symbol_set)),
padding_idx=get_pad_idx(args.symbol_set),
symbols_embedding_dim=args.symbols_embedding_dim, symbols_embedding_dim=args.symbols_embedding_dim,
# input FFT # input FFT
in_fft_n_layers=args.in_fft_n_layers, in_fft_n_layers=args.in_fft_n_layers,
@ -215,6 +219,11 @@ def get_model_config(model_name, args):
pitch_predictor_filter_size=args.pitch_predictor_filter_size, pitch_predictor_filter_size=args.pitch_predictor_filter_size,
p_pitch_predictor_dropout=args.p_pitch_predictor_dropout, p_pitch_predictor_dropout=args.p_pitch_predictor_dropout,
pitch_predictor_n_layers=args.pitch_predictor_n_layers, pitch_predictor_n_layers=args.pitch_predictor_n_layers,
# pitch conditioning
pitch_embedding_kernel_size=args.pitch_embedding_kernel_size,
# speakers parameters
n_speakers=args.n_speakers,
speaker_emb_weight=args.speaker_emb_weight
) )
return model_config return model_config

View file

@ -6,16 +6,13 @@ DATA_DIR="LJSpeech-1.1"
LJS_ARCH="LJSpeech-1.1.tar.bz2" LJS_ARCH="LJSpeech-1.1.tar.bz2"
LJS_URL="http://data.keithito.com/data/speech/${LJS_ARCH}" LJS_URL="http://data.keithito.com/data/speech/${LJS_ARCH}"
if [ ! -f ${LJS_ARCH} ]; then if [ ! -d ${DATA_DIR} ]; then
echo "Downloading ${LJS_ARCH} ..." echo "Downloading ${LJS_ARCH} ..."
wget -q ${LJS_URL} wget -q ${LJS_URL}
fi
if [ ! -d ${DATA_DIR} ]; then
echo "Extracting ${LJS_ARCH} ..." echo "Extracting ${LJS_ARCH} ..."
tar jxvf ${LJS_ARCH} tar jxvf ${LJS_ARCH}
rm -f ${LJS_ARCH} rm -f ${LJS_ARCH}
fi fi
bash scripts/download_tacotron2.sh bash ./scripts/download_tacotron2.sh
bash scripts/download_waveglow.sh bash ./scripts/download_waveglow.sh

View file

@ -2,19 +2,20 @@
set -e set -e
MODEL_DIR=${MODEL_DIR:-"pretrained_models"} : ${MODEL_DIR:="pretrained_models/fastpitch"}
FASTP_ZIP="nvidia_fastpitch_200518.zip" MODEL_ZIP="nvidia_fastpitch_200518.zip"
FASTP_CH="nvidia_fastpitch_200518.pt" MODEL_CH="nvidia_fastpitch_200518.pt"
FASTP_URL="https://api.ngc.nvidia.com/v2/models/nvidia/fastpitch_pyt_amp_ckpt_v1/versions/20.02.0/zip" MODEL_URL="https://api.ngc.nvidia.com/v2/models/nvidia/fastpitch_pyt_amp_ckpt_v1/versions/20.02.0/zip"
mkdir -p "$MODEL_DIR"/fastpitch mkdir -p "$MODEL_DIR"
if [ ! -f "${MODEL_DIR}/fastpitch/${FASTP_ZIP}" ]; then if [ ! -f "${MODEL_DIR}/${MODEL_ZIP}" ]; then
echo "Downloading ${FASTP_ZIP} ..." echo "Downloading ${MODEL_ZIP} ..."
wget -qO ${MODEL_DIR}/fastpitch/${FASTP_ZIP} ${FASTP_URL} wget -qO ${MODEL_DIR}/${MODEL_ZIP} ${MODEL_URL} \
|| echo "ERROR: Failed to download ${MODEL_ZIP} from NGC" && exit 1
fi fi
if [ ! -f "${MODEL_DIR}/fastpitch/${FASTP_CH}" ]; then if [ ! -f "${MODEL_DIR}/${MODEL_CH}" ]; then
echo "Extracting ${FASTP_CH} ..." echo "Extracting ${MODEL_CH} ..."
unzip -qo ${MODEL_DIR}/fastpitch/${FASTP_ZIP} -d ${MODEL_DIR}/fastpitch/ unzip -qo ${MODEL_DIR}/${MODEL_ZIP} -d ${MODEL_DIR}
fi fi

View file

@ -2,12 +2,18 @@
set -e set -e
MODEL_DIR=${MODEL_DIR:-"pretrained_models"} : ${MODEL_DIR:="pretrained_models/tacotron2"}
TACO_CH="nvidia_tacotron2pyt_fp32_20190427.pt" MODEL="nvidia_tacotron2pyt_fp16.pt"
TACO_URL="https://api.ngc.nvidia.com/v2/models/nvidia/tacotron2pyt_fp32/versions/2/files/nvidia_tacotron2pyt_fp32_20190427" MODEL_URL="https://api.ngc.nvidia.com/v2/models/nvidia/tacotron2_pyt_ckpt_amp/versions/19.12.0/files/nvidia_tacotron2pyt_fp16.pt"
if [ ! -f "${MODEL_DIR}/tacotron2/${TACO_CH}" ]; then mkdir -p "$MODEL_DIR"
echo "Downloading ${TACO_CH} ..."
mkdir -p "$MODEL_DIR"/tacotron2 if [ ! -f "${MODEL_DIR}/${MODEL}" ]; then
wget -qO ${MODEL_DIR}/tacotron2/${TACO_CH} ${TACO_URL} echo "Downloading ${MODEL} ..."
wget --content-disposition -qO ${MODEL_DIR}/${MODEL} ${MODEL_URL} \
|| echo "ERROR: Failed to download ${MODEL} from NGC" && exit 1
echo "OK"
else
echo "${MODEL}.pt already downloaded."
fi fi

View file

@ -2,19 +2,26 @@
set -e set -e
MODEL_DIR=${MODEL_DIR:-"pretrained_models"} : ${MODEL_DIR:="pretrained_models/waveglow"}
WAVEG="waveglow_1076430_14000_amp" MODEL="nvidia_waveglow256pyt_fp16"
WAVEG_URL="https://api.ngc.nvidia.com/v2/models/nvidia/waveglow256pyt_fp16/versions/2/zip" MODEL_ZIP="waveglow_ckpt_amp_256_20.01.0.zip"
MODEL_URL="https://api.ngc.nvidia.com/v2/models/nvidia/waveglow_ckpt_amp_256/versions/20.01.0/zip"
mkdir -p "$MODEL_DIR"/waveglow mkdir -p "$MODEL_DIR"
if [ ! -f "${MODEL_DIR}/waveglow/${WAVEG}.zip" ]; then if [ ! -f "${MODEL_DIR}/${MODEL_ZIP}" ]; then
echo "Downloading ${WAVEG}.zip ..." echo "Downloading ${MODEL_ZIP} ..."
wget -qO "${MODEL_DIR}/waveglow/${WAVEG}.zip" ${WAVEG_URL} wget --content-disposition -qO ${MODEL_DIR}/${MODEL_ZIP} ${MODEL_URL} \
|| echo "ERROR: Failed to download ${MODEL_ZIP} from NGC" && exit 1
fi fi
if [ ! -f "${MODEL_DIR}/waveglow/${WAVEG}.pt" ]; then if [ ! -f "${MODEL_DIR}/${MODEL}.pt" ]; then
echo "Extracting ${WAVEG} ..." echo "Extracting ${MODEL} ..."
unzip -qo "${MODEL_DIR}/waveglow/${WAVEG}.zip" -d ${MODEL_DIR}/waveglow/ unzip -qo ${MODEL_DIR}/${MODEL_ZIP} -d ${MODEL_DIR} \
mv "${MODEL_DIR}/waveglow/${WAVEG}" "${MODEL_DIR}/waveglow/${WAVEG}.pt" || echo "ERROR: Failed to extract ${MODEL_ZIP}" && exit 1
echo "OK"
else
echo "${MODEL}.pt already downloaded."
fi fi

View file

@ -1,22 +1,26 @@
#!/bin/bash #!/bin/bash
[ ! -n "$WAVEG_CH" ] && WAVEG_CH="pretrained_models/waveglow/waveglow_1076430_14000_amp.pt" : ${WAVEGLOW:="pretrained_models/waveglow/nvidia_waveglow256pyt_fp16.pt"}
[ ! -n "$FASTPITCH_CH" ] && FASTPITCH_CH="output/FastPitch_checkpoint_1500.pt" : ${FASTPITCH:="output/FastPitch_checkpoint_1500.pt"}
[ ! -n "$REPEATS" ] && REPEATS=1000 : ${REPEATS:=1000}
[ ! -n "$BS_SEQ" ] && BS_SEQ="1 4 8" : ${BS_SEQUENCE:="1 4 8"}
[ ! -n "$PHRASES" ] && PHRASES="phrases/benchmark_8_128.tsv" : ${PHRASES:="phrases/benchmark_8_128.tsv"}
[ ! -n "$OUTPUT_DIR" ] && OUTPUT_DIR="./output/audio_$(basename ${PHRASES} .tsv)" : ${OUTPUT_DIR:="./output/audio_$(basename ${PHRASES} .tsv)"}
[ "$AMP" == "true" ] && AMP_FLAG="--amp" || AMP=false : ${AMP:=false}
for BS in $BS_SEQ ; do [ "$AMP" = true ] && AMP_FLAG="--amp"
mkdir -o "$OUTPUT_DIR"
for BS in $BS_SEQUENCE ; do
echo -e "\nAMP: ${AMP}, batch size: ${BS}\n" echo -e "\nAMP: ${AMP}, batch size: ${BS}\n"
python inference.py --cuda --cudnn-benchmark \ python inference.py --cuda --cudnn-benchmark \
-i ${PHRASES} \ -i ${PHRASES} \
-o ${OUTPUT_DIR} \ -o ${OUTPUT_DIR} \
--fastpitch ${FASTPITCH_CH} \ --fastpitch ${FASTPITCH} \
--waveglow ${WAVEG_CH} \ --waveglow ${WAVEGLOW} \
--wn-channels 256 \ --wn-channels 256 \
--include-warmup \ --include-warmup \
--batch-size ${BS} \ --batch-size ${BS} \

View file

@ -1,22 +1,21 @@
#!/usr/bin/env bash #!/usr/bin/env bash
DATA_DIR="LJSpeech-1.1" : ${WAVEGLOW:="pretrained_models/waveglow/nvidia_waveglow256pyt_fp16.pt"}
: ${FASTPITCH:="output/FastPitch_checkpoint_1500.pt"}
: ${BS:=32}
: ${PHRASES:="phrases/devset10.tsv"}
: ${OUTPUT_DIR:="./output/audio_$(basename ${PHRASES} .tsv)"}
: ${AMP:=false}
[ ! -n "$WAVEG_CH" ] && WAVEG_CH="pretrained_models/waveglow/waveglow_1076430_14000_amp.pt" [ "$AMP" = true ] && AMP_FLAG="--amp"
[ ! -n "$FASTPITCH_CH" ] && FASTPITCH_CH="output/FastPitch_checkpoint_1500.pt"
[ ! -n "$BS" ] && BS=32
[ ! -n "$PHRASES" ] && PHRASES="phrases/devset10.tsv"
[ ! -n "$OUTPUT_DIR" ] && OUTPUT_DIR="./output/audio_$(basename ${PHRASES} .tsv)"
[ "$AMP" == "true" ] && AMP_FLAG="--amp"
mkdir -p "$OUTPUT_DIR" mkdir -p "$OUTPUT_DIR"
python inference.py --cuda \ python inference.py --cuda \
-i ${PHRASES} \ -i ${PHRASES} \
-o ${OUTPUT_DIR} \ -o ${OUTPUT_DIR} \
--dataset-path ${DATA_DIR} \ --fastpitch ${FASTPITCH} \
--fastpitch ${FASTPITCH_CH} \ --waveglow ${WAVEGLOW} \
--waveglow ${WAVEG_CH} \
--wn-channels 256 \ --wn-channels 256 \
--batch-size ${BS} \ --batch-size ${BS} \
${AMP_FLAG} ${AMP_FLAG}

View file

@ -3,7 +3,7 @@
set -e set -e
DATA_DIR="LJSpeech-1.1" DATA_DIR="LJSpeech-1.1"
TACO_CH="pretrained_models/tacotron2/nvidia_tacotron2pyt_fp32_20190427.pt" TACOTRON2="pretrained_models/tacotron2/nvidia_tacotron2pyt_fp16.pt"
for FILELIST in ljs_audio_text_train_filelist.txt \ for FILELIST in ljs_audio_text_train_filelist.txt \
ljs_audio_text_val_filelist.txt \ ljs_audio_text_val_filelist.txt \
ljs_audio_text_test_filelist.txt \ ljs_audio_text_test_filelist.txt \
@ -16,5 +16,5 @@ for FILELIST in ljs_audio_text_train_filelist.txt \
--extract-mels \ --extract-mels \
--extract-durations \ --extract-durations \
--extract-pitch-char \ --extract-pitch-char \
--tacotron2-checkpoint ${TACO_CH} --tacotron2-checkpoint ${TACOTRON2}
done done

View file

@ -2,24 +2,26 @@
export OMP_NUM_THREADS=1 export OMP_NUM_THREADS=1
: ${NUM_GPUS:=8}
: ${BS:=32}
: ${GRAD_ACCUMULATION:=1}
: ${OUTPUT_DIR:="./output"}
: ${AMP:=false}
: ${EPOCHS:=1500}
[ "$AMP" == "true" ] && AMP_FLAG="--amp"
# Adjust env variables to maintain the global batch size # Adjust env variables to maintain the global batch size
# #
# NGPU x BS x GRAD_ACC = 256. # NGPU x BS x GRAD_ACC = 256.
#
[ ! -n "$OUTPUT_DIR" ] && OUTPUT_DIR="./output" GBS=$(($NUM_GPUS * $BS * $GRAD_ACCUMULATION))
[ ! -n "$NGPU" ] && NGPU=8
[ ! -n "$BS" ] && BS=32
[ ! -n "$GRAD_ACC" ] && GRAD_ACC=1
[ ! -n "$EPOCHS" ] && EPOCHS=1500
[ "$AMP" == "true" ] && AMP_FLAG="--amp"
GBS=$(($NGPU * $BS * $GRAD_ACC))
[ $GBS -ne 256 ] && echo -e "\nWARNING: Global batch size changed from 256 to ${GBS}.\n" [ $GBS -ne 256 ] && echo -e "\nWARNING: Global batch size changed from 256 to ${GBS}.\n"
echo -e "\nSetup: ${NGPU}x${BS}x${GRAD_ACC} - global batch size ${GBS}\n" echo -e "\nSetup: ${NUM_GPUS}x${BS}x${GRAD_ACCUMULATION} - global batch size ${GBS}\n"
mkdir -p "$OUTPUT_DIR" mkdir -p "$OUTPUT_DIR"
python -m torch.distributed.launch --nproc_per_node ${NGPU} train.py \ python -m torch.distributed.launch --nproc_per_node ${NUM_GPUS} train.py \
--cuda \ --cuda \
-o "$OUTPUT_DIR/" \ -o "$OUTPUT_DIR/" \
--log-file "$OUTPUT_DIR/nvlog.json" \ --log-file "$OUTPUT_DIR/nvlog.json" \
@ -37,5 +39,5 @@ python -m torch.distributed.launch --nproc_per_node ${NGPU} train.py \
--dur-predictor-loss-scale 0.1 \ --dur-predictor-loss-scale 0.1 \
--pitch-predictor-loss-scale 0.1 \ --pitch-predictor-loss-scale 0.1 \
--weight-decay 1e-6 \ --weight-decay 1e-6 \
--gradient-accumulation-steps ${GRAD_ACC} \ --gradient-accumulation-steps ${GRAD_ACCUMULATION} \
${AMP_FLAG} ${AMP_FLAG}

View file

@ -27,9 +27,6 @@
import argparse import argparse
from common.text import symbols
def parse_tacotron2_args(parent, add_help=False): def parse_tacotron2_args(parent, add_help=False):
""" """
Parse commandline arguments. Parse commandline arguments.
@ -43,11 +40,7 @@ def parse_tacotron2_args(parent, add_help=False):
help='Number of bins in mel-spectrograms') help='Number of bins in mel-spectrograms')
# symbols parameters # symbols parameters
global symbols
len_symbols = len(symbols)
symbols = parser.add_argument_group('symbols parameters') symbols = parser.add_argument_group('symbols parameters')
symbols.add_argument('--n-symbols', default=len_symbols, type=int,
help='Number of symbols in dictionary')
symbols.add_argument('--symbols-embedding-dim', default=512, type=int, symbols.add_argument('--symbols-embedding-dim', default=512, type=int,
help='Input embedding dimension') help='Input embedding dimension')

View file

@ -33,8 +33,6 @@ import torch.utils.data
import common.layers as layers import common.layers as layers
from common.utils import load_wav_to_torch, load_filepaths_and_text, to_gpu from common.utils import load_wav_to_torch, load_filepaths_and_text, to_gpu
from common.text import text_to_sequence
class TextMelLoader(torch.utils.data.Dataset): class TextMelLoader(torch.utils.data.Dataset):
""" """
@ -43,8 +41,9 @@ class TextMelLoader(torch.utils.data.Dataset):
3) computes mel-spectrograms from audio files. 3) computes mel-spectrograms from audio files.
""" """
def __init__(self, dataset_path, audiopaths_and_text, args, load_mel_from_disk=True): def __init__(self, dataset_path, audiopaths_and_text, args, load_mel_from_disk=True):
self.audiopaths_and_text = load_filepaths_and_text(dataset_path, audiopaths_and_text) self.audiopaths_and_text = load_filepaths_and_text(
self.text_cleaners = args.text_cleaners dataset_path, audiopaths_and_text,
has_speakers=(args.n_speakers > 1))
self.load_mel_from_disk = load_mel_from_disk self.load_mel_from_disk = load_mel_from_disk
if not load_mel_from_disk: if not load_mel_from_disk:
self.max_wav_value = args.max_wav_value self.max_wav_value = args.max_wav_value
@ -74,14 +73,14 @@ class TextMelLoader(torch.utils.data.Dataset):
return melspec return melspec
def get_text(self, text): def get_text(self, text):
text_norm = torch.IntTensor(text_to_sequence(text, self.text_cleaners)) text_encoded = torch.IntTensor(self.tp.encode_text(text))
return text_norm return text_encoded
def __getitem__(self, index): def __getitem__(self, index):
# separate filename and text # separate filename and text
audiopath, text = self.audiopaths_and_text[index] audiopath, text = self.audiopaths_and_text[index]
len_text = len(text)
text = self.get_text(text) text = self.get_text(text)
len_text = len(text)
mel = self.get_mel(audiopath) mel = self.get_mel(audiopath)
return (text, mel, len_text) return (text, mel, len_text)

View file

@ -108,14 +108,20 @@ def parse_args(parser):
dataset = parser.add_argument_group('dataset parameters') dataset = parser.add_argument_group('dataset parameters')
dataset.add_argument('--training-files', type=str, required=True, dataset.add_argument('--training-files', type=str, required=True,
help='Path to training filelist') help='Path to training filelist. Separate multiple paths with commas.')
dataset.add_argument('--validation-files', type=str, required=True, dataset.add_argument('--validation-files', type=str, required=True,
help='Path to validation filelist') help='Path to validation filelist. Separate multiple paths with commas.')
dataset.add_argument('--pitch-mean-std-file', type=str, default=None, dataset.add_argument('--pitch-mean-std-file', type=str, default=None,
help='Path to pitch stats to be stored in the model') help='Path to pitch stats to be stored in the model')
dataset.add_argument('--text-cleaners', nargs='*', dataset.add_argument('--text-cleaners', nargs='*',
default=['english_cleaners'], type=str, default=['english_cleaners'], type=str,
help='Type of text cleaners for input text') help='Type of text cleaners for input text')
dataset.add_argument('--symbol-set', type=str, default='english_basic',
help='Define symbol set for input text')
cond = parser.add_argument_group('conditioning on additional attributes')
cond.add_argument('--n-speakers', type=int, default=1,
help='Condition on speaker, value > 1 enables trainable speaker embeddings.')
distributed = parser.add_argument_group('distributed setup') distributed = parser.add_argument_group('distributed setup')
distributed.add_argument('--local_rank', type=int, default=os.getenv('LOCAL_RANK', 0), distributed.add_argument('--local_rank', type=int, default=os.getenv('LOCAL_RANK', 0),
@ -316,8 +322,6 @@ def main():
model = models.get_model('FastPitch', model_config, device) model = models.get_model('FastPitch', model_config, device)
# Store pitch mean/std as params to translate from Hz during inference # Store pitch mean/std as params to translate from Hz during inference
fpath = common.utils.stats_filename(
args.dataset_path, args.training_files, 'pitch_char')
with open(args.pitch_mean_std_file, 'r') as f: with open(args.pitch_mean_std_file, 'r') as f:
stats = json.load(f) stats = json.load(f)
model.pitch_mean[0] = stats['mean'] model.pitch_mean[0] = stats['mean']
@ -530,6 +534,13 @@ def main():
validate(model, None, total_iter, criterion, valset, args.batch_size, validate(model, None, total_iter, criterion, valset, args.batch_size,
collate_fn, distributed_run, batch_to_gpu, use_gt_durations=True) collate_fn, distributed_run, batch_to_gpu, use_gt_durations=True)
if (epoch > 0 and args.epochs_per_checkpoint > 0 and
(epoch % args.epochs_per_checkpoint != 0) and args.local_rank == 0):
checkpoint_path = os.path.join(
args.output, f"FastPitch_checkpoint_{epoch}.pt")
save_checkpoint(args.local_rank, model, ema_model, optimizer, epoch,
total_iter, model_config, args.amp, checkpoint_path)
if __name__ == '__main__': if __name__ == '__main__':
main() main()

View file

@ -58,6 +58,7 @@ class Invertible1x1Conv(torch.nn.Module):
if torch.det(W) < 0: if torch.det(W) < 0:
W[:, 0] = -1 * W[:, 0] W[:, 0] = -1 * W[:, 0]
W = W.view(c, c, 1) W = W.view(c, c, 1)
W = W.contiguous()
self.conv.weight.data = W self.conv.weight.data = W
def forward(self, z): def forward(self, z):
@ -279,6 +280,49 @@ class WaveGlow(torch.nn.Module):
return audio return audio
def infer_onnx(self, spect, z, sigma=0.9):
spect = self.upsample(spect)
# trim conv artifacts. maybe pad spec to kernel multiple
time_cutoff = self.upsample.kernel_size[0] - self.upsample.stride[0]
spect = spect[:, :, :-time_cutoff]
length_spect_group = spect.size(2)//8
mel_dim = 80
batch_size = spect.size(0)
spect = spect.view((batch_size, mel_dim, length_spect_group, self.n_group))
spect = spect.permute(0, 2, 1, 3)
spect = spect.contiguous()
spect = spect.view((batch_size, length_spect_group, self.n_group*mel_dim))
spect = spect.permute(0, 2, 1)
spect = spect.contiguous()
audio = z[:, :self.n_remaining_channels, :]
z = z[:, self.n_remaining_channels:self.n_group, :]
audio = sigma*audio
for k in reversed(range(self.n_flows)):
n_half = int(audio.size(1) // 2)
audio_0 = audio[:, :n_half, :]
audio_1 = audio[:, n_half:(n_half+n_half), :]
output = self.WN[k]((audio_0, spect))
s = output[:, n_half:(n_half+n_half), :]
b = output[:, :n_half, :]
audio_1 = (audio_1 - b) / torch.exp(s)
audio = torch.cat([audio_0, audio_1], 1)
audio = self.convinv[k].infer(audio)
if k % self.n_early_every == 0 and k > 0:
audio = torch.cat((z[:, :self.n_early_size, :], audio), 1)
z = z[:, self.n_early_size:self.n_group, :]
audio = audio.permute(0,2,1).contiguous().view(batch_size, (length_spect_group * self.n_group))
return audio
@staticmethod @staticmethod
def remove_weightnorm(model): def remove_weightnorm(model):
waveglow = model waveglow = model