Merge pull request #731 from NVIDIA/gh/release
[FastPitch/PyT] updated checkpoints, multispeaker and text processing
This commit is contained in:
commit
475256f71d
17
PyTorch/SpeechSynthesis/FastPitch/.gitignore
vendored
17
PyTorch/SpeechSynthesis/FastPitch/.gitignore
vendored
|
@ -1,8 +1,15 @@
|
|||
*.swp
|
||||
*.swo
|
||||
*.pyc
|
||||
__pycache__
|
||||
scripts_joc/
|
||||
runs*/
|
||||
LJSpeech-1.1/
|
||||
output*
|
||||
scripts_joc/
|
||||
tests/
|
||||
|
||||
*.pyc
|
||||
__pycache__
|
||||
|
||||
.idea/
|
||||
.DS_Store
|
||||
|
||||
*.swp
|
||||
*.swo
|
||||
*.swn
|
||||
|
|
|
@ -488,11 +488,11 @@ The `scripts/train.sh` script is configured for 8x GPU with at least 16GB of mem
|
|||
```
|
||||
In a single accumulated step, there are `batch_size x gradient_accumulation_steps x GPUs = 256` examples being processed in parallel. With a smaller number of GPUs, increase `--gradient_accumulation_steps` to keep this relation satisfied, e.g., through env variables
|
||||
```bash
|
||||
NGPU=4 GRAD_ACC=2 bash scripts/train.sh
|
||||
NUM_GPUS=4 GRAD_ACCUMULATION=2 bash scripts/train.sh
|
||||
```
|
||||
With automatic mixed precision (AMP), a larger batch size fits in 16GB of memory:
|
||||
```bash
|
||||
NGPU=4 GRAD_ACC=1 BS=64 AMP=true bash scripts/train.sh
|
||||
NUM_GPUS=4 GRAD_ACCUMULATION=1 BS=64 AMP=true bash scripts/train.sh
|
||||
```
|
||||
|
||||
### Inference process
|
||||
|
@ -545,18 +545,18 @@ To benchmark the training performance on a specific batch size, run:
|
|||
|
||||
* NVIDIA DGX A100 (8x A100 40GB)
|
||||
```bash
|
||||
AMP=true NGPU=1 BS=128 GRAD_ACC=2 EPOCHS=10 bash scripts/train.sh
|
||||
AMP=true NGPU=8 BS=32 GRAD_ACC=1 EPOCHS=10 bash scripts/train.sh
|
||||
NGPU=1 BS=128 GRAD_ACC=2 EPOCHS=10 bash scripts/train.sh
|
||||
NGPU=8 BS=32 GRAD_ACC=1 EPOCHS=10 bash scripts/train.sh
|
||||
AMP=true NUM_GPUS=1 BS=128 GRAD_ACCUMULATION=2 EPOCHS=10 bash scripts/train.sh
|
||||
AMP=true NUM_GPUS=8 BS=32 GRAD_ACCUMULATION=1 EPOCHS=10 bash scripts/train.sh
|
||||
NUM_GPUS=1 BS=128 GRAD_ACCUMULATION=2 EPOCHS=10 bash scripts/train.sh
|
||||
NUM_GPUS=8 BS=32 GRAD_ACCUMULATION=1 EPOCHS=10 bash scripts/train.sh
|
||||
```
|
||||
|
||||
* NVIDIA DGX-1 (8x V100 16GB)
|
||||
```bash
|
||||
AMP=true NGPU=1 BS=64 GRAD_ACC=4 EPOCHS=10 bash scripts/train.sh
|
||||
AMP=true NGPU=8 BS=32 GRAD_ACC=1 EPOCHS=10 bash scripts/train.sh
|
||||
NGPU=1 BS=32 GRAD_ACC=8 EPOCHS=10 bash scripts/train.sh
|
||||
NGPU=8 BS=32 GRAD_ACC=1 EPOCHS=10 bash scripts/train.sh
|
||||
AMP=true NUM_GPUS=1 BS=64 GRAD_ACCUMULATION=4 EPOCHS=10 bash scripts/train.sh
|
||||
AMP=true NUM_GPUS=8 BS=32 GRAD_ACCUMULATION=1 EPOCHS=10 bash scripts/train.sh
|
||||
NUM_GPUS=1 BS=32 GRAD_ACCUMULATION=8 EPOCHS=10 bash scripts/train.sh
|
||||
NUM_GPUS=8 BS=32 GRAD_ACCUMULATION=1 EPOCHS=10 bash scripts/train.sh
|
||||
```
|
||||
|
||||
Each of these scripts runs for 10 epochs and for each epoch measures the
|
||||
|
@ -569,12 +569,12 @@ To benchmark the inference performance on a specific batch size, run:
|
|||
|
||||
* For FP16
|
||||
```bash
|
||||
AMP=true BS_SEQ=”1 4 8” REPEATS=100 bash scripts/inference_benchmark.sh
|
||||
AMP=true BS_SEQUENCE=”1 4 8” REPEATS=100 bash scripts/inference_benchmark.sh
|
||||
```
|
||||
|
||||
* For FP32 or TF32
|
||||
```bash
|
||||
BS_SEQ=”1 4 8” REPEATS=100 bash scripts/inference_benchmark.sh
|
||||
BS_SEQUENCE=”1 4 8” REPEATS=100 bash scripts/inference_benchmark.sh
|
||||
```
|
||||
|
||||
The output log files will contain performance numbers for the FastPitch model
|
||||
|
@ -726,6 +726,10 @@ The input utterance has 128 characters, synthesized audio has 8.05 s.
|
|||
|
||||
### Changelog
|
||||
|
||||
October 2020
|
||||
- Added multispeaker capabilities
|
||||
- Updated text processing module
|
||||
|
||||
June 2020
|
||||
- Updated performance tables to include A100 results
|
||||
|
||||
|
|
|
@ -1,74 +1,3 @@
|
|||
""" from https://github.com/keithito/tacotron """
|
||||
import re
|
||||
from common.text import cleaners
|
||||
from common.text.symbols import symbols
|
||||
from .cmudict import CMUDict
|
||||
|
||||
|
||||
# Mappings from symbol to numeric ID and vice versa:
|
||||
_symbol_to_id = {s: i for i, s in enumerate(symbols)}
|
||||
_id_to_symbol = {i: s for i, s in enumerate(symbols)}
|
||||
|
||||
# Regular expression matching text enclosed in curly braces:
|
||||
_curly_re = re.compile(r'(.*?)\{(.+?)\}(.*)')
|
||||
|
||||
|
||||
def text_to_sequence(text, cleaner_names):
|
||||
'''Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
|
||||
|
||||
The text can optionally have ARPAbet sequences enclosed in curly braces embedded
|
||||
in it. For example, "Turn left on {HH AW1 S S T AH0 N} Street."
|
||||
|
||||
Args:
|
||||
text: string to convert to a sequence
|
||||
cleaner_names: names of the cleaner functions to run the text through
|
||||
|
||||
Returns:
|
||||
List of integers corresponding to the symbols in the text
|
||||
'''
|
||||
sequence = []
|
||||
|
||||
# Check for curly braces and treat their contents as ARPAbet:
|
||||
while len(text):
|
||||
m = _curly_re.match(text)
|
||||
if not m:
|
||||
sequence += _symbols_to_sequence(_clean_text(text, cleaner_names))
|
||||
break
|
||||
sequence += _symbols_to_sequence(_clean_text(m.group(1), cleaner_names))
|
||||
sequence += _arpabet_to_sequence(m.group(2))
|
||||
text = m.group(3)
|
||||
|
||||
return sequence
|
||||
|
||||
|
||||
def sequence_to_text(sequence):
|
||||
'''Converts a sequence of IDs back to a string'''
|
||||
result = ''
|
||||
for symbol_id in sequence:
|
||||
if symbol_id in _id_to_symbol:
|
||||
s = _id_to_symbol[symbol_id]
|
||||
# Enclose ARPAbet back in curly braces:
|
||||
if len(s) > 1 and s[0] == '@':
|
||||
s = '{%s}' % s[1:]
|
||||
result += s
|
||||
return result.replace('}{', ' ')
|
||||
|
||||
|
||||
def _clean_text(text, cleaner_names):
|
||||
for name in cleaner_names:
|
||||
cleaner = getattr(cleaners, name)
|
||||
if not cleaner:
|
||||
raise Exception('Unknown cleaner: %s' % name)
|
||||
text = cleaner(text)
|
||||
return text
|
||||
|
||||
|
||||
def _symbols_to_sequence(symbols):
|
||||
return [_symbol_to_id[s] for s in symbols if _should_keep_symbol(s)]
|
||||
|
||||
|
||||
def _arpabet_to_sequence(text):
|
||||
return _symbols_to_sequence(['@' + s for s in text.split()])
|
||||
|
||||
|
||||
def _should_keep_symbol(s):
|
||||
return s in _symbol_to_id and s is not '_' and s is not '~'
|
||||
cmudict = CMUDict()
|
||||
|
|
|
@ -0,0 +1,58 @@
|
|||
import re
|
||||
|
||||
_no_period_re = re.compile(r'(No[.])(?=[ ]?[0-9])')
|
||||
_percent_re = re.compile(r'([ ]?[%])')
|
||||
_half_re = re.compile('([0-9]½)|(½)')
|
||||
|
||||
|
||||
# List of (regular expression, replacement) pairs for abbreviations:
|
||||
_abbreviations = [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1]) for x in [
|
||||
('mrs', 'misess'),
|
||||
('ms', 'miss'),
|
||||
('mr', 'mister'),
|
||||
('dr', 'doctor'),
|
||||
('st', 'saint'),
|
||||
('co', 'company'),
|
||||
('jr', 'junior'),
|
||||
('maj', 'major'),
|
||||
('gen', 'general'),
|
||||
('drs', 'doctors'),
|
||||
('rev', 'reverend'),
|
||||
('lt', 'lieutenant'),
|
||||
('hon', 'honorable'),
|
||||
('sgt', 'sergeant'),
|
||||
('capt', 'captain'),
|
||||
('esq', 'esquire'),
|
||||
('ltd', 'limited'),
|
||||
('col', 'colonel'),
|
||||
('ft', 'fort'),
|
||||
('sen', 'senator'),
|
||||
]]
|
||||
|
||||
|
||||
def _expand_no_period(m):
|
||||
word = m.group(0)
|
||||
if word[0] == 'N':
|
||||
return 'Number'
|
||||
return 'number'
|
||||
|
||||
|
||||
def _expand_percent(m):
|
||||
return ' percent'
|
||||
|
||||
|
||||
def _expand_half(m):
|
||||
word = m.group(1)
|
||||
if word is None:
|
||||
return 'half'
|
||||
return word[0] + ' and a half'
|
||||
|
||||
|
||||
def normalize_abbreviations(text):
|
||||
text = re.sub(_no_period_re, _expand_no_period, text)
|
||||
text = re.sub(_percent_re, _expand_percent, text)
|
||||
text = re.sub(_half_re, _expand_half, text)
|
||||
|
||||
for regex, replacement in _abbreviations:
|
||||
text = re.sub(regex, replacement, text)
|
||||
return text
|
67
PyTorch/SpeechSynthesis/FastPitch/common/text/acronyms.py
Normal file
67
PyTorch/SpeechSynthesis/FastPitch/common/text/acronyms.py
Normal file
|
@ -0,0 +1,67 @@
|
|||
import re
|
||||
from . import cmudict
|
||||
|
||||
_letter_to_arpabet = {
|
||||
'A': 'EY1',
|
||||
'B': 'B IY1',
|
||||
'C': 'S IY1',
|
||||
'D': 'D IY1',
|
||||
'E': 'IY1',
|
||||
'F': 'EH1 F',
|
||||
'G': 'JH IY1',
|
||||
'H': 'EY1 CH',
|
||||
'I': 'AY1',
|
||||
'J': 'JH EY1',
|
||||
'K': 'K EY1',
|
||||
'L': 'EH1 L',
|
||||
'M': 'EH1 M',
|
||||
'N': 'EH1 N',
|
||||
'O': 'OW1',
|
||||
'P': 'P IY1',
|
||||
'Q': 'K Y UW1',
|
||||
'R': 'AA1 R',
|
||||
'S': 'EH1 S',
|
||||
'T': 'T IY1',
|
||||
'U': 'Y UW1',
|
||||
'V': 'V IY1',
|
||||
'X': 'EH1 K S',
|
||||
'Y': 'W AY1',
|
||||
'W': 'D AH1 B AH0 L Y UW0',
|
||||
'Z': 'Z IY1',
|
||||
's': 'Z'
|
||||
}
|
||||
|
||||
# must ignore roman numerals
|
||||
# _acronym_re = re.compile(r'([A-Z][A-Z]+)s?|([A-Z]\.([A-Z]\.)+s?)')
|
||||
_acronym_re = re.compile(r'([A-Z][A-Z]+)s?')
|
||||
|
||||
|
||||
def _expand_acronyms(m, add_spaces=True):
|
||||
acronym = m.group(0)
|
||||
|
||||
# remove dots if they exist
|
||||
acronym = re.sub('\.', '', acronym)
|
||||
|
||||
acronym = "".join(acronym.split())
|
||||
arpabet = cmudict.lookup(acronym)
|
||||
|
||||
if arpabet is None:
|
||||
acronym = list(acronym)
|
||||
arpabet = ["{" + _letter_to_arpabet[letter] + "}" for letter in acronym]
|
||||
# temporary fix
|
||||
if arpabet[-1] == '{Z}' and len(arpabet) > 1:
|
||||
arpabet[-2] = arpabet[-2][:-1] + ' ' + arpabet[-1][1:]
|
||||
del arpabet[-1]
|
||||
|
||||
arpabet = ' '.join(arpabet)
|
||||
elif len(arpabet) == 1:
|
||||
arpabet = "{" + arpabet[0] + "}"
|
||||
else:
|
||||
arpabet = acronym
|
||||
|
||||
return arpabet
|
||||
|
||||
|
||||
def normalize_acronyms(text):
|
||||
text = re.sub(_acronym_re, _expand_acronyms, text)
|
||||
return text
|
|
@ -1,90 +1,92 @@
|
|||
""" from https://github.com/keithito/tacotron """
|
||||
""" adapted from https://github.com/keithito/tacotron """
|
||||
|
||||
'''
|
||||
Cleaners are transformations that run over the input text at both training and eval time.
|
||||
|
||||
Cleaners can be selected by passing a comma-delimited list of cleaner names as the "cleaners"
|
||||
hyperparameter. Some cleaners are English-specific. You'll typically want to use:
|
||||
1. "english_cleaners" for English text
|
||||
2. "transliteration_cleaners" for non-English text that can be transliterated to ASCII using
|
||||
the Unidecode library (https://pypi.python.org/pypi/Unidecode)
|
||||
3. "basic_cleaners" if you do not want to transliterate (in this case, you should also update
|
||||
the symbols in symbols.py to match your data).
|
||||
1. "english_cleaners" for English text
|
||||
2. "transliteration_cleaners" for non-English text that can be transliterated to ASCII using
|
||||
the Unidecode library (https://pypi.python.org/pypi/Unidecode)
|
||||
3. "basic_cleaners" if you do not want to transliterate (in this case, you should also update
|
||||
the symbols in symbols.py to match your data).
|
||||
'''
|
||||
|
||||
import re
|
||||
from unidecode import unidecode
|
||||
from .numbers import normalize_numbers
|
||||
from .numerical import normalize_numbers
|
||||
from .acronyms import normalize_acronyms
|
||||
from .datestime import normalize_datestime
|
||||
from .letters_and_numbers import normalize_letters_and_numbers
|
||||
from .abbreviations import normalize_abbreviations
|
||||
|
||||
|
||||
# Regular expression matching whitespace:
|
||||
_whitespace_re = re.compile(r'\s+')
|
||||
|
||||
# List of (regular expression, replacement) pairs for abbreviations:
|
||||
_abbreviations = [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1]) for x in [
|
||||
('mrs', 'misess'),
|
||||
('mr', 'mister'),
|
||||
('dr', 'doctor'),
|
||||
('st', 'saint'),
|
||||
('co', 'company'),
|
||||
('jr', 'junior'),
|
||||
('maj', 'major'),
|
||||
('gen', 'general'),
|
||||
('drs', 'doctors'),
|
||||
('rev', 'reverend'),
|
||||
('lt', 'lieutenant'),
|
||||
('hon', 'honorable'),
|
||||
('sgt', 'sergeant'),
|
||||
('capt', 'captain'),
|
||||
('esq', 'esquire'),
|
||||
('ltd', 'limited'),
|
||||
('col', 'colonel'),
|
||||
('ft', 'fort'),
|
||||
]]
|
||||
|
||||
|
||||
def expand_abbreviations(text):
|
||||
for regex, replacement in _abbreviations:
|
||||
text = re.sub(regex, replacement, text)
|
||||
return text
|
||||
return normalize_abbreviations(text)
|
||||
|
||||
|
||||
def expand_numbers(text):
|
||||
return normalize_numbers(text)
|
||||
return normalize_numbers(text)
|
||||
|
||||
|
||||
def expand_acronyms(text):
|
||||
return normalize_acronyms(text)
|
||||
|
||||
|
||||
def expand_datestime(text):
|
||||
return normalize_datestime(text)
|
||||
|
||||
|
||||
def expand_letters_and_numbers(text):
|
||||
return normalize_letters_and_numbers(text)
|
||||
|
||||
|
||||
def lowercase(text):
|
||||
return text.lower()
|
||||
return text.lower()
|
||||
|
||||
|
||||
def collapse_whitespace(text):
|
||||
return re.sub(_whitespace_re, ' ', text)
|
||||
return re.sub(_whitespace_re, ' ', text)
|
||||
|
||||
|
||||
def separate_acronyms(text):
|
||||
text = re.sub(r"([0-9]+)([a-zA-Z]+)", r"\1 \2", text)
|
||||
text = re.sub(r"([a-zA-Z]+)([0-9]+)", r"\1 \2", text)
|
||||
return text
|
||||
|
||||
|
||||
def convert_to_ascii(text):
|
||||
return unidecode(text)
|
||||
return unidecode(text)
|
||||
|
||||
|
||||
def basic_cleaners(text):
|
||||
'''Basic pipeline that lowercases and collapses whitespace without transliteration.'''
|
||||
text = lowercase(text)
|
||||
text = collapse_whitespace(text)
|
||||
return text
|
||||
'''Basic pipeline that collapses whitespace without transliteration.'''
|
||||
text = lowercase(text)
|
||||
text = collapse_whitespace(text)
|
||||
return text
|
||||
|
||||
|
||||
def transliteration_cleaners(text):
|
||||
'''Pipeline for non-English text that transliterates to ASCII.'''
|
||||
text = convert_to_ascii(text)
|
||||
text = lowercase(text)
|
||||
text = collapse_whitespace(text)
|
||||
return text
|
||||
'''Pipeline for non-English text that transliterates to ASCII.'''
|
||||
text = convert_to_ascii(text)
|
||||
text = lowercase(text)
|
||||
text = collapse_whitespace(text)
|
||||
return text
|
||||
|
||||
|
||||
def english_cleaners_post_chars(word):
|
||||
return word
|
||||
|
||||
|
||||
def english_cleaners(text):
|
||||
'''Pipeline for English text, including number and abbreviation expansion.'''
|
||||
text = convert_to_ascii(text)
|
||||
text = lowercase(text)
|
||||
text = expand_numbers(text)
|
||||
text = expand_abbreviations(text)
|
||||
text = collapse_whitespace(text)
|
||||
return text
|
||||
'''Pipeline for English text, with number and abbreviation expansion.'''
|
||||
text = convert_to_ascii(text)
|
||||
text = lowercase(text)
|
||||
text = expand_numbers(text)
|
||||
text = expand_abbreviations(text)
|
||||
text = collapse_whitespace(text)
|
||||
return text
|
||||
|
|
|
@ -18,7 +18,18 @@ _valid_symbol_set = set(valid_symbols)
|
|||
|
||||
class CMUDict:
|
||||
'''Thin wrapper around CMUDict data. http://www.speech.cs.cmu.edu/cgi-bin/cmudict'''
|
||||
def __init__(self, file_or_path, keep_ambiguous=True):
|
||||
def __init__(self, file_or_path=None, heteronyms_path=None, keep_ambiguous=True):
|
||||
if file_or_path is None:
|
||||
self._entries = {}
|
||||
else:
|
||||
self.initialize(file_or_path, keep_ambiguous)
|
||||
|
||||
if heteronyms_path is None:
|
||||
self.heteronyms = []
|
||||
else:
|
||||
self.heteronyms = set(lines_to_list(heteronyms_path))
|
||||
|
||||
def initialize(self, file_or_path, keep_ambiguous=True):
|
||||
if isinstance(file_or_path, str):
|
||||
with open(file_or_path, encoding='latin-1') as f:
|
||||
entries = _parse_cmudict(f)
|
||||
|
@ -28,17 +39,18 @@ class CMUDict:
|
|||
entries = {word: pron for word, pron in entries.items() if len(pron) == 1}
|
||||
self._entries = entries
|
||||
|
||||
|
||||
def __len__(self):
|
||||
if len(self._entries) == 0:
|
||||
raise ValueError("CMUDict not initialized")
|
||||
return len(self._entries)
|
||||
|
||||
|
||||
def lookup(self, word):
|
||||
'''Returns list of ARPAbet pronunciations of the given word.'''
|
||||
if len(self._entries) == 0:
|
||||
raise ValueError("CMUDict not initialized")
|
||||
return self._entries.get(word.upper())
|
||||
|
||||
|
||||
|
||||
_alt_re = re.compile(r'\([0-9]+\)')
|
||||
|
||||
|
||||
|
|
22
PyTorch/SpeechSynthesis/FastPitch/common/text/datestime.py
Normal file
22
PyTorch/SpeechSynthesis/FastPitch/common/text/datestime.py
Normal file
|
@ -0,0 +1,22 @@
|
|||
import re
|
||||
_ampm_re = re.compile(
|
||||
r'([0-9]|0[0-9]|1[0-9]|2[0-3]):?([0-5][0-9])?\s*([AaPp][Mm]\b)')
|
||||
|
||||
|
||||
def _expand_ampm(m):
|
||||
matches = list(m.groups(0))
|
||||
txt = matches[0]
|
||||
txt = txt if int(matches[1]) == 0 else txt + ' ' + matches[1]
|
||||
|
||||
if matches[2][0].lower() == 'a':
|
||||
txt += ' a.m.'
|
||||
elif matches[2][0].lower() == 'p':
|
||||
txt += ' p.m.'
|
||||
|
||||
return txt
|
||||
|
||||
|
||||
def normalize_datestime(text):
|
||||
text = re.sub(_ampm_re, _expand_ampm, text)
|
||||
#text = re.sub(r"([0-9]|0[0-9]|1[0-9]|2[0-3]):([0-5][0-9])?", r"\1 \2", text)
|
||||
return text
|
|
@ -0,0 +1,90 @@
|
|||
import re
|
||||
_letters_and_numbers_re = re.compile(
|
||||
r"((?:[a-zA-Z]+[0-9]|[0-9]+[a-zA-Z])[a-zA-Z0-9']*)", re.IGNORECASE)
|
||||
|
||||
_hardware_re = re.compile(
|
||||
'([0-9]+(?:[.,][0-9]+)?)(?:\s?)(tb|gb|mb|kb|ghz|mhz|khz|hz|mm)', re.IGNORECASE)
|
||||
_hardware_key = {'tb': 'terabyte',
|
||||
'gb': 'gigabyte',
|
||||
'mb': 'megabyte',
|
||||
'kb': 'kilobyte',
|
||||
'ghz': 'gigahertz',
|
||||
'mhz': 'megahertz',
|
||||
'khz': 'kilohertz',
|
||||
'hz': 'hertz',
|
||||
'mm': 'millimeter',
|
||||
'cm': 'centimeter',
|
||||
'km': 'kilometer'}
|
||||
|
||||
_dimension_re = re.compile(
|
||||
r'\b(\d+(?:[,.]\d+)?\s*[xX]\s*\d+(?:[,.]\d+)?\s*[xX]\s*\d+(?:[,.]\d+)?(?:in|inch|m)?)\b|\b(\d+(?:[,.]\d+)?\s*[xX]\s*\d+(?:[,.]\d+)?(?:in|inch|m)?)\b')
|
||||
_dimension_key = {'m': 'meter',
|
||||
'in': 'inch',
|
||||
'inch': 'inch'}
|
||||
|
||||
|
||||
|
||||
|
||||
def _expand_letters_and_numbers(m):
|
||||
text = re.split(r'(\d+)', m.group(0))
|
||||
|
||||
# remove trailing space
|
||||
if text[-1] == '':
|
||||
text = text[:-1]
|
||||
elif text[0] == '':
|
||||
text = text[1:]
|
||||
|
||||
# if not like 1920s, or AK47's , 20th, 1st, 2nd, 3rd, etc...
|
||||
if text[-1] in ("'s", "s", "th", "nd", "st", "rd") and text[-2].isdigit():
|
||||
text[-2] = text[-2] + text[-1]
|
||||
text = text[:-1]
|
||||
|
||||
# for combining digits 2 by 2
|
||||
new_text = []
|
||||
for i in range(len(text)):
|
||||
string = text[i]
|
||||
if string.isdigit() and len(string) < 5:
|
||||
# heuristics
|
||||
if len(string) > 2 and string[-2] == '0':
|
||||
if string[-1] == '0':
|
||||
string = [string]
|
||||
else:
|
||||
string = [string[:-3], string[-2], string[-1]]
|
||||
elif len(string) % 2 == 0:
|
||||
string = [string[i:i+2] for i in range(0, len(string), 2)]
|
||||
elif len(string) > 2:
|
||||
string = [string[0]] + [string[i:i+2] for i in range(1, len(string), 2)]
|
||||
new_text.extend(string)
|
||||
else:
|
||||
new_text.append(string)
|
||||
|
||||
text = new_text
|
||||
text = " ".join(text)
|
||||
return text
|
||||
|
||||
|
||||
def _expand_hardware(m):
|
||||
quantity, measure = m.groups(0)
|
||||
measure = _hardware_key[measure.lower()]
|
||||
if measure[-1] != 'z' and float(quantity.replace(',', '')) > 1:
|
||||
return "{} {}s".format(quantity, measure)
|
||||
return "{} {}".format(quantity, measure)
|
||||
|
||||
|
||||
def _expand_dimension(m):
|
||||
text = "".join([x for x in m.groups(0) if x != 0])
|
||||
text = text.replace(' x ', ' by ')
|
||||
text = text.replace('x', ' by ')
|
||||
if text.endswith(tuple(_dimension_key.keys())):
|
||||
if text[-2].isdigit():
|
||||
text = "{} {}".format(text[:-1], _dimension_key[text[-1:]])
|
||||
elif text[-3].isdigit():
|
||||
text = "{} {}".format(text[:-2], _dimension_key[text[-2:]])
|
||||
return text
|
||||
|
||||
|
||||
def normalize_letters_and_numbers(text):
|
||||
text = re.sub(_hardware_re, _expand_hardware, text)
|
||||
text = re.sub(_dimension_re, _expand_dimension, text)
|
||||
text = re.sub(_letters_and_numbers_re, _expand_letters_and_numbers, text)
|
||||
return text
|
|
@ -1,71 +0,0 @@
|
|||
""" from https://github.com/keithito/tacotron """
|
||||
|
||||
import inflect
|
||||
import re
|
||||
|
||||
|
||||
_inflect = inflect.engine()
|
||||
_comma_number_re = re.compile(r'([0-9][0-9\,]+[0-9])')
|
||||
_decimal_number_re = re.compile(r'([0-9]+\.[0-9]+)')
|
||||
_pounds_re = re.compile(r'£([0-9\,]*[0-9]+)')
|
||||
_dollars_re = re.compile(r'\$([0-9\.\,]*[0-9]+)')
|
||||
_ordinal_re = re.compile(r'[0-9]+(st|nd|rd|th)')
|
||||
_number_re = re.compile(r'[0-9]+')
|
||||
|
||||
|
||||
def _remove_commas(m):
|
||||
return m.group(1).replace(',', '')
|
||||
|
||||
|
||||
def _expand_decimal_point(m):
|
||||
return m.group(1).replace('.', ' point ')
|
||||
|
||||
|
||||
def _expand_dollars(m):
|
||||
match = m.group(1)
|
||||
parts = match.split('.')
|
||||
if len(parts) > 2:
|
||||
return match + ' dollars' # Unexpected format
|
||||
dollars = int(parts[0]) if parts[0] else 0
|
||||
cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0
|
||||
if dollars and cents:
|
||||
dollar_unit = 'dollar' if dollars == 1 else 'dollars'
|
||||
cent_unit = 'cent' if cents == 1 else 'cents'
|
||||
return '%s %s, %s %s' % (dollars, dollar_unit, cents, cent_unit)
|
||||
elif dollars:
|
||||
dollar_unit = 'dollar' if dollars == 1 else 'dollars'
|
||||
return '%s %s' % (dollars, dollar_unit)
|
||||
elif cents:
|
||||
cent_unit = 'cent' if cents == 1 else 'cents'
|
||||
return '%s %s' % (cents, cent_unit)
|
||||
else:
|
||||
return 'zero dollars'
|
||||
|
||||
|
||||
def _expand_ordinal(m):
|
||||
return _inflect.number_to_words(m.group(0))
|
||||
|
||||
|
||||
def _expand_number(m):
|
||||
num = int(m.group(0))
|
||||
if num > 1000 and num < 3000:
|
||||
if num == 2000:
|
||||
return 'two thousand'
|
||||
elif num > 2000 and num < 2010:
|
||||
return 'two thousand ' + _inflect.number_to_words(num % 100)
|
||||
elif num % 100 == 0:
|
||||
return _inflect.number_to_words(num // 100) + ' hundred'
|
||||
else:
|
||||
return _inflect.number_to_words(num, andword='', zero='oh', group=2).replace(', ', ' ')
|
||||
else:
|
||||
return _inflect.number_to_words(num, andword='')
|
||||
|
||||
|
||||
def normalize_numbers(text):
|
||||
text = re.sub(_comma_number_re, _remove_commas, text)
|
||||
text = re.sub(_pounds_re, r'\1 pounds', text)
|
||||
text = re.sub(_dollars_re, _expand_dollars, text)
|
||||
text = re.sub(_decimal_number_re, _expand_decimal_point, text)
|
||||
text = re.sub(_ordinal_re, _expand_ordinal, text)
|
||||
text = re.sub(_number_re, _expand_number, text)
|
||||
return text
|
153
PyTorch/SpeechSynthesis/FastPitch/common/text/numerical.py
Normal file
153
PyTorch/SpeechSynthesis/FastPitch/common/text/numerical.py
Normal file
|
@ -0,0 +1,153 @@
|
|||
""" adapted from https://github.com/keithito/tacotron """
|
||||
|
||||
import inflect
|
||||
import re
|
||||
_magnitudes = ['trillion', 'billion', 'million', 'thousand', 'hundred', 'm', 'b', 't']
|
||||
_magnitudes_key = {'m': 'million', 'b': 'billion', 't': 'trillion'}
|
||||
_measurements = '(f|c|k|d|m)'
|
||||
_measurements_key = {'f': 'fahrenheit',
|
||||
'c': 'celsius',
|
||||
'k': 'thousand',
|
||||
'm': 'meters'}
|
||||
_currency_key = {'$': 'dollar', '£': 'pound', '€': 'euro', '₩': 'won'}
|
||||
_inflect = inflect.engine()
|
||||
_comma_number_re = re.compile(r'([0-9][0-9\,]+[0-9])')
|
||||
_decimal_number_re = re.compile(r'([0-9]+\.[0-9]+)')
|
||||
_currency_re = re.compile(r'([\$€£₩])([0-9\.\,]*[0-9]+)(?:[ ]?({})(?=[^a-zA-Z]))?'.format("|".join(_magnitudes)), re.IGNORECASE)
|
||||
_measurement_re = re.compile(r'([0-9\.\,]*[0-9]+(\s)?{}\b)'.format(_measurements), re.IGNORECASE)
|
||||
_ordinal_re = re.compile(r'[0-9]+(st|nd|rd|th)')
|
||||
# _range_re = re.compile(r'(?<=[0-9])+(-)(?=[0-9])+.*?')
|
||||
_roman_re = re.compile(r'\b(?=[MDCLXVI]+\b)M{0,4}(CM|CD|D?C{0,3})(XC|XL|L?X{0,3})(IX|IV|V?I{2,3})\b') # avoid I
|
||||
_multiply_re = re.compile(r'(\b[0-9]+)(x)([0-9]+)')
|
||||
_number_re = re.compile(r"[0-9]+'s|[0-9]+s|[0-9]+")
|
||||
|
||||
def _remove_commas(m):
|
||||
return m.group(1).replace(',', '')
|
||||
|
||||
|
||||
def _expand_decimal_point(m):
|
||||
return m.group(1).replace('.', ' point ')
|
||||
|
||||
|
||||
def _expand_currency(m):
|
||||
currency = _currency_key[m.group(1)]
|
||||
quantity = m.group(2)
|
||||
magnitude = m.group(3)
|
||||
|
||||
# remove commas from quantity to be able to convert to numerical
|
||||
quantity = quantity.replace(',', '')
|
||||
|
||||
# check for million, billion, etc...
|
||||
if magnitude is not None and magnitude.lower() in _magnitudes:
|
||||
if len(magnitude) == 1:
|
||||
magnitude = _magnitudes_key[magnitude.lower()]
|
||||
return "{} {} {}".format(_expand_hundreds(quantity), magnitude, currency+'s')
|
||||
|
||||
parts = quantity.split('.')
|
||||
if len(parts) > 2:
|
||||
return quantity + " " + currency + "s" # Unexpected format
|
||||
|
||||
dollars = int(parts[0]) if parts[0] else 0
|
||||
|
||||
cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0
|
||||
if dollars and cents:
|
||||
dollar_unit = currency if dollars == 1 else currency+'s'
|
||||
cent_unit = 'cent' if cents == 1 else 'cents'
|
||||
return "{} {}, {} {}".format(
|
||||
_expand_hundreds(dollars), dollar_unit,
|
||||
_inflect.number_to_words(cents), cent_unit)
|
||||
elif dollars:
|
||||
dollar_unit = currency if dollars == 1 else currency+'s'
|
||||
return "{} {}".format(_expand_hundreds(dollars), dollar_unit)
|
||||
elif cents:
|
||||
cent_unit = 'cent' if cents == 1 else 'cents'
|
||||
return "{} {}".format(_inflect.number_to_words(cents), cent_unit)
|
||||
else:
|
||||
return 'zero' + ' ' + currency + 's'
|
||||
|
||||
|
||||
def _expand_hundreds(text):
|
||||
number = float(text)
|
||||
if 1000 < number < 10000 and (number % 100 == 0) and (number % 1000 != 0):
|
||||
return _inflect.number_to_words(int(number / 100)) + " hundred"
|
||||
else:
|
||||
return _inflect.number_to_words(text)
|
||||
|
||||
|
||||
def _expand_ordinal(m):
|
||||
return _inflect.number_to_words(m.group(0))
|
||||
|
||||
|
||||
def _expand_measurement(m):
|
||||
_, number, measurement = re.split('(\d+(?:\.\d+)?)', m.group(0))
|
||||
number = _inflect.number_to_words(number)
|
||||
measurement = "".join(measurement.split())
|
||||
measurement = _measurements_key[measurement.lower()]
|
||||
return "{} {}".format(number, measurement)
|
||||
|
||||
|
||||
def _expand_range(m):
|
||||
return ' to '
|
||||
|
||||
|
||||
def _expand_multiply(m):
|
||||
left = m.group(1)
|
||||
right = m.group(3)
|
||||
return "{} by {}".format(left, right)
|
||||
|
||||
|
||||
def _expand_roman(m):
|
||||
# from https://stackoverflow.com/questions/19308177/converting-roman-numerals-to-integers-in-python
|
||||
roman_numerals = {'I':1, 'V':5, 'X':10, 'L':50, 'C':100, 'D':500, 'M':1000}
|
||||
result = 0
|
||||
num = m.group(0)
|
||||
for i, c in enumerate(num):
|
||||
if (i+1) == len(num) or roman_numerals[c] >= roman_numerals[num[i+1]]:
|
||||
result += roman_numerals[c]
|
||||
else:
|
||||
result -= roman_numerals[c]
|
||||
return str(result)
|
||||
|
||||
|
||||
def _expand_number(m):
|
||||
_, number, suffix = re.split(r"(\d+(?:'?\d+)?)", m.group(0))
|
||||
number = int(number)
|
||||
if number > 1000 < 10000 and (number % 100 == 0) and (number % 1000 != 0):
|
||||
text = _inflect.number_to_words(number // 100) + " hundred"
|
||||
elif number > 1000 and number < 3000:
|
||||
if number == 2000:
|
||||
text = 'two thousand'
|
||||
elif number > 2000 and number < 2010:
|
||||
text = 'two thousand ' + _inflect.number_to_words(number % 100)
|
||||
elif number % 100 == 0:
|
||||
text = _inflect.number_to_words(number // 100) + ' hundred'
|
||||
else:
|
||||
number = _inflect.number_to_words(number, andword='', zero='oh', group=2).replace(', ', ' ')
|
||||
number = re.sub(r'-', ' ', number)
|
||||
text = number
|
||||
else:
|
||||
number = _inflect.number_to_words(number, andword='and')
|
||||
number = re.sub(r'-', ' ', number)
|
||||
number = re.sub(r',', '', number)
|
||||
text = number
|
||||
|
||||
if suffix in ("'s", "s"):
|
||||
if text[-1] == 'y':
|
||||
text = text[:-1] + 'ies'
|
||||
else:
|
||||
text = text + suffix
|
||||
|
||||
return text
|
||||
|
||||
|
||||
def normalize_numbers(text):
|
||||
text = re.sub(_comma_number_re, _remove_commas, text)
|
||||
text = re.sub(_currency_re, _expand_currency, text)
|
||||
text = re.sub(_decimal_number_re, _expand_decimal_point, text)
|
||||
text = re.sub(_ordinal_re, _expand_ordinal, text)
|
||||
# text = re.sub(_range_re, _expand_range, text)
|
||||
# text = re.sub(_measurement_re, _expand_measurement, text)
|
||||
text = re.sub(_roman_re, _expand_roman, text)
|
||||
text = re.sub(_multiply_re, _expand_multiply, text)
|
||||
text = re.sub(_number_re, _expand_number, text)
|
||||
return text
|
|
@ -4,16 +4,41 @@
|
|||
Defines the set of symbols used in text input to the model.
|
||||
|
||||
The default is a set of ASCII characters that works well for English or text that has been run through Unidecode. For other data, you can modify _characters. See TRAINING_DATA.md for details. '''
|
||||
from common.text import cmudict
|
||||
from .cmudict import valid_symbols
|
||||
|
||||
_pad = '_'
|
||||
_punctuation = '!\'(),.:;? '
|
||||
_special = '-'
|
||||
_letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'
|
||||
|
||||
# Prepend "@" to ARPAbet symbols to ensure uniqueness (some are the same as uppercase letters):
|
||||
_arpabet = ['@' + s for s in cmudict.valid_symbols]
|
||||
_arpabet = ['@' + s for s in valid_symbols]
|
||||
|
||||
# Export all symbols:
|
||||
symbols = [_pad] + list(_special) + list(_punctuation) + list(_letters) + _arpabet
|
||||
pad_idx = 0
|
||||
|
||||
def get_symbols(symbol_set='english_basic'):
|
||||
if symbol_set == 'english_basic':
|
||||
_pad = '_'
|
||||
_punctuation = '!\'(),.:;? '
|
||||
_special = '-'
|
||||
_letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'
|
||||
symbols = list(_pad + _special + _punctuation + _letters) + _arpabet
|
||||
elif symbol_set == 'english_basic_lowercase':
|
||||
_pad = '_'
|
||||
_punctuation = '!\'"(),.:;? '
|
||||
_special = '-'
|
||||
_letters = 'abcdefghijklmnopqrstuvwxyz'
|
||||
symbols = list(_pad + _special + _punctuation + _letters) + _arpabet
|
||||
elif symbol_set == 'english_expanded':
|
||||
_punctuation = '!\'",.:;? '
|
||||
_math = '#%&*+-/[]()'
|
||||
_special = '_@©°½—₩€$'
|
||||
_accented = 'áçéêëñöøćž'
|
||||
_letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'
|
||||
symbols = list(_punctuation + _math + _special + _accented + _letters) + _arpabet
|
||||
else:
|
||||
raise Exception("{} symbol set does not exist".format(symbol_set))
|
||||
|
||||
return symbols
|
||||
|
||||
|
||||
def get_pad_idx(symbol_set='english_basic'):
|
||||
if symbol_set in {'english_basic', 'english_basic_lowercase'}:
|
||||
return 0
|
||||
else:
|
||||
raise Exception("{} symbol set not used yet".format(symbol_set))
|
||||
|
|
175
PyTorch/SpeechSynthesis/FastPitch/common/text/text_processing.py
Normal file
175
PyTorch/SpeechSynthesis/FastPitch/common/text/text_processing.py
Normal file
|
@ -0,0 +1,175 @@
|
|||
""" adapted from https://github.com/keithito/tacotron """
|
||||
import re
|
||||
import numpy as np
|
||||
from . import cleaners
|
||||
from .symbols import get_symbols
|
||||
from .cmudict import CMUDict
|
||||
from .numerical import _currency_re, _expand_currency
|
||||
|
||||
|
||||
#########
|
||||
# REGEX #
|
||||
#########
|
||||
|
||||
# Regular expression matching text enclosed in curly braces for encoding
|
||||
_curly_re = re.compile(r'(.*?)\{(.+?)\}(.*)')
|
||||
|
||||
# Regular expression matching words and not words
|
||||
_words_re = re.compile(r"([a-zA-ZÀ-ž]+['][a-zA-ZÀ-ž]{1,2}|[a-zA-ZÀ-ž]+)|([{][^}]+[}]|[^a-zA-ZÀ-ž{}]+)")
|
||||
|
||||
# Regular expression separating words enclosed in curly braces for cleaning
|
||||
_arpa_re = re.compile(r'{[^}]+}|\S+')
|
||||
|
||||
|
||||
def lines_to_list(filename):
|
||||
with open(filename, encoding='utf-8') as f:
|
||||
lines = f.readlines()
|
||||
lines = [l.rstrip() for l in lines]
|
||||
return lines
|
||||
|
||||
|
||||
class TextProcessing(object):
|
||||
def __init__(self, symbol_set, cleaner_names, p_arpabet=0.0,
|
||||
handle_arpabet='word', handle_arpabet_ambiguous='ignore',
|
||||
expand_currency=True):
|
||||
self.symbols = get_symbols(symbol_set)
|
||||
self.cleaner_names = cleaner_names
|
||||
|
||||
# Mappings from symbol to numeric ID and vice versa:
|
||||
self.symbol_to_id = {s: i for i, s in enumerate(self.symbols)}
|
||||
self.id_to_symbol = {i: s for i, s in enumerate(self.symbols)}
|
||||
self.expand_currency = expand_currency
|
||||
|
||||
# cmudict
|
||||
self.p_arpabet = p_arpabet
|
||||
self.handle_arpabet = handle_arpabet
|
||||
self.handle_arpabet_ambiguous = handle_arpabet_ambiguous
|
||||
|
||||
|
||||
def text_to_sequence(self, text):
|
||||
sequence = []
|
||||
|
||||
# Check for curly braces and treat their contents as ARPAbet:
|
||||
while len(text):
|
||||
m = _curly_re.match(text)
|
||||
if not m:
|
||||
sequence += self.symbols_to_sequence(text)
|
||||
break
|
||||
sequence += self.symbols_to_sequence(m.group(1))
|
||||
sequence += self.arpabet_to_sequence(m.group(2))
|
||||
text = m.group(3)
|
||||
|
||||
return sequence
|
||||
|
||||
def sequence_to_text(self, sequence):
|
||||
result = ''
|
||||
for symbol_id in sequence:
|
||||
if symbol_id in self.id_to_symbol:
|
||||
s = self.id_to_symbol[symbol_id]
|
||||
# Enclose ARPAbet back in curly braces:
|
||||
if len(s) > 1 and s[0] == '@':
|
||||
s = '{%s}' % s[1:]
|
||||
result += s
|
||||
return result.replace('}{', ' ')
|
||||
|
||||
def clean_text(self, text):
|
||||
for name in self.cleaner_names:
|
||||
cleaner = getattr(cleaners, name)
|
||||
if not cleaner:
|
||||
raise Exception('Unknown cleaner: %s' % name)
|
||||
text = cleaner(text)
|
||||
|
||||
return text
|
||||
|
||||
def symbols_to_sequence(self, symbols):
|
||||
return [self.symbol_to_id[s] for s in symbols if s in self.symbol_to_id]
|
||||
|
||||
def arpabet_to_sequence(self, text):
|
||||
return self.symbols_to_sequence(['@' + s for s in text.split()])
|
||||
|
||||
def get_arpabet(self, word):
|
||||
arpabet_suffix = ''
|
||||
|
||||
if word.lower() in cmudict.heteronyms:
|
||||
return word
|
||||
|
||||
if len(word) > 2 and word.endswith("'s"):
|
||||
arpabet = cmudict.lookup(word)
|
||||
if arpabet is None:
|
||||
arpabet = self.get_arpabet(word[:-2])
|
||||
arpabet_suffix = ' Z'
|
||||
elif len(word) > 1 and word.endswith("s"):
|
||||
arpabet = cmudict.lookup(word)
|
||||
if arpabet is None:
|
||||
arpabet = self.get_arpabet(word[:-1])
|
||||
arpabet_suffix = ' Z'
|
||||
else:
|
||||
arpabet = cmudict.lookup(word)
|
||||
|
||||
if arpabet is None:
|
||||
return word
|
||||
elif arpabet[0] == '{':
|
||||
arpabet = [arpabet[1:-1]]
|
||||
|
||||
if len(arpabet) > 1:
|
||||
if self.handle_arpabet_ambiguous == 'first':
|
||||
arpabet = arpabet[0]
|
||||
elif self.handle_arpabet_ambiguous == 'random':
|
||||
arpabet = np.random.choice(arpabet)
|
||||
elif self.handle_arpabet_ambiguous == 'ignore':
|
||||
return word
|
||||
else:
|
||||
arpabet = arpabet[0]
|
||||
|
||||
arpabet = "{" + arpabet + arpabet_suffix + "}"
|
||||
|
||||
return arpabet
|
||||
|
||||
# def get_characters(self, word):
|
||||
# for name in self.cleaner_names:
|
||||
# cleaner = getattr(cleaners, f'{name}_post_chars')
|
||||
# if not cleaner:
|
||||
# raise Exception('Unknown cleaner: %s' % name)
|
||||
# word = cleaner(word)
|
||||
|
||||
# return word
|
||||
|
||||
def encode_text(self, text, return_all=False):
|
||||
if self.expand_currency:
|
||||
text = re.sub(_currency_re, _expand_currency, text)
|
||||
text_clean = [self.clean_text(split) if split[0] != '{' else split
|
||||
for split in _arpa_re.findall(text)]
|
||||
text_clean = ' '.join(text_clean)
|
||||
text = text_clean
|
||||
|
||||
text_arpabet = ''
|
||||
if self.p_arpabet > 0:
|
||||
if self.handle_arpabet == 'sentence':
|
||||
if np.random.uniform() < self.p_arpabet:
|
||||
words = _words_re.findall(text)
|
||||
text_arpabet = [
|
||||
self.get_arpabet(word[0])
|
||||
if (word[0] != '') else word[1]
|
||||
for word in words]
|
||||
text_arpabet = ''.join(text_arpabet)
|
||||
text = text_arpabet
|
||||
elif self.handle_arpabet == 'word':
|
||||
words = _words_re.findall(text)
|
||||
text_arpabet = [
|
||||
word[1] if word[0] == '' else (
|
||||
self.get_arpabet(word[0])
|
||||
if np.random.uniform() < self.p_arpabet
|
||||
else word[0])
|
||||
for word in words]
|
||||
text_arpabet = ''.join(text_arpabet)
|
||||
text = text_arpabet
|
||||
elif self.handle_arpabet != '':
|
||||
raise Exception("{} handle_arpabet is not supported".format(
|
||||
self.handle_arpabet))
|
||||
|
||||
text_encoded = self.text_to_sequence(text)
|
||||
|
||||
if return_all:
|
||||
return text_encoded, text_clean, text_arpabet
|
||||
|
||||
return text_encoded
|
|
@ -25,7 +25,6 @@
|
|||
#
|
||||
# *****************************************************************************
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
|
@ -48,14 +47,20 @@ def load_wav_to_torch(full_path):
|
|||
return torch.FloatTensor(data.astype(np.float32)), sampling_rate
|
||||
|
||||
|
||||
def load_filepaths_and_text(dataset_path, filename, split="|"):
|
||||
def load_filepaths_and_text(dataset_path, fnames, has_speakers=False, split="|"):
|
||||
def split_line(root, line):
|
||||
parts = line.strip().split(split)
|
||||
paths, text = parts[:-1], parts[-1]
|
||||
return tuple(os.path.join(root, p) for p in paths) + (text,)
|
||||
with open(filename, encoding='utf-8') as f:
|
||||
filepaths_and_text = [split_line(dataset_path, line) for line in f]
|
||||
return filepaths_and_text
|
||||
if has_speakers:
|
||||
paths, non_paths = parts[:-2], parts[-2:]
|
||||
else:
|
||||
paths, non_paths = parts[:-1], parts[-1:]
|
||||
return tuple(str(Path(root, p)) for p in paths) + tuple(non_paths)
|
||||
|
||||
fpaths_and_text = []
|
||||
for fname in fnames.split(','):
|
||||
with open(fname, encoding='utf-8') as f:
|
||||
fpaths_and_text += [split_line(dataset_path, line) for line in f]
|
||||
return fpaths_and_text
|
||||
|
||||
|
||||
def stats_filename(dataset_path, filelist_path, feature_name):
|
||||
|
|
|
@ -40,6 +40,7 @@ from torch.utils.data import DataLoader
|
|||
from common import utils
|
||||
from inference import load_and_setup_model
|
||||
from tacotron2.data_function import TextMelLoader, TextMelCollate, batch_to_gpu
|
||||
from common.text.text_processing import TextProcessing
|
||||
|
||||
|
||||
def parse_args(parser):
|
||||
|
@ -59,6 +60,8 @@ def parse_args(parser):
|
|||
parser.add_argument('--text-cleaners', nargs='*',
|
||||
default=['english_cleaners'], type=str,
|
||||
help='Type of text cleaners for input text')
|
||||
parser.add_argument('--symbol-set', type=str, default='english_basic',
|
||||
help='Define symbol set for input text')
|
||||
parser.add_argument('--max-wav-value', default=32768.0, type=float,
|
||||
help='Maximum audiowave value')
|
||||
parser.add_argument('--sampling-rate', default=22050, type=int,
|
||||
|
@ -98,6 +101,7 @@ def parse_args(parser):
|
|||
class FilenamedLoader(TextMelLoader):
|
||||
def __init__(self, filenames, *args, **kwargs):
|
||||
super(FilenamedLoader, self).__init__(*args, **kwargs)
|
||||
self.tp = TextProcessing(args[-1].symbol_set, args[-1].text_cleaners)
|
||||
self.filenames = filenames
|
||||
|
||||
def __getitem__(self, index):
|
||||
|
@ -211,6 +215,8 @@ def main():
|
|||
|
||||
filenames = [Path(l.split('|')[0]).stem
|
||||
for l in open(args.wav_text_filelist, 'r')]
|
||||
# Compatibility with Tacotron2 Data loader
|
||||
args.n_speakers = 1
|
||||
dataset = FilenamedLoader(filenames, args.dataset_path, args.wav_text_filelist,
|
||||
args, load_mel_from_disk=False)
|
||||
# TextMelCollate supports only n_frames_per_step=1
|
||||
|
|
|
@ -27,8 +27,6 @@
|
|||
|
||||
import argparse
|
||||
|
||||
from common.text import symbols
|
||||
|
||||
|
||||
def parse_fastpitch_args(parent, add_help=False):
|
||||
"""
|
||||
|
@ -42,11 +40,12 @@ def parse_fastpitch_args(parent, add_help=False):
|
|||
help='Number of bins in mel-spectrograms')
|
||||
io.add_argument('--max-seq-len', default=2048, type=int,
|
||||
help='')
|
||||
global symbols
|
||||
len_symbols = len(symbols)
|
||||
|
||||
symbols = parser.add_argument_group('symbols parameters')
|
||||
symbols.add_argument('--n-symbols', default=len_symbols, type=int,
|
||||
symbols.add_argument('--n-symbols', default=148, type=int,
|
||||
help='Number of symbols in dictionary')
|
||||
symbols.add_argument('--padding-idx', default=0, type=int,
|
||||
help='Index of padding symbol in dictionary')
|
||||
symbols.add_argument('--symbols-embedding-dim', default=384, type=int,
|
||||
help='Input embedding dimension')
|
||||
|
||||
|
@ -102,11 +101,18 @@ def parse_fastpitch_args(parent, add_help=False):
|
|||
|
||||
pitch_pred = parser.add_argument_group('pitch predictor parameters')
|
||||
pitch_pred.add_argument('--pitch-predictor-kernel-size', default=3, type=int,
|
||||
help='Pitch predictor conv-1D kernel size')
|
||||
help='Pitch predictor conv-1D kernel size')
|
||||
pitch_pred.add_argument('--pitch-predictor-filter-size', default=256, type=int,
|
||||
help='Pitch predictor conv-1D filter size')
|
||||
help='Pitch predictor conv-1D filter size')
|
||||
pitch_pred.add_argument('--p-pitch-predictor-dropout', default=0.1, type=float,
|
||||
help='Pitch probability for pitch predictor')
|
||||
help='Pitch probability for pitch predictor')
|
||||
pitch_pred.add_argument('--pitch-predictor-n-layers', default=2, type=int,
|
||||
help='Number of conv-1D layers')
|
||||
help='Number of conv-1D layers')
|
||||
|
||||
cond = parser.add_argument_group('conditioning parameters')
|
||||
cond.add_argument('--pitch-embedding-kernel-size', default=3, type=int,
|
||||
help='Pitch embedding conv-1D kernel size')
|
||||
cond.add_argument('--speaker-emb-weight', type=float, default=1.0,
|
||||
help='Scale speaker embedding')
|
||||
|
||||
return parser
|
||||
|
|
|
@ -31,6 +31,7 @@ import torch
|
|||
|
||||
from common.utils import to_gpu
|
||||
from tacotron2.data_function import TextMelLoader
|
||||
from common.text.text_processing import TextProcessing
|
||||
|
||||
|
||||
class TextMelAliLoader(TextMelLoader):
|
||||
|
@ -38,18 +39,27 @@ class TextMelAliLoader(TextMelLoader):
|
|||
"""
|
||||
def __init__(self, *args):
|
||||
super(TextMelAliLoader, self).__init__(*args)
|
||||
if len(self.audiopaths_and_text[0]) != 4:
|
||||
raise ValueError('Expected four columns in audiopaths file')
|
||||
self.tp = TextProcessing(args[-1].symbol_set, args[-1].text_cleaners)
|
||||
self.n_speakers = args[-1].n_speakers
|
||||
if len(self.audiopaths_and_text[0]) != 4 + (args[-1].n_speakers > 1):
|
||||
raise ValueError('Expected four columns in audiopaths file for single speaker model. \n'
|
||||
'For multispeaker model, the filelist format is '
|
||||
'<mel>|<dur>|<pitch>|<text>|<speaker_id>')
|
||||
|
||||
def __getitem__(self, index):
|
||||
# separate filename and text
|
||||
audiopath, durpath, pitchpath, text = self.audiopaths_and_text[index]
|
||||
if self.n_speakers > 1:
|
||||
audiopath, durpath, pitchpath, text, speaker = self.audiopaths_and_text[index]
|
||||
speaker = int(speaker)
|
||||
else:
|
||||
audiopath, durpath, pitchpath, text = self.audiopaths_and_text[index]
|
||||
speaker = None
|
||||
len_text = len(text)
|
||||
text = self.get_text(text)
|
||||
mel = self.get_mel(audiopath)
|
||||
dur = torch.load(durpath)
|
||||
pitch = torch.load(pitchpath)
|
||||
return (text, mel, len_text, dur, pitch)
|
||||
return (text, mel, len_text, dur, pitch, speaker)
|
||||
|
||||
|
||||
class TextMelAliCollate():
|
||||
|
@ -107,16 +117,24 @@ class TextMelAliCollate():
|
|||
pitch = batch[ids_sorted_decreasing[i]][4]
|
||||
pitch_padded[i, :pitch.shape[0]] = pitch
|
||||
|
||||
if batch[0][5] is not None:
|
||||
speaker = torch.zeros_like(input_lengths)
|
||||
for i in range(len(ids_sorted_decreasing)):
|
||||
speaker[i] = batch[ids_sorted_decreasing[i]][5]
|
||||
else:
|
||||
speaker = None
|
||||
|
||||
# count number of items - characters in text
|
||||
len_x = [x[2] for x in batch]
|
||||
len_x = torch.Tensor(len_x)
|
||||
return (text_padded, input_lengths, mel_padded,
|
||||
output_lengths, len_x, dur_padded, dur_lens, pitch_padded)
|
||||
|
||||
return (text_padded, input_lengths, mel_padded, output_lengths,
|
||||
len_x, dur_padded, dur_lens, pitch_padded, speaker)
|
||||
|
||||
|
||||
def batch_to_gpu(batch):
|
||||
text_padded, input_lengths, mel_padded, \
|
||||
output_lengths, len_x, dur_padded, dur_lens, pitch_padded = batch
|
||||
text_padded, input_lengths, mel_padded, output_lengths, \
|
||||
len_x, dur_padded, dur_lens, pitch_padded, speaker = batch
|
||||
text_padded = to_gpu(text_padded).long()
|
||||
input_lengths = to_gpu(input_lengths).long()
|
||||
mel_padded = to_gpu(mel_padded).float()
|
||||
|
@ -124,9 +142,11 @@ def batch_to_gpu(batch):
|
|||
dur_padded = to_gpu(dur_padded).long()
|
||||
dur_lens = to_gpu(dur_lens).long()
|
||||
pitch_padded = to_gpu(pitch_padded).float()
|
||||
if speaker is not None:
|
||||
speaker = to_gpu(speaker).long()
|
||||
# Alignments act as both inputs and targets - pass shallow copies
|
||||
x = [text_padded, input_lengths, mel_padded, output_lengths,
|
||||
dur_padded, dur_lens, pitch_padded]
|
||||
dur_padded, dur_lens, pitch_padded, speaker]
|
||||
y = [mel_padded, dur_padded, dur_lens, pitch_padded]
|
||||
len_x = torch.sum(output_lengths)
|
||||
return (x, y, len_x)
|
||||
|
|
|
@ -73,7 +73,7 @@ class FastPitchLoss(nn.Module):
|
|||
'mel_loss': mel_loss.clone().detach(),
|
||||
'duration_predictor_loss': dur_pred_loss.clone().detach(),
|
||||
'pitch_loss': pitch_loss.clone().detach(),
|
||||
'dur_error': (torch.abs(dur_pred - dur_tgt).sum()
|
||||
'dur_error': (torch.abs(dur_pred - dur_tgt).sum()
|
||||
/ dur_mask.sum()).detach(),
|
||||
}
|
||||
assert meta_agg in ('sum', 'mean')
|
||||
|
|
|
@ -1,210 +1,236 @@
|
|||
# *****************************************************************************
|
||||
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
# * Redistributions of source code must retain the above copyright
|
||||
# notice, this list of conditions and the following disclaimer.
|
||||
# * Redistributions in binary form must reproduce the above copyright
|
||||
# notice, this list of conditions and the following disclaimer in the
|
||||
# documentation and/or other materials provided with the distribution.
|
||||
# * Neither the name of the NVIDIA CORPORATION nor the
|
||||
# names of its contributors may be used to endorse or promote products
|
||||
# derived from this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
||||
# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
||||
# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
|
||||
# DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
||||
# (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
||||
# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
|
||||
# ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
# *****************************************************************************
|
||||
|
||||
import torch
|
||||
from torch import nn as nn
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
|
||||
from common.layers import ConvReLUNorm
|
||||
from common.utils import mask_from_lens
|
||||
from fastpitch.transformer import FFTransformer
|
||||
|
||||
|
||||
def regulate_len(durations, enc_out, pace=1.0, mel_max_len=None):
|
||||
"""If target=None, then predicted durations are applied"""
|
||||
reps = torch.round(durations.float() / pace).long()
|
||||
dec_lens = reps.sum(dim=1)
|
||||
|
||||
enc_rep = pad_sequence([torch.repeat_interleave(o, r, dim=0)
|
||||
for o, r in zip(enc_out, reps)],
|
||||
batch_first=True)
|
||||
if mel_max_len:
|
||||
enc_rep = enc_rep[:, :mel_max_len]
|
||||
dec_lens = torch.clamp_max(dec_lens, mel_max_len)
|
||||
return enc_rep, dec_lens
|
||||
|
||||
|
||||
class TemporalPredictor(nn.Module):
|
||||
"""Predicts a single float per each temporal location"""
|
||||
|
||||
def __init__(self, input_size, filter_size, kernel_size, dropout,
|
||||
n_layers=2):
|
||||
super(TemporalPredictor, self).__init__()
|
||||
|
||||
self.layers = nn.Sequential(*[
|
||||
ConvReLUNorm(input_size if i == 0 else filter_size, filter_size,
|
||||
kernel_size=kernel_size, dropout=dropout)
|
||||
for i in range(n_layers)]
|
||||
)
|
||||
self.fc = nn.Linear(filter_size, 1, bias=True)
|
||||
|
||||
def forward(self, enc_out, enc_out_mask):
|
||||
out = enc_out * enc_out_mask
|
||||
out = self.layers(out.transpose(1, 2)).transpose(1, 2)
|
||||
out = self.fc(out) * enc_out_mask
|
||||
return out.squeeze(-1)
|
||||
|
||||
|
||||
class FastPitch(nn.Module):
|
||||
def __init__(self, n_mel_channels, max_seq_len, n_symbols,
|
||||
symbols_embedding_dim, in_fft_n_layers, in_fft_n_heads,
|
||||
in_fft_d_head,
|
||||
in_fft_conv1d_kernel_size, in_fft_conv1d_filter_size,
|
||||
in_fft_output_size,
|
||||
p_in_fft_dropout, p_in_fft_dropatt, p_in_fft_dropemb,
|
||||
out_fft_n_layers, out_fft_n_heads, out_fft_d_head,
|
||||
out_fft_conv1d_kernel_size, out_fft_conv1d_filter_size,
|
||||
out_fft_output_size,
|
||||
p_out_fft_dropout, p_out_fft_dropatt, p_out_fft_dropemb,
|
||||
dur_predictor_kernel_size, dur_predictor_filter_size,
|
||||
p_dur_predictor_dropout, dur_predictor_n_layers,
|
||||
pitch_predictor_kernel_size, pitch_predictor_filter_size,
|
||||
p_pitch_predictor_dropout, pitch_predictor_n_layers):
|
||||
super(FastPitch, self).__init__()
|
||||
del max_seq_len # unused
|
||||
del n_symbols
|
||||
|
||||
self.encoder = FFTransformer(
|
||||
n_layer=in_fft_n_layers, n_head=in_fft_n_heads,
|
||||
d_model=symbols_embedding_dim,
|
||||
d_head=in_fft_d_head,
|
||||
d_inner=in_fft_conv1d_filter_size,
|
||||
kernel_size=in_fft_conv1d_kernel_size,
|
||||
dropout=p_in_fft_dropout,
|
||||
dropatt=p_in_fft_dropatt,
|
||||
dropemb=p_in_fft_dropemb,
|
||||
d_embed=symbols_embedding_dim,
|
||||
embed_input=True)
|
||||
|
||||
self.duration_predictor = TemporalPredictor(
|
||||
in_fft_output_size,
|
||||
filter_size=dur_predictor_filter_size,
|
||||
kernel_size=dur_predictor_kernel_size,
|
||||
dropout=p_dur_predictor_dropout, n_layers=dur_predictor_n_layers
|
||||
)
|
||||
|
||||
self.decoder = FFTransformer(
|
||||
n_layer=out_fft_n_layers, n_head=out_fft_n_heads,
|
||||
d_model=symbols_embedding_dim,
|
||||
d_head=out_fft_d_head,
|
||||
d_inner=out_fft_conv1d_filter_size,
|
||||
kernel_size=out_fft_conv1d_kernel_size,
|
||||
dropout=p_out_fft_dropout,
|
||||
dropatt=p_out_fft_dropatt,
|
||||
dropemb=p_out_fft_dropemb,
|
||||
d_embed=symbols_embedding_dim,
|
||||
embed_input=False)
|
||||
|
||||
self.pitch_predictor = TemporalPredictor(
|
||||
in_fft_output_size,
|
||||
filter_size=pitch_predictor_filter_size,
|
||||
kernel_size=pitch_predictor_kernel_size,
|
||||
dropout=p_pitch_predictor_dropout, n_layers=pitch_predictor_n_layers
|
||||
)
|
||||
self.pitch_emb = nn.Conv1d(1, symbols_embedding_dim, kernel_size=3,
|
||||
padding=1)
|
||||
|
||||
# Store values precomputed for training data within the model
|
||||
self.register_buffer('pitch_mean', torch.zeros(1))
|
||||
self.register_buffer('pitch_std', torch.zeros(1))
|
||||
|
||||
self.proj = nn.Linear(out_fft_output_size, n_mel_channels, bias=True)
|
||||
|
||||
def forward(self, inputs, use_gt_durations=True, use_gt_pitch=True,
|
||||
pace=1.0, max_duration=75):
|
||||
inputs, _, mel_tgt, _, dur_tgt, _, pitch_tgt = inputs
|
||||
mel_max_len = mel_tgt.size(2)
|
||||
|
||||
# Input FFT
|
||||
enc_out, enc_mask = self.encoder(inputs)
|
||||
|
||||
# Embedded for predictors
|
||||
pred_enc_out, pred_enc_mask = enc_out, enc_mask
|
||||
|
||||
# Predict durations
|
||||
log_dur_pred = self.duration_predictor(pred_enc_out, pred_enc_mask)
|
||||
dur_pred = torch.clamp(torch.exp(log_dur_pred) - 1, 0, max_duration)
|
||||
|
||||
# Predict pitch
|
||||
pitch_pred = self.pitch_predictor(enc_out, enc_mask)
|
||||
|
||||
if use_gt_pitch and pitch_tgt is not None:
|
||||
pitch_emb = self.pitch_emb(pitch_tgt.unsqueeze(1))
|
||||
else:
|
||||
pitch_emb = self.pitch_emb(pitch_pred.unsqueeze(1))
|
||||
enc_out = enc_out + pitch_emb.transpose(1, 2)
|
||||
|
||||
len_regulated, dec_lens = regulate_len(
|
||||
dur_tgt if use_gt_durations else dur_pred,
|
||||
enc_out, pace, mel_max_len)
|
||||
|
||||
# Output FFT
|
||||
dec_out, dec_mask = self.decoder(len_regulated, dec_lens)
|
||||
mel_out = self.proj(dec_out)
|
||||
return mel_out, dec_mask, dur_pred, log_dur_pred, pitch_pred
|
||||
|
||||
def infer(self, inputs, input_lens, pace=1.0, dur_tgt=None, pitch_tgt=None,
|
||||
pitch_transform=None, max_duration=75):
|
||||
del input_lens # unused
|
||||
|
||||
# Input FFT
|
||||
enc_out, enc_mask = self.encoder(inputs)
|
||||
|
||||
# Embedded for predictors
|
||||
pred_enc_out, pred_enc_mask = enc_out, enc_mask
|
||||
|
||||
# Predict durations
|
||||
log_dur_pred = self.duration_predictor(pred_enc_out, pred_enc_mask)
|
||||
dur_pred = torch.clamp(torch.exp(log_dur_pred) - 1, 0, max_duration)
|
||||
|
||||
# Pitch over chars
|
||||
pitch_pred = self.pitch_predictor(enc_out, enc_mask)
|
||||
|
||||
if pitch_transform is not None:
|
||||
if self.pitch_std[0] == 0.0:
|
||||
# XXX LJSpeech-1.1 defaults
|
||||
mean, std = 218.14, 67.24
|
||||
else:
|
||||
mean, std = self.pitch_mean[0], self.pitch_std[0]
|
||||
pitch_pred = pitch_transform(pitch_pred, enc_mask.sum(dim=(1,2)), mean, std)
|
||||
|
||||
if pitch_tgt is None:
|
||||
pitch_emb = self.pitch_emb(pitch_pred.unsqueeze(1)).transpose(1, 2)
|
||||
else:
|
||||
pitch_emb = self.pitch_emb(pitch_tgt.unsqueeze(1)).transpose(1, 2)
|
||||
|
||||
enc_out = enc_out + pitch_emb
|
||||
|
||||
len_regulated, dec_lens = regulate_len(
|
||||
dur_pred if dur_tgt is None else dur_tgt,
|
||||
enc_out, pace, mel_max_len=None)
|
||||
|
||||
dec_out, dec_mask = self.decoder(len_regulated, dec_lens)
|
||||
mel_out = self.proj(dec_out)
|
||||
# mel_lens = dec_mask.squeeze(2).sum(axis=1).long()
|
||||
mel_out = mel_out.permute(0, 2, 1) # For inference.py
|
||||
return mel_out, dec_lens, dur_pred, pitch_pred
|
||||
# *****************************************************************************
|
||||
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
# * Redistributions of source code must retain the above copyright
|
||||
# notice, this list of conditions and the following disclaimer.
|
||||
# * Redistributions in binary form must reproduce the above copyright
|
||||
# notice, this list of conditions and the following disclaimer in the
|
||||
# documentation and/or other materials provided with the distribution.
|
||||
# * Neither the name of the NVIDIA CORPORATION nor the
|
||||
# names of its contributors may be used to endorse or promote products
|
||||
# derived from this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
||||
# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
||||
# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
|
||||
# DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
||||
# (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
||||
# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
|
||||
# ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
# *****************************************************************************
|
||||
|
||||
import torch
|
||||
from torch import nn as nn
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
|
||||
from common.layers import ConvReLUNorm
|
||||
from common.utils import mask_from_lens
|
||||
from fastpitch.transformer import FFTransformer
|
||||
|
||||
|
||||
def regulate_len(durations, enc_out, pace=1.0, mel_max_len=None):
|
||||
"""If target=None, then predicted durations are applied"""
|
||||
reps = torch.round(durations.float() / pace).long()
|
||||
dec_lens = reps.sum(dim=1)
|
||||
|
||||
enc_rep = pad_sequence([torch.repeat_interleave(o, r, dim=0)
|
||||
for o, r in zip(enc_out, reps)],
|
||||
batch_first=True)
|
||||
if mel_max_len:
|
||||
enc_rep = enc_rep[:, :mel_max_len]
|
||||
dec_lens = torch.clamp_max(dec_lens, mel_max_len)
|
||||
return enc_rep, dec_lens
|
||||
|
||||
|
||||
class TemporalPredictor(nn.Module):
|
||||
"""Predicts a single float per each temporal location"""
|
||||
|
||||
def __init__(self, input_size, filter_size, kernel_size, dropout,
|
||||
n_layers=2):
|
||||
super(TemporalPredictor, self).__init__()
|
||||
|
||||
self.layers = nn.Sequential(*[
|
||||
ConvReLUNorm(input_size if i == 0 else filter_size, filter_size,
|
||||
kernel_size=kernel_size, dropout=dropout)
|
||||
for i in range(n_layers)]
|
||||
)
|
||||
self.fc = nn.Linear(filter_size, 1, bias=True)
|
||||
|
||||
def forward(self, enc_out, enc_out_mask):
|
||||
out = enc_out * enc_out_mask
|
||||
out = self.layers(out.transpose(1, 2)).transpose(1, 2)
|
||||
out = self.fc(out) * enc_out_mask
|
||||
return out.squeeze(-1)
|
||||
|
||||
|
||||
class FastPitch(nn.Module):
|
||||
def __init__(self, n_mel_channels, max_seq_len, n_symbols, padding_idx,
|
||||
symbols_embedding_dim, in_fft_n_layers, in_fft_n_heads,
|
||||
in_fft_d_head,
|
||||
in_fft_conv1d_kernel_size, in_fft_conv1d_filter_size,
|
||||
in_fft_output_size,
|
||||
p_in_fft_dropout, p_in_fft_dropatt, p_in_fft_dropemb,
|
||||
out_fft_n_layers, out_fft_n_heads, out_fft_d_head,
|
||||
out_fft_conv1d_kernel_size, out_fft_conv1d_filter_size,
|
||||
out_fft_output_size,
|
||||
p_out_fft_dropout, p_out_fft_dropatt, p_out_fft_dropemb,
|
||||
dur_predictor_kernel_size, dur_predictor_filter_size,
|
||||
p_dur_predictor_dropout, dur_predictor_n_layers,
|
||||
pitch_predictor_kernel_size, pitch_predictor_filter_size,
|
||||
p_pitch_predictor_dropout, pitch_predictor_n_layers,
|
||||
pitch_embedding_kernel_size, n_speakers, speaker_emb_weight):
|
||||
super(FastPitch, self).__init__()
|
||||
del max_seq_len # unused
|
||||
|
||||
self.encoder = FFTransformer(
|
||||
n_layer=in_fft_n_layers, n_head=in_fft_n_heads,
|
||||
d_model=symbols_embedding_dim,
|
||||
d_head=in_fft_d_head,
|
||||
d_inner=in_fft_conv1d_filter_size,
|
||||
kernel_size=in_fft_conv1d_kernel_size,
|
||||
dropout=p_in_fft_dropout,
|
||||
dropatt=p_in_fft_dropatt,
|
||||
dropemb=p_in_fft_dropemb,
|
||||
embed_input=True,
|
||||
d_embed=symbols_embedding_dim,
|
||||
n_embed=n_symbols,
|
||||
padding_idx=padding_idx)
|
||||
|
||||
if n_speakers > 1:
|
||||
self.speaker_emb = nn.Embedding(n_speakers, symbols_embedding_dim)
|
||||
else:
|
||||
self.speaker_emb = None
|
||||
self.speaker_emb_weight = speaker_emb_weight
|
||||
|
||||
self.duration_predictor = TemporalPredictor(
|
||||
in_fft_output_size,
|
||||
filter_size=dur_predictor_filter_size,
|
||||
kernel_size=dur_predictor_kernel_size,
|
||||
dropout=p_dur_predictor_dropout, n_layers=dur_predictor_n_layers
|
||||
)
|
||||
|
||||
self.decoder = FFTransformer(
|
||||
n_layer=out_fft_n_layers, n_head=out_fft_n_heads,
|
||||
d_model=symbols_embedding_dim,
|
||||
d_head=out_fft_d_head,
|
||||
d_inner=out_fft_conv1d_filter_size,
|
||||
kernel_size=out_fft_conv1d_kernel_size,
|
||||
dropout=p_out_fft_dropout,
|
||||
dropatt=p_out_fft_dropatt,
|
||||
dropemb=p_out_fft_dropemb,
|
||||
embed_input=False,
|
||||
d_embed=symbols_embedding_dim
|
||||
)
|
||||
|
||||
self.pitch_predictor = TemporalPredictor(
|
||||
in_fft_output_size,
|
||||
filter_size=pitch_predictor_filter_size,
|
||||
kernel_size=pitch_predictor_kernel_size,
|
||||
dropout=p_pitch_predictor_dropout, n_layers=pitch_predictor_n_layers
|
||||
)
|
||||
|
||||
self.pitch_emb = nn.Conv1d(
|
||||
1, symbols_embedding_dim,
|
||||
kernel_size=pitch_embedding_kernel_size,
|
||||
padding=int((pitch_embedding_kernel_size - 1) / 2))
|
||||
|
||||
# Store values precomputed for training data within the model
|
||||
self.register_buffer('pitch_mean', torch.zeros(1))
|
||||
self.register_buffer('pitch_std', torch.zeros(1))
|
||||
|
||||
self.proj = nn.Linear(out_fft_output_size, n_mel_channels, bias=True)
|
||||
|
||||
def forward(self, inputs, use_gt_durations=True, use_gt_pitch=True,
|
||||
pace=1.0, max_duration=75):
|
||||
inputs, _, mel_tgt, _, dur_tgt, _, pitch_tgt, speaker = inputs
|
||||
mel_max_len = mel_tgt.size(2)
|
||||
|
||||
# Calculate speaker embedding
|
||||
if self.speaker_emb is None:
|
||||
spk_emb = 0
|
||||
else:
|
||||
spk_emb = self.speaker_emb(speaker).unsqueeze(1)
|
||||
spk_emb.mul_(self.speaker_emb_weight)
|
||||
|
||||
# Input FFT
|
||||
enc_out, enc_mask = self.encoder(inputs, conditioning=spk_emb)
|
||||
|
||||
# Embedded for predictors
|
||||
pred_enc_out, pred_enc_mask = enc_out, enc_mask
|
||||
|
||||
# Predict durations
|
||||
log_dur_pred = self.duration_predictor(pred_enc_out, pred_enc_mask)
|
||||
dur_pred = torch.clamp(torch.exp(log_dur_pred) - 1, 0, max_duration)
|
||||
|
||||
# Predict pitch
|
||||
pitch_pred = self.pitch_predictor(enc_out, enc_mask)
|
||||
|
||||
if use_gt_pitch and pitch_tgt is not None:
|
||||
pitch_emb = self.pitch_emb(pitch_tgt.unsqueeze(1))
|
||||
else:
|
||||
pitch_emb = self.pitch_emb(pitch_pred.unsqueeze(1))
|
||||
enc_out = enc_out + pitch_emb.transpose(1, 2)
|
||||
|
||||
len_regulated, dec_lens = regulate_len(
|
||||
dur_tgt if use_gt_durations else dur_pred,
|
||||
enc_out, pace, mel_max_len)
|
||||
|
||||
# Output FFT
|
||||
dec_out, dec_mask = self.decoder(len_regulated, dec_lens)
|
||||
mel_out = self.proj(dec_out)
|
||||
return mel_out, dec_mask, dur_pred, log_dur_pred, pitch_pred
|
||||
|
||||
def infer(self, inputs, input_lens, pace=1.0, dur_tgt=None, pitch_tgt=None,
|
||||
pitch_transform=None, max_duration=75, speaker=0):
|
||||
del input_lens # unused
|
||||
|
||||
if self.speaker_emb is None:
|
||||
spk_emb = 0
|
||||
else:
|
||||
speaker = torch.ones(inputs.size(0)).long().to(inputs.device) * speaker
|
||||
spk_emb = self.speaker_emb(speaker).unsqueeze(1)
|
||||
spk_emb.mul_(self.speaker_emb_weight)
|
||||
|
||||
# Input FFT
|
||||
enc_out, enc_mask = self.encoder(inputs, conditioning=spk_emb)
|
||||
|
||||
# Embedded for predictors
|
||||
pred_enc_out, pred_enc_mask = enc_out, enc_mask
|
||||
|
||||
# Predict durations
|
||||
log_dur_pred = self.duration_predictor(pred_enc_out, pred_enc_mask)
|
||||
dur_pred = torch.clamp(torch.exp(log_dur_pred) - 1, 0, max_duration)
|
||||
|
||||
# Pitch over chars
|
||||
pitch_pred = self.pitch_predictor(enc_out, enc_mask)
|
||||
|
||||
if pitch_transform is not None:
|
||||
if self.pitch_std[0] == 0.0:
|
||||
# XXX LJSpeech-1.1 defaults
|
||||
mean, std = 218.14, 67.24
|
||||
else:
|
||||
mean, std = self.pitch_mean[0], self.pitch_std[0]
|
||||
pitch_pred = pitch_transform(pitch_pred, enc_mask.sum(dim=(1,2)), mean, std)
|
||||
|
||||
if pitch_tgt is None:
|
||||
pitch_emb = self.pitch_emb(pitch_pred.unsqueeze(1)).transpose(1, 2)
|
||||
else:
|
||||
pitch_emb = self.pitch_emb(pitch_tgt.unsqueeze(1)).transpose(1, 2)
|
||||
|
||||
enc_out = enc_out + pitch_emb
|
||||
|
||||
len_regulated, dec_lens = regulate_len(
|
||||
dur_pred if dur_tgt is None else dur_tgt,
|
||||
enc_out, pace, mel_max_len=None)
|
||||
|
||||
dec_out, dec_mask = self.decoder(len_regulated, dec_lens)
|
||||
mel_out = self.proj(dec_out)
|
||||
# mel_lens = dec_mask.squeeze(2).sum(axis=1).long()
|
||||
mel_out = mel_out.permute(0, 2, 1) # For inference.py
|
||||
return mel_out, dec_lens, dur_pred, pitch_pred
|
||||
|
|
|
@ -1,218 +1,246 @@
|
|||
# *****************************************************************************
|
||||
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
# * Redistributions of source code must retain the above copyright
|
||||
# notice, this list of conditions and the following disclaimer.
|
||||
# * Redistributions in binary form must reproduce the above copyright
|
||||
# notice, this list of conditions and the following disclaimer in the
|
||||
# documentation and/or other materials provided with the distribution.
|
||||
# * Neither the name of the NVIDIA CORPORATION nor the
|
||||
# names of its contributors may be used to endorse or promote products
|
||||
# derived from this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
||||
# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
||||
# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
|
||||
# DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
||||
# (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
||||
# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
|
||||
# ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
# *****************************************************************************
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
from torch import nn as nn
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
|
||||
from common.layers import ConvReLUNorm
|
||||
from fastpitch.transformer_jit import FFTransformer
|
||||
|
||||
|
||||
def regulate_len(durations, enc_out, pace: float = 1.0,
|
||||
mel_max_len: Optional[int] = None):
|
||||
"""If target=None, then predicted durations are applied"""
|
||||
reps = torch.round(durations.float() / pace).long()
|
||||
dec_lens = reps.sum(dim=1)
|
||||
|
||||
max_len = dec_lens.max()
|
||||
bsz, _, hid = enc_out.size()
|
||||
|
||||
reps_padded = torch.cat([reps, (max_len - dec_lens)[:, None]], dim=1)
|
||||
pad_vec = torch.zeros(bsz, 1, hid, dtype=enc_out.dtype,
|
||||
device=enc_out.device)
|
||||
|
||||
enc_rep = torch.cat([enc_out, pad_vec], dim=1)
|
||||
enc_rep = torch.repeat_interleave(
|
||||
enc_rep.view(-1, hid), reps_padded.view(-1), dim=0
|
||||
).view(bsz, -1, hid)
|
||||
|
||||
# enc_rep = pad_sequence([torch.repeat_interleave(o, r, dim=0)
|
||||
# for o, r in zip(enc_out, reps)],
|
||||
# batch_first=True)
|
||||
if mel_max_len is not None:
|
||||
enc_rep = enc_rep[:, :mel_max_len]
|
||||
dec_lens = torch.clamp_max(dec_lens, mel_max_len)
|
||||
return enc_rep, dec_lens
|
||||
|
||||
|
||||
class TemporalPredictor(nn.Module):
|
||||
"""Predicts a single float per each temporal location"""
|
||||
|
||||
def __init__(self, input_size, filter_size, kernel_size, dropout,
|
||||
n_layers=2):
|
||||
super(TemporalPredictor, self).__init__()
|
||||
|
||||
self.layers = nn.Sequential(*[
|
||||
ConvReLUNorm(input_size if i == 0 else filter_size, filter_size,
|
||||
kernel_size=kernel_size, dropout=dropout)
|
||||
for i in range(n_layers)]
|
||||
)
|
||||
self.fc = nn.Linear(filter_size, 1, bias=True)
|
||||
|
||||
def forward(self, enc_out, enc_out_mask):
|
||||
out = enc_out * enc_out_mask
|
||||
out = self.layers(out.transpose(1, 2)).transpose(1, 2)
|
||||
out = self.fc(out) * enc_out_mask
|
||||
return out.squeeze(-1)
|
||||
|
||||
|
||||
class FastPitch(nn.Module):
|
||||
def __init__(self, n_mel_channels, max_seq_len, n_symbols,
|
||||
symbols_embedding_dim, in_fft_n_layers, in_fft_n_heads,
|
||||
in_fft_d_head,
|
||||
in_fft_conv1d_kernel_size, in_fft_conv1d_filter_size,
|
||||
in_fft_output_size,
|
||||
p_in_fft_dropout, p_in_fft_dropatt, p_in_fft_dropemb,
|
||||
out_fft_n_layers, out_fft_n_heads, out_fft_d_head,
|
||||
out_fft_conv1d_kernel_size, out_fft_conv1d_filter_size,
|
||||
out_fft_output_size,
|
||||
p_out_fft_dropout, p_out_fft_dropatt, p_out_fft_dropemb,
|
||||
dur_predictor_kernel_size, dur_predictor_filter_size,
|
||||
p_dur_predictor_dropout, dur_predictor_n_layers,
|
||||
pitch_predictor_kernel_size, pitch_predictor_filter_size,
|
||||
p_pitch_predictor_dropout, pitch_predictor_n_layers):
|
||||
super(FastPitch, self).__init__()
|
||||
del max_seq_len # unused
|
||||
del n_symbols
|
||||
|
||||
self.encoder = FFTransformer(
|
||||
n_layer=in_fft_n_layers, n_head=in_fft_n_heads,
|
||||
d_model=symbols_embedding_dim,
|
||||
d_head=in_fft_d_head,
|
||||
d_inner=in_fft_conv1d_filter_size,
|
||||
kernel_size=in_fft_conv1d_kernel_size,
|
||||
dropout=p_in_fft_dropout,
|
||||
dropatt=p_in_fft_dropatt,
|
||||
dropemb=p_in_fft_dropemb,
|
||||
d_embed=symbols_embedding_dim,
|
||||
embed_input=True)
|
||||
|
||||
self.duration_predictor = TemporalPredictor(
|
||||
in_fft_output_size,
|
||||
filter_size=dur_predictor_filter_size,
|
||||
kernel_size=dur_predictor_kernel_size,
|
||||
dropout=p_dur_predictor_dropout, n_layers=dur_predictor_n_layers
|
||||
)
|
||||
|
||||
self.decoder = FFTransformer(
|
||||
n_layer=out_fft_n_layers, n_head=out_fft_n_heads,
|
||||
d_model=symbols_embedding_dim,
|
||||
d_head=out_fft_d_head,
|
||||
d_inner=out_fft_conv1d_filter_size,
|
||||
kernel_size=out_fft_conv1d_kernel_size,
|
||||
dropout=p_out_fft_dropout,
|
||||
dropatt=p_out_fft_dropatt,
|
||||
dropemb=p_out_fft_dropemb,
|
||||
d_embed=symbols_embedding_dim,
|
||||
embed_input=False)
|
||||
|
||||
self.pitch_predictor = TemporalPredictor(
|
||||
in_fft_output_size,
|
||||
filter_size=pitch_predictor_filter_size,
|
||||
kernel_size=pitch_predictor_kernel_size,
|
||||
dropout=p_pitch_predictor_dropout, n_layers=pitch_predictor_n_layers
|
||||
)
|
||||
self.pitch_emb = nn.Conv1d(1, symbols_embedding_dim, kernel_size=3,
|
||||
padding=1)
|
||||
|
||||
# Store values precomputed for training data within the model
|
||||
self.register_buffer('pitch_mean', torch.zeros(1))
|
||||
self.register_buffer('pitch_std', torch.zeros(1))
|
||||
|
||||
self.proj = nn.Linear(out_fft_output_size, n_mel_channels, bias=True)
|
||||
|
||||
def forward(self, inputs: List[torch.Tensor], use_gt_durations: bool = True,
|
||||
use_gt_pitch: bool = True, pace: float = 1.0,
|
||||
max_duration: int = 75):
|
||||
inputs, _, mel_tgt, _, dur_tgt, _, pitch_tgt = inputs
|
||||
mel_max_len = mel_tgt.size(2)
|
||||
|
||||
# Input FFT
|
||||
enc_out, enc_mask = self.encoder(inputs)
|
||||
|
||||
# Embedded for predictors
|
||||
pred_enc_out, pred_enc_mask = enc_out, enc_mask
|
||||
|
||||
# Predict durations
|
||||
log_dur_pred = self.duration_predictor(pred_enc_out, pred_enc_mask)
|
||||
dur_pred = torch.clamp(torch.exp(log_dur_pred) - 1, 0, max_duration)
|
||||
|
||||
# Predict pitch
|
||||
pitch_pred = self.pitch_predictor(enc_out, enc_mask)
|
||||
|
||||
if use_gt_pitch and pitch_tgt is not None:
|
||||
pitch_emb = self.pitch_emb(pitch_tgt.unsqueeze(1))
|
||||
else:
|
||||
pitch_emb = self.pitch_emb(pitch_pred.unsqueeze(1))
|
||||
enc_out = enc_out + pitch_emb.transpose(1, 2)
|
||||
|
||||
len_regulated, dec_lens = regulate_len(
|
||||
dur_tgt if use_gt_durations else dur_pred,
|
||||
enc_out, pace, mel_max_len)
|
||||
|
||||
# Output FFT
|
||||
dec_out, dec_mask = self.decoder(len_regulated, dec_lens)
|
||||
mel_out = self.proj(dec_out)
|
||||
return mel_out, dec_mask, dur_pred, log_dur_pred, pitch_pred
|
||||
|
||||
def infer(self, inputs, input_lens, pace: float = 1.0,
|
||||
dur_tgt: Optional[torch.Tensor] = None,
|
||||
pitch_tgt: Optional[torch.Tensor] = None,
|
||||
max_duration: float = 75):
|
||||
|
||||
# Input FFT
|
||||
enc_out, enc_mask = self.encoder(inputs)
|
||||
|
||||
# Embedded for predictors
|
||||
pred_enc_out, pred_enc_mask = enc_out, enc_mask
|
||||
|
||||
# Predict durations
|
||||
log_dur_pred = self.duration_predictor(pred_enc_out, pred_enc_mask)
|
||||
dur_pred = torch.clamp(torch.exp(log_dur_pred) - 1, 0, max_duration)
|
||||
|
||||
# Pitch over chars
|
||||
pitch_pred = self.pitch_predictor(enc_out, enc_mask)
|
||||
|
||||
if pitch_tgt is None:
|
||||
pitch_emb = self.pitch_emb(pitch_pred.unsqueeze(1)).transpose(1, 2)
|
||||
else:
|
||||
pitch_emb = self.pitch_emb(pitch_tgt.unsqueeze(1)).transpose(1, 2)
|
||||
|
||||
enc_out = enc_out + pitch_emb
|
||||
|
||||
len_regulated, dec_lens = regulate_len(
|
||||
dur_pred if dur_tgt is None else dur_tgt,
|
||||
enc_out, pace, mel_max_len=None)
|
||||
|
||||
dec_out, dec_mask = self.decoder(len_regulated, dec_lens)
|
||||
mel_out = self.proj(dec_out)
|
||||
# mel_lens = dec_mask.squeeze(2).sum(axis=1).long()
|
||||
mel_out = mel_out.permute(0, 2, 1) # For inference.py
|
||||
return mel_out, dec_lens, dur_pred, pitch_pred
|
||||
# *****************************************************************************
|
||||
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
# * Redistributions of source code must retain the above copyright
|
||||
# notice, this list of conditions and the following disclaimer.
|
||||
# * Redistributions in binary form must reproduce the above copyright
|
||||
# notice, this list of conditions and the following disclaimer in the
|
||||
# documentation and/or other materials provided with the distribution.
|
||||
# * Neither the name of the NVIDIA CORPORATION nor the
|
||||
# names of its contributors may be used to endorse or promote products
|
||||
# derived from this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
||||
# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
||||
# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
|
||||
# DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
||||
# (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
||||
# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
|
||||
# ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
# *****************************************************************************
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
from torch import nn as nn
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
|
||||
from common.layers import ConvReLUNorm
|
||||
from fastpitch.transformer_jit import FFTransformer
|
||||
|
||||
|
||||
def regulate_len(durations, enc_out, pace: float = 1.0,
|
||||
mel_max_len: Optional[int] = None):
|
||||
"""If target=None, then predicted durations are applied"""
|
||||
reps = torch.round(durations.float() / pace).long()
|
||||
dec_lens = reps.sum(dim=1)
|
||||
|
||||
max_len = dec_lens.max()
|
||||
bsz, _, hid = enc_out.size()
|
||||
|
||||
reps_padded = torch.cat([reps, (max_len - dec_lens)[:, None]], dim=1)
|
||||
pad_vec = torch.zeros(bsz, 1, hid, dtype=enc_out.dtype,
|
||||
device=enc_out.device)
|
||||
|
||||
enc_rep = torch.cat([enc_out, pad_vec], dim=1)
|
||||
enc_rep = torch.repeat_interleave(
|
||||
enc_rep.view(-1, hid), reps_padded.view(-1), dim=0
|
||||
).view(bsz, -1, hid)
|
||||
|
||||
# enc_rep = pad_sequence([torch.repeat_interleave(o, r, dim=0)
|
||||
# for o, r in zip(enc_out, reps)],
|
||||
# batch_first=True)
|
||||
if mel_max_len is not None:
|
||||
enc_rep = enc_rep[:, :mel_max_len]
|
||||
dec_lens = torch.clamp_max(dec_lens, mel_max_len)
|
||||
return enc_rep, dec_lens
|
||||
|
||||
|
||||
class TemporalPredictor(nn.Module):
|
||||
"""Predicts a single float per each temporal location"""
|
||||
|
||||
def __init__(self, input_size, filter_size, kernel_size, dropout,
|
||||
n_layers=2):
|
||||
super(TemporalPredictor, self).__init__()
|
||||
|
||||
self.layers = nn.Sequential(*[
|
||||
ConvReLUNorm(input_size if i == 0 else filter_size, filter_size,
|
||||
kernel_size=kernel_size, dropout=dropout)
|
||||
for i in range(n_layers)]
|
||||
)
|
||||
self.fc = nn.Linear(filter_size, 1, bias=True)
|
||||
|
||||
def forward(self, enc_out, enc_out_mask):
|
||||
out = enc_out * enc_out_mask
|
||||
out = self.layers(out.transpose(1, 2)).transpose(1, 2)
|
||||
out = self.fc(out) * enc_out_mask
|
||||
return out.squeeze(-1)
|
||||
|
||||
|
||||
class FastPitch(nn.Module):
|
||||
def __init__(self, n_mel_channels, max_seq_len, n_symbols, padding_idx,
|
||||
symbols_embedding_dim, in_fft_n_layers, in_fft_n_heads,
|
||||
in_fft_d_head,
|
||||
in_fft_conv1d_kernel_size, in_fft_conv1d_filter_size,
|
||||
in_fft_output_size,
|
||||
p_in_fft_dropout, p_in_fft_dropatt, p_in_fft_dropemb,
|
||||
out_fft_n_layers, out_fft_n_heads, out_fft_d_head,
|
||||
out_fft_conv1d_kernel_size, out_fft_conv1d_filter_size,
|
||||
out_fft_output_size,
|
||||
p_out_fft_dropout, p_out_fft_dropatt, p_out_fft_dropemb,
|
||||
dur_predictor_kernel_size, dur_predictor_filter_size,
|
||||
p_dur_predictor_dropout, dur_predictor_n_layers,
|
||||
pitch_predictor_kernel_size, pitch_predictor_filter_size,
|
||||
p_pitch_predictor_dropout, pitch_predictor_n_layers,
|
||||
pitch_embedding_kernel_size, n_speakers, speaker_emb_weight):
|
||||
super(FastPitch, self).__init__()
|
||||
del max_seq_len # unused
|
||||
|
||||
self.encoder = FFTransformer(
|
||||
n_layer=in_fft_n_layers, n_head=in_fft_n_heads,
|
||||
d_model=symbols_embedding_dim,
|
||||
d_head=in_fft_d_head,
|
||||
d_inner=in_fft_conv1d_filter_size,
|
||||
kernel_size=in_fft_conv1d_kernel_size,
|
||||
dropout=p_in_fft_dropout,
|
||||
dropatt=p_in_fft_dropatt,
|
||||
dropemb=p_in_fft_dropemb,
|
||||
embed_input=True,
|
||||
d_embed=symbols_embedding_dim,
|
||||
n_embed=n_symbols,
|
||||
padding_idx=padding_idx)
|
||||
|
||||
if n_speakers > 1:
|
||||
self.speaker_emb = nn.Embedding(n_speakers, symbols_embedding_dim)
|
||||
else:
|
||||
self.speaker_emb = None
|
||||
self.speaker_emb_weight = speaker_emb_weight
|
||||
|
||||
self.duration_predictor = TemporalPredictor(
|
||||
in_fft_output_size,
|
||||
filter_size=dur_predictor_filter_size,
|
||||
kernel_size=dur_predictor_kernel_size,
|
||||
dropout=p_dur_predictor_dropout, n_layers=dur_predictor_n_layers
|
||||
)
|
||||
|
||||
self.decoder = FFTransformer(
|
||||
n_layer=out_fft_n_layers, n_head=out_fft_n_heads,
|
||||
d_model=symbols_embedding_dim,
|
||||
d_head=out_fft_d_head,
|
||||
d_inner=out_fft_conv1d_filter_size,
|
||||
kernel_size=out_fft_conv1d_kernel_size,
|
||||
dropout=p_out_fft_dropout,
|
||||
dropatt=p_out_fft_dropatt,
|
||||
dropemb=p_out_fft_dropemb,
|
||||
embed_input=False,
|
||||
d_embed=symbols_embedding_dim
|
||||
)
|
||||
|
||||
self.pitch_predictor = TemporalPredictor(
|
||||
in_fft_output_size,
|
||||
filter_size=pitch_predictor_filter_size,
|
||||
kernel_size=pitch_predictor_kernel_size,
|
||||
dropout=p_pitch_predictor_dropout, n_layers=pitch_predictor_n_layers
|
||||
)
|
||||
|
||||
self.pitch_emb = nn.Conv1d(
|
||||
1, symbols_embedding_dim,
|
||||
kernel_size=pitch_embedding_kernel_size,
|
||||
padding=int((pitch_embedding_kernel_size - 1) / 2))
|
||||
|
||||
# Store values precomputed for training data within the model
|
||||
self.register_buffer('pitch_mean', torch.zeros(1))
|
||||
self.register_buffer('pitch_std', torch.zeros(1))
|
||||
|
||||
self.proj = nn.Linear(out_fft_output_size, n_mel_channels, bias=True)
|
||||
|
||||
def forward(self, inputs: List[torch.Tensor], use_gt_durations: bool = True,
|
||||
use_gt_pitch: bool = True, pace: float = 1.0,
|
||||
max_duration: int = 75):
|
||||
inputs, _, mel_tgt, _, dur_tgt, _, pitch_tgt, speaker = inputs
|
||||
mel_max_len = mel_tgt.size(2)
|
||||
|
||||
# Calculate speaker embedding
|
||||
if self.speaker_emb is None:
|
||||
spk_emb = 0
|
||||
else:
|
||||
spk_emb = self.speaker_emb(speaker).unsqueeze(1)
|
||||
spk_emb.mul_(self.speaker_emb_weight)
|
||||
|
||||
# Input FFT
|
||||
enc_out, enc_mask = self.encoder(inputs, conditioning=spk_emb)
|
||||
|
||||
# Embedded for predictors
|
||||
pred_enc_out, pred_enc_mask = enc_out, enc_mask
|
||||
|
||||
# Predict durations
|
||||
log_dur_pred = self.duration_predictor(pred_enc_out, pred_enc_mask)
|
||||
dur_pred = torch.clamp(torch.exp(log_dur_pred) - 1, 0, max_duration)
|
||||
|
||||
# Predict pitch
|
||||
pitch_pred = self.pitch_predictor(enc_out, enc_mask)
|
||||
|
||||
if use_gt_pitch and pitch_tgt is not None:
|
||||
pitch_emb = self.pitch_emb(pitch_tgt.unsqueeze(1))
|
||||
else:
|
||||
pitch_emb = self.pitch_emb(pitch_pred.unsqueeze(1))
|
||||
enc_out = enc_out + pitch_emb.transpose(1, 2)
|
||||
|
||||
len_regulated, dec_lens = regulate_len(
|
||||
dur_tgt if use_gt_durations else dur_pred,
|
||||
enc_out, pace, mel_max_len)
|
||||
|
||||
# Output FFT
|
||||
dec_out, dec_mask = self.decoder(len_regulated, dec_lens)
|
||||
mel_out = self.proj(dec_out)
|
||||
return mel_out, dec_mask, dur_pred, log_dur_pred, pitch_pred
|
||||
|
||||
def infer(self, inputs, input_lens, pace: float = 1.0,
|
||||
dur_tgt: Optional[torch.Tensor] = None,
|
||||
pitch_tgt: Optional[torch.Tensor] = None,
|
||||
max_duration: float = 75,
|
||||
speaker: int = 0):
|
||||
del input_lens # unused
|
||||
|
||||
if self.speaker_emb is None:
|
||||
spk_emb = None
|
||||
else:
|
||||
speaker = torch.ones(inputs.size(0), dtype=torch.long, device=inputs.device).fill_(speaker)
|
||||
spk_emb = self.speaker_emb(speaker).unsqueeze(1)
|
||||
spk_emb.mul_(self.speaker_emb_weight)
|
||||
|
||||
# Input FFT
|
||||
enc_out, enc_mask = self.encoder(inputs, conditioning=spk_emb)
|
||||
|
||||
# Embedded for predictors
|
||||
pred_enc_out, pred_enc_mask = enc_out, enc_mask
|
||||
|
||||
# Predict durations
|
||||
log_dur_pred = self.duration_predictor(pred_enc_out, pred_enc_mask)
|
||||
dur_pred = torch.clamp(torch.exp(log_dur_pred) - 1, 0, max_duration)
|
||||
|
||||
# Pitch over chars
|
||||
pitch_pred = self.pitch_predictor(enc_out, enc_mask)
|
||||
|
||||
if pitch_tgt is None:
|
||||
pitch_emb = self.pitch_emb(pitch_pred.unsqueeze(1)).transpose(1, 2)
|
||||
else:
|
||||
pitch_emb = self.pitch_emb(pitch_tgt.unsqueeze(1)).transpose(1, 2)
|
||||
|
||||
enc_out = enc_out + pitch_emb
|
||||
|
||||
len_regulated, dec_lens = regulate_len(
|
||||
dur_pred if dur_tgt is None else dur_tgt,
|
||||
enc_out, pace, mel_max_len=None)
|
||||
|
||||
dec_out, dec_mask = self.decoder(len_regulated, dec_lens)
|
||||
mel_out = self.proj(dec_out)
|
||||
# mel_lens = dec_mask.squeeze(2).sum(axis=1).long()
|
||||
mel_out = mel_out.permute(0, 2, 1) # For inference.py
|
||||
return mel_out, dec_lens, dur_pred, pitch_pred
|
||||
|
|
|
@ -17,7 +17,6 @@ import torch.nn as nn
|
|||
import torch.nn.functional as F
|
||||
|
||||
from common.utils import mask_from_lens
|
||||
from common.text.symbols import pad_idx, symbols
|
||||
|
||||
|
||||
class PositionalEmbedding(nn.Module):
|
||||
|
@ -248,16 +247,17 @@ class TransformerLayer(nn.Module):
|
|||
|
||||
class FFTransformer(nn.Module):
|
||||
def __init__(self, n_layer, n_head, d_model, d_head, d_inner, kernel_size,
|
||||
dropout, dropatt, dropemb=0.0, embed_input=True, d_embed=None,
|
||||
pre_lnorm=False):
|
||||
dropout, dropatt, dropemb=0.0, embed_input=True,
|
||||
n_embed=None, d_embed=None, padding_idx=0, pre_lnorm=False):
|
||||
super(FFTransformer, self).__init__()
|
||||
self.d_model = d_model
|
||||
self.n_head = n_head
|
||||
self.d_head = d_head
|
||||
self.padding_idx = padding_idx
|
||||
|
||||
if embed_input:
|
||||
self.word_emb = nn.Embedding(len(symbols), d_embed or d_model,
|
||||
padding_idx=pad_idx)
|
||||
self.word_emb = nn.Embedding(n_embed, d_embed or d_model,
|
||||
padding_idx=self.padding_idx)
|
||||
else:
|
||||
self.word_emb = None
|
||||
|
||||
|
@ -272,18 +272,18 @@ class FFTransformer(nn.Module):
|
|||
dropatt=dropatt, pre_lnorm=pre_lnorm)
|
||||
)
|
||||
|
||||
def forward(self, dec_inp, seq_lens=None):
|
||||
def forward(self, dec_inp, seq_lens=None, conditioning=0):
|
||||
if self.word_emb is None:
|
||||
inp = dec_inp
|
||||
mask = mask_from_lens(seq_lens).unsqueeze(2)
|
||||
else:
|
||||
inp = self.word_emb(dec_inp)
|
||||
# [bsz x L x 1]
|
||||
mask = (dec_inp != pad_idx).unsqueeze(2)
|
||||
mask = (dec_inp != self.padding_idx).unsqueeze(2)
|
||||
|
||||
pos_seq = torch.arange(inp.size(1), device=inp.device, dtype=inp.dtype)
|
||||
pos_emb = self.pos_emb(pos_seq) * mask
|
||||
out = self.drop(inp + pos_emb)
|
||||
out = self.drop(inp + pos_emb + conditioning)
|
||||
|
||||
for layer in self.layers:
|
||||
out = layer(out, mask=mask)
|
||||
|
|
|
@ -19,7 +19,6 @@ import torch.nn as nn
|
|||
import torch.nn.functional as F
|
||||
|
||||
from common.utils import mask_from_lens
|
||||
from common.text.symbols import pad_idx, symbols
|
||||
|
||||
|
||||
class NoOp(nn.Module):
|
||||
|
@ -255,20 +254,20 @@ class TransformerLayer(nn.Module):
|
|||
|
||||
|
||||
class FFTransformer(nn.Module):
|
||||
pad_idx = 0 # XXX
|
||||
|
||||
def __init__(self, n_layer, n_head, d_model, d_head, d_inner, kernel_size,
|
||||
dropout, dropatt, dropemb=0.0, embed_input=True, d_embed=None,
|
||||
pre_lnorm=False):
|
||||
dropout, dropatt, dropemb=0.0, embed_input=True,
|
||||
n_embed=None, d_embed=None, padding_idx=0, pre_lnorm=False):
|
||||
super(FFTransformer, self).__init__()
|
||||
self.d_model = d_model
|
||||
self.n_head = n_head
|
||||
self.d_head = d_head
|
||||
self.padding_idx = padding_idx
|
||||
self.n_embed = n_embed
|
||||
|
||||
self.embed_input = embed_input
|
||||
if embed_input:
|
||||
self.word_emb = nn.Embedding(len(symbols), d_embed or d_model,
|
||||
padding_idx=FFTransformer.pad_idx)
|
||||
self.word_emb = nn.Embedding(n_embed, d_embed or d_model,
|
||||
padding_idx=self.padding_idx)
|
||||
else:
|
||||
self.word_emb = NoOp()
|
||||
|
||||
|
@ -283,19 +282,23 @@ class FFTransformer(nn.Module):
|
|||
dropatt=dropatt, pre_lnorm=pre_lnorm)
|
||||
)
|
||||
|
||||
def forward(self, dec_inp, seq_lens: Optional[torch.Tensor] = None):
|
||||
if self.embed_input:
|
||||
inp = self.word_emb(dec_inp)
|
||||
# [bsz x L x 1]
|
||||
# mask = (dec_inp != FFTransformer.pad_idx).unsqueeze(2)
|
||||
mask = (dec_inp != 0).unsqueeze(2)
|
||||
else:
|
||||
def forward(self, dec_inp, seq_lens: Optional[torch.Tensor] = None,
|
||||
conditioning: Optional[torch.Tensor] = None):
|
||||
if not self.embed_input:
|
||||
inp = dec_inp
|
||||
assert seq_lens is not None
|
||||
mask = mask_from_lens(seq_lens).unsqueeze(2)
|
||||
else:
|
||||
inp = self.word_emb(dec_inp)
|
||||
# [bsz x L x 1]
|
||||
mask = (dec_inp != self.padding_idx).unsqueeze(2)
|
||||
|
||||
pos_seq = torch.arange(inp.size(1), device=inp.device, dtype=inp.dtype)
|
||||
pos_emb = self.pos_emb(pos_seq) * mask
|
||||
out = self.drop(inp + pos_emb)
|
||||
if conditioning is not None:
|
||||
out = self.drop(inp + pos_emb + conditioning)
|
||||
else:
|
||||
out = self.drop(inp + pos_emb)
|
||||
|
||||
for layer in self.layers:
|
||||
out = layer(out, mask=mask)
|
||||
|
|
|
@ -45,7 +45,7 @@ from dllogger import StdOutBackend, JSONStreamBackend, Verbosity
|
|||
from common import utils
|
||||
from common.tb_dllogger import (init_inference_metadata, stdout_metric_format,
|
||||
unique_log_fpath)
|
||||
from common.text import text_to_sequence
|
||||
from common.text.text_processing import TextProcessing
|
||||
from pitch_transform import pitch_transform_custom
|
||||
from waveglow import model as glow
|
||||
from waveglow.denoiser import Denoiser
|
||||
|
@ -92,9 +92,11 @@ def parse_args(parser):
|
|||
help='Use EMA averaged model (if saved in checkpoints)')
|
||||
parser.add_argument('--dataset-path', type=str,
|
||||
help='Path to dataset (for loading extra data fields)')
|
||||
parser.add_argument('--speaker', type=int, default=0,
|
||||
help='Speaker ID for a multi-speaker model')
|
||||
|
||||
transform = parser.add_argument_group('transform')
|
||||
transform.add_argument('--fade-out', type=int, default=5,
|
||||
transform.add_argument('--fade-out', type=int, default=10,
|
||||
help='Number of fadeout frames at the end')
|
||||
transform.add_argument('--pace', type=float, default=1.0,
|
||||
help='Adjust the pace of speech')
|
||||
|
@ -108,6 +110,18 @@ def parse_args(parser):
|
|||
help='Raise/lower the pitch by <hz>')
|
||||
transform.add_argument('--pitch-transform-custom', action='store_true',
|
||||
help='Apply the transform from pitch_transform.py')
|
||||
|
||||
text_processing = parser.add_argument_group('Text processing parameters')
|
||||
text_processing.add_argument('--text-cleaners', nargs='*',
|
||||
default=['english_cleaners'], type=str,
|
||||
help='Type of text cleaners for input text')
|
||||
text_processing.add_argument('--symbol-set', type=str, default='english_basic',
|
||||
help='Define symbol set for input text')
|
||||
|
||||
cond = parser.add_argument_group('conditioning on additional attributes')
|
||||
cond.add_argument('--n-speakers', type=int, default=1,
|
||||
help='Number of speakers in the model.')
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
|
@ -138,7 +152,7 @@ def load_and_setup_model(model_name, parser, checkpoint, amp, device,
|
|||
|
||||
if any(key.startswith('module.') for key in sd):
|
||||
sd = {k.replace('module.', ''): v for k,v in sd.items()}
|
||||
status += ' ' + str(model.load_state_dict(sd, strict=False))
|
||||
status += ' ' + str(model.load_state_dict(sd, strict=True))
|
||||
else:
|
||||
model = checkpoint_data['model']
|
||||
print(f'Loaded {model_name}{status}')
|
||||
|
@ -162,10 +176,13 @@ def load_fields(fpath):
|
|||
return {c:f for c, f in zip(columns, fields)}
|
||||
|
||||
|
||||
def prepare_input_sequence(fields, device, batch_size=128, dataset=None,
|
||||
load_mels=False, load_pitch=False):
|
||||
fields['text'] = [torch.LongTensor(text_to_sequence(t, ['english_cleaners']))
|
||||
for t in fields['text']]
|
||||
def prepare_input_sequence(fields, device, symbol_set, text_cleaners,
|
||||
batch_size=128, dataset=None, load_mels=False,
|
||||
load_pitch=False):
|
||||
tp = TextProcessing(symbol_set, text_cleaners)
|
||||
|
||||
fields['text'] = [torch.LongTensor(tp.encode_text(text))
|
||||
for text in fields['text']]
|
||||
order = np.argsort([-t.size(0) for t in fields['text']])
|
||||
|
||||
fields['text'] = [fields['text'][i] for i in order]
|
||||
|
@ -206,7 +223,6 @@ def prepare_input_sequence(fields, device, batch_size=128, dataset=None,
|
|||
|
||||
|
||||
def build_pitch_transformation(args):
|
||||
|
||||
if args.pitch_transform_custom:
|
||||
def custom_(pitch, pitch_lens, mean, std):
|
||||
return (pitch_transform_custom(pitch * std + mean, pitch_lens)
|
||||
|
@ -262,7 +278,7 @@ def main():
|
|||
StdOutBackend(Verbosity.VERBOSE,
|
||||
metric_format=stdout_metric_format)])
|
||||
init_inference_metadata()
|
||||
[DLLogger.log("PARAMETER", {k:v}) for k,v in vars(args).items()]
|
||||
[DLLogger.log("PARAMETER", {k: v}) for k, v in vars(args).items()]
|
||||
|
||||
device = torch.device('cuda' if args.cuda else 'cpu')
|
||||
|
||||
|
@ -293,8 +309,8 @@ def main():
|
|||
|
||||
fields = load_fields(args.input)
|
||||
batches = prepare_input_sequence(
|
||||
fields, device, args.batch_size, args.dataset_path,
|
||||
load_mels=(generator is None))
|
||||
fields, device, args.symbol_set, args.text_cleaners, args.batch_size,
|
||||
args.dataset_path, load_mels=(generator is None))
|
||||
|
||||
if args.include_warmup:
|
||||
# Use real data rather than synthetic - FastPitch predicts len
|
||||
|
@ -311,11 +327,13 @@ def main():
|
|||
waveglow_measures = MeasureTime()
|
||||
|
||||
gen_kw = {'pace': args.pace,
|
||||
'speaker': args.speaker,
|
||||
'pitch_tgt': None,
|
||||
'pitch_transform': build_pitch_transformation(args)}
|
||||
|
||||
if args.torchscript:
|
||||
gen_kw.pop('pitch_transform')
|
||||
print('NOTE: Pitch transforms are disabled with TorchScript')
|
||||
|
||||
all_utterances = 0
|
||||
all_samples = 0
|
||||
|
@ -323,11 +341,10 @@ def main():
|
|||
all_frames = 0
|
||||
|
||||
reps = args.repeats
|
||||
log_enabled = True # reps == 1
|
||||
log_enabled = reps == 1
|
||||
log = lambda s, d: DLLogger.log(step=s, data=d) if log_enabled else None
|
||||
|
||||
# for repeat in (tqdm.tqdm(range(reps)) if reps > 1 else range(reps)):
|
||||
for rep in range(reps):
|
||||
for rep in (tqdm.tqdm(range(reps)) if reps > 1 else range(reps)):
|
||||
for b in batches:
|
||||
if generator is None:
|
||||
log(rep, {'Synthesizing from ground truth mels'})
|
||||
|
@ -348,7 +365,7 @@ def main():
|
|||
audios = waveglow(mel, sigma=args.sigma_infer)
|
||||
audios = denoiser(audios.float(),
|
||||
strength=args.denoising_strength
|
||||
).squeeze(1)
|
||||
).squeeze(1)
|
||||
|
||||
all_utterances += len(audios)
|
||||
all_samples += sum(audio.size(0) for audio in audios)
|
||||
|
@ -367,7 +384,7 @@ def main():
|
|||
fade_w = torch.linspace(1.0, 0.0, fade_len)
|
||||
audio[-fade_len:] *= fade_w.to(audio.device)
|
||||
|
||||
audio = audio/torch.max(torch.abs(audio))
|
||||
audio = audio / torch.max(torch.abs(audio))
|
||||
fname = b['output'][i] if 'output' in b else f'audio_{i}.wav'
|
||||
audio_path = Path(args.output, fname)
|
||||
write(audio_path, args.sampling_rate, audio.cpu().numpy())
|
||||
|
|
|
@ -37,6 +37,7 @@ from fastpitch.model import FastPitch as _FastPitch
|
|||
from fastpitch.model_jit import FastPitch as _FastPitchJIT
|
||||
from tacotron2.model import Tacotron2
|
||||
from waveglow.model import WaveGlow
|
||||
from common.text.symbols import get_symbols, get_pad_idx
|
||||
|
||||
|
||||
def parse_model_args(model_name, parser, add_help=False):
|
||||
|
@ -94,25 +95,27 @@ def get_model(model_name, model_config, device,
|
|||
model = WaveGlow(**model_config)
|
||||
|
||||
elif model_name == 'FastPitch':
|
||||
|
||||
if forward_is_infer:
|
||||
|
||||
if jitable:
|
||||
class FastPitch__forward_is_infer(_FastPitchJIT):
|
||||
def forward(self, inputs, input_lengths, pace: float = 1.0,
|
||||
dur_tgt: Optional[torch.Tensor] = None,
|
||||
pitch_tgt: Optional[torch.Tensor] = None):
|
||||
pitch_tgt: Optional[torch.Tensor] = None,
|
||||
speaker: int = 0):
|
||||
return self.infer(inputs, input_lengths, pace=pace,
|
||||
dur_tgt=dur_tgt, pitch_tgt=pitch_tgt)
|
||||
dur_tgt=dur_tgt, pitch_tgt=pitch_tgt,
|
||||
speaker=speaker)
|
||||
else:
|
||||
class FastPitch__forward_is_infer(_FastPitch):
|
||||
def forward(self, inputs, input_lengths, pace: float = 1.0,
|
||||
dur_tgt: Optional[torch.Tensor] = None,
|
||||
pitch_tgt: Optional[torch.Tensor] = None,
|
||||
pitch_transform=None):
|
||||
pitch_transform=None,
|
||||
speaker: Optional[int] = None):
|
||||
return self.infer(inputs, input_lengths, pace=pace,
|
||||
dur_tgt=dur_tgt, pitch_tgt=pitch_tgt,
|
||||
pitch_transform=pitch_transform)
|
||||
pitch_transform=pitch_transform,
|
||||
speaker=speaker)
|
||||
|
||||
model = FastPitch__forward_is_infer(**model_config)
|
||||
else:
|
||||
|
@ -136,7 +139,7 @@ def get_model_config(model_name, args):
|
|||
# audio
|
||||
n_mel_channels=args.n_mel_channels,
|
||||
# symbols
|
||||
n_symbols=args.n_symbols,
|
||||
n_symbols=len(get_symbols(args.symbol_set)),
|
||||
symbols_embedding_dim=args.symbols_embedding_dim,
|
||||
# encoder
|
||||
encoder_kernel_size=args.encoder_kernel_size,
|
||||
|
@ -183,7 +186,8 @@ def get_model_config(model_name, args):
|
|||
n_mel_channels=args.n_mel_channels,
|
||||
max_seq_len=args.max_seq_len,
|
||||
# symbols
|
||||
n_symbols=args.n_symbols,
|
||||
n_symbols=len(get_symbols(args.symbol_set)),
|
||||
padding_idx=get_pad_idx(args.symbol_set),
|
||||
symbols_embedding_dim=args.symbols_embedding_dim,
|
||||
# input FFT
|
||||
in_fft_n_layers=args.in_fft_n_layers,
|
||||
|
@ -215,6 +219,11 @@ def get_model_config(model_name, args):
|
|||
pitch_predictor_filter_size=args.pitch_predictor_filter_size,
|
||||
p_pitch_predictor_dropout=args.p_pitch_predictor_dropout,
|
||||
pitch_predictor_n_layers=args.pitch_predictor_n_layers,
|
||||
# pitch conditioning
|
||||
pitch_embedding_kernel_size=args.pitch_embedding_kernel_size,
|
||||
# speakers parameters
|
||||
n_speakers=args.n_speakers,
|
||||
speaker_emb_weight=args.speaker_emb_weight
|
||||
)
|
||||
return model_config
|
||||
|
||||
|
|
|
@ -6,16 +6,13 @@ DATA_DIR="LJSpeech-1.1"
|
|||
LJS_ARCH="LJSpeech-1.1.tar.bz2"
|
||||
LJS_URL="http://data.keithito.com/data/speech/${LJS_ARCH}"
|
||||
|
||||
if [ ! -f ${LJS_ARCH} ]; then
|
||||
if [ ! -d ${DATA_DIR} ]; then
|
||||
echo "Downloading ${LJS_ARCH} ..."
|
||||
wget -q ${LJS_URL}
|
||||
fi
|
||||
|
||||
if [ ! -d ${DATA_DIR} ]; then
|
||||
echo "Extracting ${LJS_ARCH} ..."
|
||||
tar jxvf ${LJS_ARCH}
|
||||
rm -f ${LJS_ARCH}
|
||||
fi
|
||||
|
||||
bash scripts/download_tacotron2.sh
|
||||
bash scripts/download_waveglow.sh
|
||||
bash ./scripts/download_tacotron2.sh
|
||||
bash ./scripts/download_waveglow.sh
|
||||
|
|
|
@ -2,19 +2,20 @@
|
|||
|
||||
set -e
|
||||
|
||||
MODEL_DIR=${MODEL_DIR:-"pretrained_models"}
|
||||
FASTP_ZIP="nvidia_fastpitch_200518.zip"
|
||||
FASTP_CH="nvidia_fastpitch_200518.pt"
|
||||
FASTP_URL="https://api.ngc.nvidia.com/v2/models/nvidia/fastpitch_pyt_amp_ckpt_v1/versions/20.02.0/zip"
|
||||
: ${MODEL_DIR:="pretrained_models/fastpitch"}
|
||||
MODEL_ZIP="nvidia_fastpitch_200518.zip"
|
||||
MODEL_CH="nvidia_fastpitch_200518.pt"
|
||||
MODEL_URL="https://api.ngc.nvidia.com/v2/models/nvidia/fastpitch_pyt_amp_ckpt_v1/versions/20.02.0/zip"
|
||||
|
||||
mkdir -p "$MODEL_DIR"/fastpitch
|
||||
mkdir -p "$MODEL_DIR"
|
||||
|
||||
if [ ! -f "${MODEL_DIR}/fastpitch/${FASTP_ZIP}" ]; then
|
||||
echo "Downloading ${FASTP_ZIP} ..."
|
||||
wget -qO ${MODEL_DIR}/fastpitch/${FASTP_ZIP} ${FASTP_URL}
|
||||
if [ ! -f "${MODEL_DIR}/${MODEL_ZIP}" ]; then
|
||||
echo "Downloading ${MODEL_ZIP} ..."
|
||||
wget -qO ${MODEL_DIR}/${MODEL_ZIP} ${MODEL_URL} \
|
||||
|| echo "ERROR: Failed to download ${MODEL_ZIP} from NGC" && exit 1
|
||||
fi
|
||||
|
||||
if [ ! -f "${MODEL_DIR}/fastpitch/${FASTP_CH}" ]; then
|
||||
echo "Extracting ${FASTP_CH} ..."
|
||||
unzip -qo ${MODEL_DIR}/fastpitch/${FASTP_ZIP} -d ${MODEL_DIR}/fastpitch/
|
||||
if [ ! -f "${MODEL_DIR}/${MODEL_CH}" ]; then
|
||||
echo "Extracting ${MODEL_CH} ..."
|
||||
unzip -qo ${MODEL_DIR}/${MODEL_ZIP} -d ${MODEL_DIR}
|
||||
fi
|
||||
|
|
|
@ -2,12 +2,18 @@
|
|||
|
||||
set -e
|
||||
|
||||
MODEL_DIR=${MODEL_DIR:-"pretrained_models"}
|
||||
TACO_CH="nvidia_tacotron2pyt_fp32_20190427.pt"
|
||||
TACO_URL="https://api.ngc.nvidia.com/v2/models/nvidia/tacotron2pyt_fp32/versions/2/files/nvidia_tacotron2pyt_fp32_20190427"
|
||||
: ${MODEL_DIR:="pretrained_models/tacotron2"}
|
||||
MODEL="nvidia_tacotron2pyt_fp16.pt"
|
||||
MODEL_URL="https://api.ngc.nvidia.com/v2/models/nvidia/tacotron2_pyt_ckpt_amp/versions/19.12.0/files/nvidia_tacotron2pyt_fp16.pt"
|
||||
|
||||
if [ ! -f "${MODEL_DIR}/tacotron2/${TACO_CH}" ]; then
|
||||
echo "Downloading ${TACO_CH} ..."
|
||||
mkdir -p "$MODEL_DIR"/tacotron2
|
||||
wget -qO ${MODEL_DIR}/tacotron2/${TACO_CH} ${TACO_URL}
|
||||
mkdir -p "$MODEL_DIR"
|
||||
|
||||
if [ ! -f "${MODEL_DIR}/${MODEL}" ]; then
|
||||
echo "Downloading ${MODEL} ..."
|
||||
wget --content-disposition -qO ${MODEL_DIR}/${MODEL} ${MODEL_URL} \
|
||||
|| echo "ERROR: Failed to download ${MODEL} from NGC" && exit 1
|
||||
echo "OK"
|
||||
|
||||
else
|
||||
echo "${MODEL}.pt already downloaded."
|
||||
fi
|
||||
|
|
|
@ -2,19 +2,26 @@
|
|||
|
||||
set -e
|
||||
|
||||
MODEL_DIR=${MODEL_DIR:-"pretrained_models"}
|
||||
WAVEG="waveglow_1076430_14000_amp"
|
||||
WAVEG_URL="https://api.ngc.nvidia.com/v2/models/nvidia/waveglow256pyt_fp16/versions/2/zip"
|
||||
: ${MODEL_DIR:="pretrained_models/waveglow"}
|
||||
MODEL="nvidia_waveglow256pyt_fp16"
|
||||
MODEL_ZIP="waveglow_ckpt_amp_256_20.01.0.zip"
|
||||
MODEL_URL="https://api.ngc.nvidia.com/v2/models/nvidia/waveglow_ckpt_amp_256/versions/20.01.0/zip"
|
||||
|
||||
mkdir -p "$MODEL_DIR"/waveglow
|
||||
mkdir -p "$MODEL_DIR"
|
||||
|
||||
if [ ! -f "${MODEL_DIR}/waveglow/${WAVEG}.zip" ]; then
|
||||
echo "Downloading ${WAVEG}.zip ..."
|
||||
wget -qO "${MODEL_DIR}/waveglow/${WAVEG}.zip" ${WAVEG_URL}
|
||||
if [ ! -f "${MODEL_DIR}/${MODEL_ZIP}" ]; then
|
||||
echo "Downloading ${MODEL_ZIP} ..."
|
||||
wget --content-disposition -qO ${MODEL_DIR}/${MODEL_ZIP} ${MODEL_URL} \
|
||||
|| echo "ERROR: Failed to download ${MODEL_ZIP} from NGC" && exit 1
|
||||
fi
|
||||
|
||||
if [ ! -f "${MODEL_DIR}/waveglow/${WAVEG}.pt" ]; then
|
||||
echo "Extracting ${WAVEG} ..."
|
||||
unzip -qo "${MODEL_DIR}/waveglow/${WAVEG}.zip" -d ${MODEL_DIR}/waveglow/
|
||||
mv "${MODEL_DIR}/waveglow/${WAVEG}" "${MODEL_DIR}/waveglow/${WAVEG}.pt"
|
||||
if [ ! -f "${MODEL_DIR}/${MODEL}.pt" ]; then
|
||||
echo "Extracting ${MODEL} ..."
|
||||
unzip -qo ${MODEL_DIR}/${MODEL_ZIP} -d ${MODEL_DIR} \
|
||||
|| echo "ERROR: Failed to extract ${MODEL_ZIP}" && exit 1
|
||||
|
||||
echo "OK"
|
||||
|
||||
else
|
||||
echo "${MODEL}.pt already downloaded."
|
||||
fi
|
||||
|
|
|
@ -1,22 +1,26 @@
|
|||
#!/bin/bash
|
||||
|
||||
[ ! -n "$WAVEG_CH" ] && WAVEG_CH="pretrained_models/waveglow/waveglow_1076430_14000_amp.pt"
|
||||
[ ! -n "$FASTPITCH_CH" ] && FASTPITCH_CH="output/FastPitch_checkpoint_1500.pt"
|
||||
[ ! -n "$REPEATS" ] && REPEATS=1000
|
||||
[ ! -n "$BS_SEQ" ] && BS_SEQ="1 4 8"
|
||||
[ ! -n "$PHRASES" ] && PHRASES="phrases/benchmark_8_128.tsv"
|
||||
[ ! -n "$OUTPUT_DIR" ] && OUTPUT_DIR="./output/audio_$(basename ${PHRASES} .tsv)"
|
||||
[ "$AMP" == "true" ] && AMP_FLAG="--amp" || AMP=false
|
||||
: ${WAVEGLOW:="pretrained_models/waveglow/nvidia_waveglow256pyt_fp16.pt"}
|
||||
: ${FASTPITCH:="output/FastPitch_checkpoint_1500.pt"}
|
||||
: ${REPEATS:=1000}
|
||||
: ${BS_SEQUENCE:="1 4 8"}
|
||||
: ${PHRASES:="phrases/benchmark_8_128.tsv"}
|
||||
: ${OUTPUT_DIR:="./output/audio_$(basename ${PHRASES} .tsv)"}
|
||||
: ${AMP:=false}
|
||||
|
||||
for BS in $BS_SEQ ; do
|
||||
[ "$AMP" = true ] && AMP_FLAG="--amp"
|
||||
|
||||
mkdir -o "$OUTPUT_DIR"
|
||||
|
||||
for BS in $BS_SEQUENCE ; do
|
||||
|
||||
echo -e "\nAMP: ${AMP}, batch size: ${BS}\n"
|
||||
|
||||
python inference.py --cuda --cudnn-benchmark \
|
||||
-i ${PHRASES} \
|
||||
-o ${OUTPUT_DIR} \
|
||||
--fastpitch ${FASTPITCH_CH} \
|
||||
--waveglow ${WAVEG_CH} \
|
||||
--fastpitch ${FASTPITCH} \
|
||||
--waveglow ${WAVEGLOW} \
|
||||
--wn-channels 256 \
|
||||
--include-warmup \
|
||||
--batch-size ${BS} \
|
||||
|
|
|
@ -1,22 +1,21 @@
|
|||
#!/usr/bin/env bash
|
||||
|
||||
DATA_DIR="LJSpeech-1.1"
|
||||
: ${WAVEGLOW:="pretrained_models/waveglow/nvidia_waveglow256pyt_fp16.pt"}
|
||||
: ${FASTPITCH:="output/FastPitch_checkpoint_1500.pt"}
|
||||
: ${BS:=32}
|
||||
: ${PHRASES:="phrases/devset10.tsv"}
|
||||
: ${OUTPUT_DIR:="./output/audio_$(basename ${PHRASES} .tsv)"}
|
||||
: ${AMP:=false}
|
||||
|
||||
[ ! -n "$WAVEG_CH" ] && WAVEG_CH="pretrained_models/waveglow/waveglow_1076430_14000_amp.pt"
|
||||
[ ! -n "$FASTPITCH_CH" ] && FASTPITCH_CH="output/FastPitch_checkpoint_1500.pt"
|
||||
[ ! -n "$BS" ] && BS=32
|
||||
[ ! -n "$PHRASES" ] && PHRASES="phrases/devset10.tsv"
|
||||
[ ! -n "$OUTPUT_DIR" ] && OUTPUT_DIR="./output/audio_$(basename ${PHRASES} .tsv)"
|
||||
[ "$AMP" == "true" ] && AMP_FLAG="--amp"
|
||||
[ "$AMP" = true ] && AMP_FLAG="--amp"
|
||||
|
||||
mkdir -p "$OUTPUT_DIR"
|
||||
|
||||
python inference.py --cuda \
|
||||
-i ${PHRASES} \
|
||||
-o ${OUTPUT_DIR} \
|
||||
--dataset-path ${DATA_DIR} \
|
||||
--fastpitch ${FASTPITCH_CH} \
|
||||
--waveglow ${WAVEG_CH} \
|
||||
--fastpitch ${FASTPITCH} \
|
||||
--waveglow ${WAVEGLOW} \
|
||||
--wn-channels 256 \
|
||||
--batch-size ${BS} \
|
||||
${AMP_FLAG}
|
||||
|
|
|
@ -3,7 +3,7 @@
|
|||
set -e
|
||||
|
||||
DATA_DIR="LJSpeech-1.1"
|
||||
TACO_CH="pretrained_models/tacotron2/nvidia_tacotron2pyt_fp32_20190427.pt"
|
||||
TACOTRON2="pretrained_models/tacotron2/nvidia_tacotron2pyt_fp16.pt"
|
||||
for FILELIST in ljs_audio_text_train_filelist.txt \
|
||||
ljs_audio_text_val_filelist.txt \
|
||||
ljs_audio_text_test_filelist.txt \
|
||||
|
@ -16,5 +16,5 @@ for FILELIST in ljs_audio_text_train_filelist.txt \
|
|||
--extract-mels \
|
||||
--extract-durations \
|
||||
--extract-pitch-char \
|
||||
--tacotron2-checkpoint ${TACO_CH}
|
||||
--tacotron2-checkpoint ${TACOTRON2}
|
||||
done
|
||||
|
|
|
@ -2,24 +2,26 @@
|
|||
|
||||
export OMP_NUM_THREADS=1
|
||||
|
||||
: ${NUM_GPUS:=8}
|
||||
: ${BS:=32}
|
||||
: ${GRAD_ACCUMULATION:=1}
|
||||
: ${OUTPUT_DIR:="./output"}
|
||||
: ${AMP:=false}
|
||||
: ${EPOCHS:=1500}
|
||||
|
||||
[ "$AMP" == "true" ] && AMP_FLAG="--amp"
|
||||
|
||||
# Adjust env variables to maintain the global batch size
|
||||
#
|
||||
# NGPU x BS x GRAD_ACC = 256.
|
||||
|
||||
[ ! -n "$OUTPUT_DIR" ] && OUTPUT_DIR="./output"
|
||||
[ ! -n "$NGPU" ] && NGPU=8
|
||||
[ ! -n "$BS" ] && BS=32
|
||||
[ ! -n "$GRAD_ACC" ] && GRAD_ACC=1
|
||||
[ ! -n "$EPOCHS" ] && EPOCHS=1500
|
||||
[ "$AMP" == "true" ] && AMP_FLAG="--amp"
|
||||
|
||||
GBS=$(($NGPU * $BS * $GRAD_ACC))
|
||||
#
|
||||
GBS=$(($NUM_GPUS * $BS * $GRAD_ACCUMULATION))
|
||||
[ $GBS -ne 256 ] && echo -e "\nWARNING: Global batch size changed from 256 to ${GBS}.\n"
|
||||
|
||||
echo -e "\nSetup: ${NGPU}x${BS}x${GRAD_ACC} - global batch size ${GBS}\n"
|
||||
echo -e "\nSetup: ${NUM_GPUS}x${BS}x${GRAD_ACCUMULATION} - global batch size ${GBS}\n"
|
||||
|
||||
mkdir -p "$OUTPUT_DIR"
|
||||
python -m torch.distributed.launch --nproc_per_node ${NGPU} train.py \
|
||||
python -m torch.distributed.launch --nproc_per_node ${NUM_GPUS} train.py \
|
||||
--cuda \
|
||||
-o "$OUTPUT_DIR/" \
|
||||
--log-file "$OUTPUT_DIR/nvlog.json" \
|
||||
|
@ -37,5 +39,5 @@ python -m torch.distributed.launch --nproc_per_node ${NGPU} train.py \
|
|||
--dur-predictor-loss-scale 0.1 \
|
||||
--pitch-predictor-loss-scale 0.1 \
|
||||
--weight-decay 1e-6 \
|
||||
--gradient-accumulation-steps ${GRAD_ACC} \
|
||||
--gradient-accumulation-steps ${GRAD_ACCUMULATION} \
|
||||
${AMP_FLAG}
|
||||
|
|
|
@ -27,9 +27,6 @@
|
|||
|
||||
import argparse
|
||||
|
||||
from common.text import symbols
|
||||
|
||||
|
||||
def parse_tacotron2_args(parent, add_help=False):
|
||||
"""
|
||||
Parse commandline arguments.
|
||||
|
@ -43,11 +40,7 @@ def parse_tacotron2_args(parent, add_help=False):
|
|||
help='Number of bins in mel-spectrograms')
|
||||
|
||||
# symbols parameters
|
||||
global symbols
|
||||
len_symbols = len(symbols)
|
||||
symbols = parser.add_argument_group('symbols parameters')
|
||||
symbols.add_argument('--n-symbols', default=len_symbols, type=int,
|
||||
help='Number of symbols in dictionary')
|
||||
symbols.add_argument('--symbols-embedding-dim', default=512, type=int,
|
||||
help='Input embedding dimension')
|
||||
|
||||
|
|
|
@ -33,8 +33,6 @@ import torch.utils.data
|
|||
|
||||
import common.layers as layers
|
||||
from common.utils import load_wav_to_torch, load_filepaths_and_text, to_gpu
|
||||
from common.text import text_to_sequence
|
||||
|
||||
|
||||
class TextMelLoader(torch.utils.data.Dataset):
|
||||
"""
|
||||
|
@ -43,8 +41,9 @@ class TextMelLoader(torch.utils.data.Dataset):
|
|||
3) computes mel-spectrograms from audio files.
|
||||
"""
|
||||
def __init__(self, dataset_path, audiopaths_and_text, args, load_mel_from_disk=True):
|
||||
self.audiopaths_and_text = load_filepaths_and_text(dataset_path, audiopaths_and_text)
|
||||
self.text_cleaners = args.text_cleaners
|
||||
self.audiopaths_and_text = load_filepaths_and_text(
|
||||
dataset_path, audiopaths_and_text,
|
||||
has_speakers=(args.n_speakers > 1))
|
||||
self.load_mel_from_disk = load_mel_from_disk
|
||||
if not load_mel_from_disk:
|
||||
self.max_wav_value = args.max_wav_value
|
||||
|
@ -74,14 +73,14 @@ class TextMelLoader(torch.utils.data.Dataset):
|
|||
return melspec
|
||||
|
||||
def get_text(self, text):
|
||||
text_norm = torch.IntTensor(text_to_sequence(text, self.text_cleaners))
|
||||
return text_norm
|
||||
text_encoded = torch.IntTensor(self.tp.encode_text(text))
|
||||
return text_encoded
|
||||
|
||||
def __getitem__(self, index):
|
||||
# separate filename and text
|
||||
audiopath, text = self.audiopaths_and_text[index]
|
||||
len_text = len(text)
|
||||
text = self.get_text(text)
|
||||
len_text = len(text)
|
||||
mel = self.get_mel(audiopath)
|
||||
return (text, mel, len_text)
|
||||
|
||||
|
|
|
@ -108,14 +108,20 @@ def parse_args(parser):
|
|||
|
||||
dataset = parser.add_argument_group('dataset parameters')
|
||||
dataset.add_argument('--training-files', type=str, required=True,
|
||||
help='Path to training filelist')
|
||||
help='Path to training filelist. Separate multiple paths with commas.')
|
||||
dataset.add_argument('--validation-files', type=str, required=True,
|
||||
help='Path to validation filelist')
|
||||
help='Path to validation filelist. Separate multiple paths with commas.')
|
||||
dataset.add_argument('--pitch-mean-std-file', type=str, default=None,
|
||||
help='Path to pitch stats to be stored in the model')
|
||||
dataset.add_argument('--text-cleaners', nargs='*',
|
||||
default=['english_cleaners'], type=str,
|
||||
help='Type of text cleaners for input text')
|
||||
dataset.add_argument('--symbol-set', type=str, default='english_basic',
|
||||
help='Define symbol set for input text')
|
||||
|
||||
cond = parser.add_argument_group('conditioning on additional attributes')
|
||||
cond.add_argument('--n-speakers', type=int, default=1,
|
||||
help='Condition on speaker, value > 1 enables trainable speaker embeddings.')
|
||||
|
||||
distributed = parser.add_argument_group('distributed setup')
|
||||
distributed.add_argument('--local_rank', type=int, default=os.getenv('LOCAL_RANK', 0),
|
||||
|
@ -316,8 +322,6 @@ def main():
|
|||
model = models.get_model('FastPitch', model_config, device)
|
||||
|
||||
# Store pitch mean/std as params to translate from Hz during inference
|
||||
fpath = common.utils.stats_filename(
|
||||
args.dataset_path, args.training_files, 'pitch_char')
|
||||
with open(args.pitch_mean_std_file, 'r') as f:
|
||||
stats = json.load(f)
|
||||
model.pitch_mean[0] = stats['mean']
|
||||
|
@ -530,6 +534,13 @@ def main():
|
|||
validate(model, None, total_iter, criterion, valset, args.batch_size,
|
||||
collate_fn, distributed_run, batch_to_gpu, use_gt_durations=True)
|
||||
|
||||
if (epoch > 0 and args.epochs_per_checkpoint > 0 and
|
||||
(epoch % args.epochs_per_checkpoint != 0) and args.local_rank == 0):
|
||||
checkpoint_path = os.path.join(
|
||||
args.output, f"FastPitch_checkpoint_{epoch}.pt")
|
||||
save_checkpoint(args.local_rank, model, ema_model, optimizer, epoch,
|
||||
total_iter, model_config, args.amp, checkpoint_path)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
|
|
@ -58,6 +58,7 @@ class Invertible1x1Conv(torch.nn.Module):
|
|||
if torch.det(W) < 0:
|
||||
W[:, 0] = -1 * W[:, 0]
|
||||
W = W.view(c, c, 1)
|
||||
W = W.contiguous()
|
||||
self.conv.weight.data = W
|
||||
|
||||
def forward(self, z):
|
||||
|
@ -279,6 +280,49 @@ class WaveGlow(torch.nn.Module):
|
|||
return audio
|
||||
|
||||
|
||||
def infer_onnx(self, spect, z, sigma=0.9):
|
||||
|
||||
spect = self.upsample(spect)
|
||||
# trim conv artifacts. maybe pad spec to kernel multiple
|
||||
time_cutoff = self.upsample.kernel_size[0] - self.upsample.stride[0]
|
||||
spect = spect[:, :, :-time_cutoff]
|
||||
|
||||
length_spect_group = spect.size(2)//8
|
||||
mel_dim = 80
|
||||
batch_size = spect.size(0)
|
||||
|
||||
spect = spect.view((batch_size, mel_dim, length_spect_group, self.n_group))
|
||||
spect = spect.permute(0, 2, 1, 3)
|
||||
spect = spect.contiguous()
|
||||
spect = spect.view((batch_size, length_spect_group, self.n_group*mel_dim))
|
||||
spect = spect.permute(0, 2, 1)
|
||||
spect = spect.contiguous()
|
||||
|
||||
audio = z[:, :self.n_remaining_channels, :]
|
||||
z = z[:, self.n_remaining_channels:self.n_group, :]
|
||||
audio = sigma*audio
|
||||
|
||||
for k in reversed(range(self.n_flows)):
|
||||
n_half = int(audio.size(1) // 2)
|
||||
audio_0 = audio[:, :n_half, :]
|
||||
audio_1 = audio[:, n_half:(n_half+n_half), :]
|
||||
|
||||
output = self.WN[k]((audio_0, spect))
|
||||
s = output[:, n_half:(n_half+n_half), :]
|
||||
b = output[:, :n_half, :]
|
||||
audio_1 = (audio_1 - b) / torch.exp(s)
|
||||
audio = torch.cat([audio_0, audio_1], 1)
|
||||
audio = self.convinv[k].infer(audio)
|
||||
|
||||
if k % self.n_early_every == 0 and k > 0:
|
||||
audio = torch.cat((z[:, :self.n_early_size, :], audio), 1)
|
||||
z = z[:, self.n_early_size:self.n_group, :]
|
||||
|
||||
audio = audio.permute(0,2,1).contiguous().view(batch_size, (length_spect_group * self.n_group))
|
||||
|
||||
return audio
|
||||
|
||||
|
||||
@staticmethod
|
||||
def remove_weightnorm(model):
|
||||
waveglow = model
|
||||
|
|
Loading…
Reference in a new issue