[FastPitch/PyT] updated checkpoints, multispeaker and text processing
This commit is contained in:
parent
03c5a9fd9b
commit
bec82593f5
17
PyTorch/SpeechSynthesis/FastPitch/.gitignore
vendored
17
PyTorch/SpeechSynthesis/FastPitch/.gitignore
vendored
|
@ -1,8 +1,15 @@
|
||||||
*.swp
|
|
||||||
*.swo
|
|
||||||
*.pyc
|
|
||||||
__pycache__
|
|
||||||
scripts_joc/
|
|
||||||
runs*/
|
runs*/
|
||||||
LJSpeech-1.1/
|
LJSpeech-1.1/
|
||||||
output*
|
output*
|
||||||
|
scripts_joc/
|
||||||
|
tests/
|
||||||
|
|
||||||
|
*.pyc
|
||||||
|
__pycache__
|
||||||
|
|
||||||
|
.idea/
|
||||||
|
.DS_Store
|
||||||
|
|
||||||
|
*.swp
|
||||||
|
*.swo
|
||||||
|
*.swn
|
||||||
|
|
|
@ -488,11 +488,11 @@ The `scripts/train.sh` script is configured for 8x GPU with at least 16GB of mem
|
||||||
```
|
```
|
||||||
In a single accumulated step, there are `batch_size x gradient_accumulation_steps x GPUs = 256` examples being processed in parallel. With a smaller number of GPUs, increase `--gradient_accumulation_steps` to keep this relation satisfied, e.g., through env variables
|
In a single accumulated step, there are `batch_size x gradient_accumulation_steps x GPUs = 256` examples being processed in parallel. With a smaller number of GPUs, increase `--gradient_accumulation_steps` to keep this relation satisfied, e.g., through env variables
|
||||||
```bash
|
```bash
|
||||||
NGPU=4 GRAD_ACC=2 bash scripts/train.sh
|
NUM_GPUS=4 GRAD_ACCUMULATION=2 bash scripts/train.sh
|
||||||
```
|
```
|
||||||
With automatic mixed precision (AMP), a larger batch size fits in 16GB of memory:
|
With automatic mixed precision (AMP), a larger batch size fits in 16GB of memory:
|
||||||
```bash
|
```bash
|
||||||
NGPU=4 GRAD_ACC=1 BS=64 AMP=true bash scripts/train.sh
|
NUM_GPUS=4 GRAD_ACCUMULATION=1 BS=64 AMP=true bash scripts/train.sh
|
||||||
```
|
```
|
||||||
|
|
||||||
### Inference process
|
### Inference process
|
||||||
|
@ -545,18 +545,18 @@ To benchmark the training performance on a specific batch size, run:
|
||||||
|
|
||||||
* NVIDIA DGX A100 (8x A100 40GB)
|
* NVIDIA DGX A100 (8x A100 40GB)
|
||||||
```bash
|
```bash
|
||||||
AMP=true NGPU=1 BS=128 GRAD_ACC=2 EPOCHS=10 bash scripts/train.sh
|
AMP=true NUM_GPUS=1 BS=128 GRAD_ACCUMULATION=2 EPOCHS=10 bash scripts/train.sh
|
||||||
AMP=true NGPU=8 BS=32 GRAD_ACC=1 EPOCHS=10 bash scripts/train.sh
|
AMP=true NUM_GPUS=8 BS=32 GRAD_ACCUMULATION=1 EPOCHS=10 bash scripts/train.sh
|
||||||
NGPU=1 BS=128 GRAD_ACC=2 EPOCHS=10 bash scripts/train.sh
|
NUM_GPUS=1 BS=128 GRAD_ACCUMULATION=2 EPOCHS=10 bash scripts/train.sh
|
||||||
NGPU=8 BS=32 GRAD_ACC=1 EPOCHS=10 bash scripts/train.sh
|
NUM_GPUS=8 BS=32 GRAD_ACCUMULATION=1 EPOCHS=10 bash scripts/train.sh
|
||||||
```
|
```
|
||||||
|
|
||||||
* NVIDIA DGX-1 (8x V100 16GB)
|
* NVIDIA DGX-1 (8x V100 16GB)
|
||||||
```bash
|
```bash
|
||||||
AMP=true NGPU=1 BS=64 GRAD_ACC=4 EPOCHS=10 bash scripts/train.sh
|
AMP=true NUM_GPUS=1 BS=64 GRAD_ACCUMULATION=4 EPOCHS=10 bash scripts/train.sh
|
||||||
AMP=true NGPU=8 BS=32 GRAD_ACC=1 EPOCHS=10 bash scripts/train.sh
|
AMP=true NUM_GPUS=8 BS=32 GRAD_ACCUMULATION=1 EPOCHS=10 bash scripts/train.sh
|
||||||
NGPU=1 BS=32 GRAD_ACC=8 EPOCHS=10 bash scripts/train.sh
|
NUM_GPUS=1 BS=32 GRAD_ACCUMULATION=8 EPOCHS=10 bash scripts/train.sh
|
||||||
NGPU=8 BS=32 GRAD_ACC=1 EPOCHS=10 bash scripts/train.sh
|
NUM_GPUS=8 BS=32 GRAD_ACCUMULATION=1 EPOCHS=10 bash scripts/train.sh
|
||||||
```
|
```
|
||||||
|
|
||||||
Each of these scripts runs for 10 epochs and for each epoch measures the
|
Each of these scripts runs for 10 epochs and for each epoch measures the
|
||||||
|
@ -569,12 +569,12 @@ To benchmark the inference performance on a specific batch size, run:
|
||||||
|
|
||||||
* For FP16
|
* For FP16
|
||||||
```bash
|
```bash
|
||||||
AMP=true BS_SEQ=”1 4 8” REPEATS=100 bash scripts/inference_benchmark.sh
|
AMP=true BS_SEQUENCE=”1 4 8” REPEATS=100 bash scripts/inference_benchmark.sh
|
||||||
```
|
```
|
||||||
|
|
||||||
* For FP32 or TF32
|
* For FP32 or TF32
|
||||||
```bash
|
```bash
|
||||||
BS_SEQ=”1 4 8” REPEATS=100 bash scripts/inference_benchmark.sh
|
BS_SEQUENCE=”1 4 8” REPEATS=100 bash scripts/inference_benchmark.sh
|
||||||
```
|
```
|
||||||
|
|
||||||
The output log files will contain performance numbers for the FastPitch model
|
The output log files will contain performance numbers for the FastPitch model
|
||||||
|
@ -726,6 +726,10 @@ The input utterance has 128 characters, synthesized audio has 8.05 s.
|
||||||
|
|
||||||
### Changelog
|
### Changelog
|
||||||
|
|
||||||
|
October 2020
|
||||||
|
- Added multispeaker capabilities
|
||||||
|
- Updated text processing module
|
||||||
|
|
||||||
June 2020
|
June 2020
|
||||||
- Updated performance tables to include A100 results
|
- Updated performance tables to include A100 results
|
||||||
|
|
||||||
|
|
|
@ -1,74 +1,3 @@
|
||||||
""" from https://github.com/keithito/tacotron """
|
from .cmudict import CMUDict
|
||||||
import re
|
|
||||||
from common.text import cleaners
|
|
||||||
from common.text.symbols import symbols
|
|
||||||
|
|
||||||
|
cmudict = CMUDict()
|
||||||
# Mappings from symbol to numeric ID and vice versa:
|
|
||||||
_symbol_to_id = {s: i for i, s in enumerate(symbols)}
|
|
||||||
_id_to_symbol = {i: s for i, s in enumerate(symbols)}
|
|
||||||
|
|
||||||
# Regular expression matching text enclosed in curly braces:
|
|
||||||
_curly_re = re.compile(r'(.*?)\{(.+?)\}(.*)')
|
|
||||||
|
|
||||||
|
|
||||||
def text_to_sequence(text, cleaner_names):
|
|
||||||
'''Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
|
|
||||||
|
|
||||||
The text can optionally have ARPAbet sequences enclosed in curly braces embedded
|
|
||||||
in it. For example, "Turn left on {HH AW1 S S T AH0 N} Street."
|
|
||||||
|
|
||||||
Args:
|
|
||||||
text: string to convert to a sequence
|
|
||||||
cleaner_names: names of the cleaner functions to run the text through
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of integers corresponding to the symbols in the text
|
|
||||||
'''
|
|
||||||
sequence = []
|
|
||||||
|
|
||||||
# Check for curly braces and treat their contents as ARPAbet:
|
|
||||||
while len(text):
|
|
||||||
m = _curly_re.match(text)
|
|
||||||
if not m:
|
|
||||||
sequence += _symbols_to_sequence(_clean_text(text, cleaner_names))
|
|
||||||
break
|
|
||||||
sequence += _symbols_to_sequence(_clean_text(m.group(1), cleaner_names))
|
|
||||||
sequence += _arpabet_to_sequence(m.group(2))
|
|
||||||
text = m.group(3)
|
|
||||||
|
|
||||||
return sequence
|
|
||||||
|
|
||||||
|
|
||||||
def sequence_to_text(sequence):
|
|
||||||
'''Converts a sequence of IDs back to a string'''
|
|
||||||
result = ''
|
|
||||||
for symbol_id in sequence:
|
|
||||||
if symbol_id in _id_to_symbol:
|
|
||||||
s = _id_to_symbol[symbol_id]
|
|
||||||
# Enclose ARPAbet back in curly braces:
|
|
||||||
if len(s) > 1 and s[0] == '@':
|
|
||||||
s = '{%s}' % s[1:]
|
|
||||||
result += s
|
|
||||||
return result.replace('}{', ' ')
|
|
||||||
|
|
||||||
|
|
||||||
def _clean_text(text, cleaner_names):
|
|
||||||
for name in cleaner_names:
|
|
||||||
cleaner = getattr(cleaners, name)
|
|
||||||
if not cleaner:
|
|
||||||
raise Exception('Unknown cleaner: %s' % name)
|
|
||||||
text = cleaner(text)
|
|
||||||
return text
|
|
||||||
|
|
||||||
|
|
||||||
def _symbols_to_sequence(symbols):
|
|
||||||
return [_symbol_to_id[s] for s in symbols if _should_keep_symbol(s)]
|
|
||||||
|
|
||||||
|
|
||||||
def _arpabet_to_sequence(text):
|
|
||||||
return _symbols_to_sequence(['@' + s for s in text.split()])
|
|
||||||
|
|
||||||
|
|
||||||
def _should_keep_symbol(s):
|
|
||||||
return s in _symbol_to_id and s is not '_' and s is not '~'
|
|
||||||
|
|
|
@ -0,0 +1,58 @@
|
||||||
|
import re
|
||||||
|
|
||||||
|
_no_period_re = re.compile(r'(No[.])(?=[ ]?[0-9])')
|
||||||
|
_percent_re = re.compile(r'([ ]?[%])')
|
||||||
|
_half_re = re.compile('([0-9]½)|(½)')
|
||||||
|
|
||||||
|
|
||||||
|
# List of (regular expression, replacement) pairs for abbreviations:
|
||||||
|
_abbreviations = [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1]) for x in [
|
||||||
|
('mrs', 'misess'),
|
||||||
|
('ms', 'miss'),
|
||||||
|
('mr', 'mister'),
|
||||||
|
('dr', 'doctor'),
|
||||||
|
('st', 'saint'),
|
||||||
|
('co', 'company'),
|
||||||
|
('jr', 'junior'),
|
||||||
|
('maj', 'major'),
|
||||||
|
('gen', 'general'),
|
||||||
|
('drs', 'doctors'),
|
||||||
|
('rev', 'reverend'),
|
||||||
|
('lt', 'lieutenant'),
|
||||||
|
('hon', 'honorable'),
|
||||||
|
('sgt', 'sergeant'),
|
||||||
|
('capt', 'captain'),
|
||||||
|
('esq', 'esquire'),
|
||||||
|
('ltd', 'limited'),
|
||||||
|
('col', 'colonel'),
|
||||||
|
('ft', 'fort'),
|
||||||
|
('sen', 'senator'),
|
||||||
|
]]
|
||||||
|
|
||||||
|
|
||||||
|
def _expand_no_period(m):
|
||||||
|
word = m.group(0)
|
||||||
|
if word[0] == 'N':
|
||||||
|
return 'Number'
|
||||||
|
return 'number'
|
||||||
|
|
||||||
|
|
||||||
|
def _expand_percent(m):
|
||||||
|
return ' percent'
|
||||||
|
|
||||||
|
|
||||||
|
def _expand_half(m):
|
||||||
|
word = m.group(1)
|
||||||
|
if word is None:
|
||||||
|
return 'half'
|
||||||
|
return word[0] + ' and a half'
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_abbreviations(text):
|
||||||
|
text = re.sub(_no_period_re, _expand_no_period, text)
|
||||||
|
text = re.sub(_percent_re, _expand_percent, text)
|
||||||
|
text = re.sub(_half_re, _expand_half, text)
|
||||||
|
|
||||||
|
for regex, replacement in _abbreviations:
|
||||||
|
text = re.sub(regex, replacement, text)
|
||||||
|
return text
|
67
PyTorch/SpeechSynthesis/FastPitch/common/text/acronyms.py
Normal file
67
PyTorch/SpeechSynthesis/FastPitch/common/text/acronyms.py
Normal file
|
@ -0,0 +1,67 @@
|
||||||
|
import re
|
||||||
|
from . import cmudict
|
||||||
|
|
||||||
|
_letter_to_arpabet = {
|
||||||
|
'A': 'EY1',
|
||||||
|
'B': 'B IY1',
|
||||||
|
'C': 'S IY1',
|
||||||
|
'D': 'D IY1',
|
||||||
|
'E': 'IY1',
|
||||||
|
'F': 'EH1 F',
|
||||||
|
'G': 'JH IY1',
|
||||||
|
'H': 'EY1 CH',
|
||||||
|
'I': 'AY1',
|
||||||
|
'J': 'JH EY1',
|
||||||
|
'K': 'K EY1',
|
||||||
|
'L': 'EH1 L',
|
||||||
|
'M': 'EH1 M',
|
||||||
|
'N': 'EH1 N',
|
||||||
|
'O': 'OW1',
|
||||||
|
'P': 'P IY1',
|
||||||
|
'Q': 'K Y UW1',
|
||||||
|
'R': 'AA1 R',
|
||||||
|
'S': 'EH1 S',
|
||||||
|
'T': 'T IY1',
|
||||||
|
'U': 'Y UW1',
|
||||||
|
'V': 'V IY1',
|
||||||
|
'X': 'EH1 K S',
|
||||||
|
'Y': 'W AY1',
|
||||||
|
'W': 'D AH1 B AH0 L Y UW0',
|
||||||
|
'Z': 'Z IY1',
|
||||||
|
's': 'Z'
|
||||||
|
}
|
||||||
|
|
||||||
|
# must ignore roman numerals
|
||||||
|
# _acronym_re = re.compile(r'([A-Z][A-Z]+)s?|([A-Z]\.([A-Z]\.)+s?)')
|
||||||
|
_acronym_re = re.compile(r'([A-Z][A-Z]+)s?')
|
||||||
|
|
||||||
|
|
||||||
|
def _expand_acronyms(m, add_spaces=True):
|
||||||
|
acronym = m.group(0)
|
||||||
|
|
||||||
|
# remove dots if they exist
|
||||||
|
acronym = re.sub('\.', '', acronym)
|
||||||
|
|
||||||
|
acronym = "".join(acronym.split())
|
||||||
|
arpabet = cmudict.lookup(acronym)
|
||||||
|
|
||||||
|
if arpabet is None:
|
||||||
|
acronym = list(acronym)
|
||||||
|
arpabet = ["{" + _letter_to_arpabet[letter] + "}" for letter in acronym]
|
||||||
|
# temporary fix
|
||||||
|
if arpabet[-1] == '{Z}' and len(arpabet) > 1:
|
||||||
|
arpabet[-2] = arpabet[-2][:-1] + ' ' + arpabet[-1][1:]
|
||||||
|
del arpabet[-1]
|
||||||
|
|
||||||
|
arpabet = ' '.join(arpabet)
|
||||||
|
elif len(arpabet) == 1:
|
||||||
|
arpabet = "{" + arpabet[0] + "}"
|
||||||
|
else:
|
||||||
|
arpabet = acronym
|
||||||
|
|
||||||
|
return arpabet
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_acronyms(text):
|
||||||
|
text = re.sub(_acronym_re, _expand_acronyms, text)
|
||||||
|
return text
|
|
@ -1,90 +1,92 @@
|
||||||
""" from https://github.com/keithito/tacotron """
|
""" adapted from https://github.com/keithito/tacotron """
|
||||||
|
|
||||||
'''
|
'''
|
||||||
Cleaners are transformations that run over the input text at both training and eval time.
|
Cleaners are transformations that run over the input text at both training and eval time.
|
||||||
|
|
||||||
Cleaners can be selected by passing a comma-delimited list of cleaner names as the "cleaners"
|
Cleaners can be selected by passing a comma-delimited list of cleaner names as the "cleaners"
|
||||||
hyperparameter. Some cleaners are English-specific. You'll typically want to use:
|
hyperparameter. Some cleaners are English-specific. You'll typically want to use:
|
||||||
1. "english_cleaners" for English text
|
1. "english_cleaners" for English text
|
||||||
2. "transliteration_cleaners" for non-English text that can be transliterated to ASCII using
|
2. "transliteration_cleaners" for non-English text that can be transliterated to ASCII using
|
||||||
the Unidecode library (https://pypi.python.org/pypi/Unidecode)
|
the Unidecode library (https://pypi.python.org/pypi/Unidecode)
|
||||||
3. "basic_cleaners" if you do not want to transliterate (in this case, you should also update
|
3. "basic_cleaners" if you do not want to transliterate (in this case, you should also update
|
||||||
the symbols in symbols.py to match your data).
|
the symbols in symbols.py to match your data).
|
||||||
'''
|
'''
|
||||||
|
|
||||||
import re
|
import re
|
||||||
from unidecode import unidecode
|
from unidecode import unidecode
|
||||||
from .numbers import normalize_numbers
|
from .numerical import normalize_numbers
|
||||||
|
from .acronyms import normalize_acronyms
|
||||||
|
from .datestime import normalize_datestime
|
||||||
|
from .letters_and_numbers import normalize_letters_and_numbers
|
||||||
|
from .abbreviations import normalize_abbreviations
|
||||||
|
|
||||||
|
|
||||||
# Regular expression matching whitespace:
|
# Regular expression matching whitespace:
|
||||||
_whitespace_re = re.compile(r'\s+')
|
_whitespace_re = re.compile(r'\s+')
|
||||||
|
|
||||||
# List of (regular expression, replacement) pairs for abbreviations:
|
|
||||||
_abbreviations = [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1]) for x in [
|
|
||||||
('mrs', 'misess'),
|
|
||||||
('mr', 'mister'),
|
|
||||||
('dr', 'doctor'),
|
|
||||||
('st', 'saint'),
|
|
||||||
('co', 'company'),
|
|
||||||
('jr', 'junior'),
|
|
||||||
('maj', 'major'),
|
|
||||||
('gen', 'general'),
|
|
||||||
('drs', 'doctors'),
|
|
||||||
('rev', 'reverend'),
|
|
||||||
('lt', 'lieutenant'),
|
|
||||||
('hon', 'honorable'),
|
|
||||||
('sgt', 'sergeant'),
|
|
||||||
('capt', 'captain'),
|
|
||||||
('esq', 'esquire'),
|
|
||||||
('ltd', 'limited'),
|
|
||||||
('col', 'colonel'),
|
|
||||||
('ft', 'fort'),
|
|
||||||
]]
|
|
||||||
|
|
||||||
|
|
||||||
def expand_abbreviations(text):
|
def expand_abbreviations(text):
|
||||||
for regex, replacement in _abbreviations:
|
return normalize_abbreviations(text)
|
||||||
text = re.sub(regex, replacement, text)
|
|
||||||
return text
|
|
||||||
|
|
||||||
|
|
||||||
def expand_numbers(text):
|
def expand_numbers(text):
|
||||||
return normalize_numbers(text)
|
return normalize_numbers(text)
|
||||||
|
|
||||||
|
|
||||||
|
def expand_acronyms(text):
|
||||||
|
return normalize_acronyms(text)
|
||||||
|
|
||||||
|
|
||||||
|
def expand_datestime(text):
|
||||||
|
return normalize_datestime(text)
|
||||||
|
|
||||||
|
|
||||||
|
def expand_letters_and_numbers(text):
|
||||||
|
return normalize_letters_and_numbers(text)
|
||||||
|
|
||||||
|
|
||||||
def lowercase(text):
|
def lowercase(text):
|
||||||
return text.lower()
|
return text.lower()
|
||||||
|
|
||||||
|
|
||||||
def collapse_whitespace(text):
|
def collapse_whitespace(text):
|
||||||
return re.sub(_whitespace_re, ' ', text)
|
return re.sub(_whitespace_re, ' ', text)
|
||||||
|
|
||||||
|
|
||||||
|
def separate_acronyms(text):
|
||||||
|
text = re.sub(r"([0-9]+)([a-zA-Z]+)", r"\1 \2", text)
|
||||||
|
text = re.sub(r"([a-zA-Z]+)([0-9]+)", r"\1 \2", text)
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
def convert_to_ascii(text):
|
def convert_to_ascii(text):
|
||||||
return unidecode(text)
|
return unidecode(text)
|
||||||
|
|
||||||
|
|
||||||
def basic_cleaners(text):
|
def basic_cleaners(text):
|
||||||
'''Basic pipeline that lowercases and collapses whitespace without transliteration.'''
|
'''Basic pipeline that collapses whitespace without transliteration.'''
|
||||||
text = lowercase(text)
|
text = lowercase(text)
|
||||||
text = collapse_whitespace(text)
|
text = collapse_whitespace(text)
|
||||||
return text
|
return text
|
||||||
|
|
||||||
|
|
||||||
def transliteration_cleaners(text):
|
def transliteration_cleaners(text):
|
||||||
'''Pipeline for non-English text that transliterates to ASCII.'''
|
'''Pipeline for non-English text that transliterates to ASCII.'''
|
||||||
text = convert_to_ascii(text)
|
text = convert_to_ascii(text)
|
||||||
text = lowercase(text)
|
text = lowercase(text)
|
||||||
text = collapse_whitespace(text)
|
text = collapse_whitespace(text)
|
||||||
return text
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
def english_cleaners_post_chars(word):
|
||||||
|
return word
|
||||||
|
|
||||||
|
|
||||||
def english_cleaners(text):
|
def english_cleaners(text):
|
||||||
'''Pipeline for English text, including number and abbreviation expansion.'''
|
'''Pipeline for English text, with number and abbreviation expansion.'''
|
||||||
text = convert_to_ascii(text)
|
text = convert_to_ascii(text)
|
||||||
text = lowercase(text)
|
text = lowercase(text)
|
||||||
text = expand_numbers(text)
|
text = expand_numbers(text)
|
||||||
text = expand_abbreviations(text)
|
text = expand_abbreviations(text)
|
||||||
text = collapse_whitespace(text)
|
text = collapse_whitespace(text)
|
||||||
return text
|
return text
|
||||||
|
|
|
@ -18,7 +18,18 @@ _valid_symbol_set = set(valid_symbols)
|
||||||
|
|
||||||
class CMUDict:
|
class CMUDict:
|
||||||
'''Thin wrapper around CMUDict data. http://www.speech.cs.cmu.edu/cgi-bin/cmudict'''
|
'''Thin wrapper around CMUDict data. http://www.speech.cs.cmu.edu/cgi-bin/cmudict'''
|
||||||
def __init__(self, file_or_path, keep_ambiguous=True):
|
def __init__(self, file_or_path=None, heteronyms_path=None, keep_ambiguous=True):
|
||||||
|
if file_or_path is None:
|
||||||
|
self._entries = {}
|
||||||
|
else:
|
||||||
|
self.initialize(file_or_path, keep_ambiguous)
|
||||||
|
|
||||||
|
if heteronyms_path is None:
|
||||||
|
self.heteronyms = []
|
||||||
|
else:
|
||||||
|
self.heteronyms = set(lines_to_list(heteronyms_path))
|
||||||
|
|
||||||
|
def initialize(self, file_or_path, keep_ambiguous=True):
|
||||||
if isinstance(file_or_path, str):
|
if isinstance(file_or_path, str):
|
||||||
with open(file_or_path, encoding='latin-1') as f:
|
with open(file_or_path, encoding='latin-1') as f:
|
||||||
entries = _parse_cmudict(f)
|
entries = _parse_cmudict(f)
|
||||||
|
@ -28,17 +39,18 @@ class CMUDict:
|
||||||
entries = {word: pron for word, pron in entries.items() if len(pron) == 1}
|
entries = {word: pron for word, pron in entries.items() if len(pron) == 1}
|
||||||
self._entries = entries
|
self._entries = entries
|
||||||
|
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
|
if len(self._entries) == 0:
|
||||||
|
raise ValueError("CMUDict not initialized")
|
||||||
return len(self._entries)
|
return len(self._entries)
|
||||||
|
|
||||||
|
|
||||||
def lookup(self, word):
|
def lookup(self, word):
|
||||||
'''Returns list of ARPAbet pronunciations of the given word.'''
|
'''Returns list of ARPAbet pronunciations of the given word.'''
|
||||||
|
if len(self._entries) == 0:
|
||||||
|
raise ValueError("CMUDict not initialized")
|
||||||
return self._entries.get(word.upper())
|
return self._entries.get(word.upper())
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
_alt_re = re.compile(r'\([0-9]+\)')
|
_alt_re = re.compile(r'\([0-9]+\)')
|
||||||
|
|
||||||
|
|
||||||
|
|
22
PyTorch/SpeechSynthesis/FastPitch/common/text/datestime.py
Normal file
22
PyTorch/SpeechSynthesis/FastPitch/common/text/datestime.py
Normal file
|
@ -0,0 +1,22 @@
|
||||||
|
import re
|
||||||
|
_ampm_re = re.compile(
|
||||||
|
r'([0-9]|0[0-9]|1[0-9]|2[0-3]):?([0-5][0-9])?\s*([AaPp][Mm]\b)')
|
||||||
|
|
||||||
|
|
||||||
|
def _expand_ampm(m):
|
||||||
|
matches = list(m.groups(0))
|
||||||
|
txt = matches[0]
|
||||||
|
txt = txt if int(matches[1]) == 0 else txt + ' ' + matches[1]
|
||||||
|
|
||||||
|
if matches[2][0].lower() == 'a':
|
||||||
|
txt += ' a.m.'
|
||||||
|
elif matches[2][0].lower() == 'p':
|
||||||
|
txt += ' p.m.'
|
||||||
|
|
||||||
|
return txt
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_datestime(text):
|
||||||
|
text = re.sub(_ampm_re, _expand_ampm, text)
|
||||||
|
#text = re.sub(r"([0-9]|0[0-9]|1[0-9]|2[0-3]):([0-5][0-9])?", r"\1 \2", text)
|
||||||
|
return text
|
|
@ -0,0 +1,90 @@
|
||||||
|
import re
|
||||||
|
_letters_and_numbers_re = re.compile(
|
||||||
|
r"((?:[a-zA-Z]+[0-9]|[0-9]+[a-zA-Z])[a-zA-Z0-9']*)", re.IGNORECASE)
|
||||||
|
|
||||||
|
_hardware_re = re.compile(
|
||||||
|
'([0-9]+(?:[.,][0-9]+)?)(?:\s?)(tb|gb|mb|kb|ghz|mhz|khz|hz|mm)', re.IGNORECASE)
|
||||||
|
_hardware_key = {'tb': 'terabyte',
|
||||||
|
'gb': 'gigabyte',
|
||||||
|
'mb': 'megabyte',
|
||||||
|
'kb': 'kilobyte',
|
||||||
|
'ghz': 'gigahertz',
|
||||||
|
'mhz': 'megahertz',
|
||||||
|
'khz': 'kilohertz',
|
||||||
|
'hz': 'hertz',
|
||||||
|
'mm': 'millimeter',
|
||||||
|
'cm': 'centimeter',
|
||||||
|
'km': 'kilometer'}
|
||||||
|
|
||||||
|
_dimension_re = re.compile(
|
||||||
|
r'\b(\d+(?:[,.]\d+)?\s*[xX]\s*\d+(?:[,.]\d+)?\s*[xX]\s*\d+(?:[,.]\d+)?(?:in|inch|m)?)\b|\b(\d+(?:[,.]\d+)?\s*[xX]\s*\d+(?:[,.]\d+)?(?:in|inch|m)?)\b')
|
||||||
|
_dimension_key = {'m': 'meter',
|
||||||
|
'in': 'inch',
|
||||||
|
'inch': 'inch'}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def _expand_letters_and_numbers(m):
|
||||||
|
text = re.split(r'(\d+)', m.group(0))
|
||||||
|
|
||||||
|
# remove trailing space
|
||||||
|
if text[-1] == '':
|
||||||
|
text = text[:-1]
|
||||||
|
elif text[0] == '':
|
||||||
|
text = text[1:]
|
||||||
|
|
||||||
|
# if not like 1920s, or AK47's , 20th, 1st, 2nd, 3rd, etc...
|
||||||
|
if text[-1] in ("'s", "s", "th", "nd", "st", "rd") and text[-2].isdigit():
|
||||||
|
text[-2] = text[-2] + text[-1]
|
||||||
|
text = text[:-1]
|
||||||
|
|
||||||
|
# for combining digits 2 by 2
|
||||||
|
new_text = []
|
||||||
|
for i in range(len(text)):
|
||||||
|
string = text[i]
|
||||||
|
if string.isdigit() and len(string) < 5:
|
||||||
|
# heuristics
|
||||||
|
if len(string) > 2 and string[-2] == '0':
|
||||||
|
if string[-1] == '0':
|
||||||
|
string = [string]
|
||||||
|
else:
|
||||||
|
string = [string[:-3], string[-2], string[-1]]
|
||||||
|
elif len(string) % 2 == 0:
|
||||||
|
string = [string[i:i+2] for i in range(0, len(string), 2)]
|
||||||
|
elif len(string) > 2:
|
||||||
|
string = [string[0]] + [string[i:i+2] for i in range(1, len(string), 2)]
|
||||||
|
new_text.extend(string)
|
||||||
|
else:
|
||||||
|
new_text.append(string)
|
||||||
|
|
||||||
|
text = new_text
|
||||||
|
text = " ".join(text)
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
def _expand_hardware(m):
|
||||||
|
quantity, measure = m.groups(0)
|
||||||
|
measure = _hardware_key[measure.lower()]
|
||||||
|
if measure[-1] != 'z' and float(quantity.replace(',', '')) > 1:
|
||||||
|
return "{} {}s".format(quantity, measure)
|
||||||
|
return "{} {}".format(quantity, measure)
|
||||||
|
|
||||||
|
|
||||||
|
def _expand_dimension(m):
|
||||||
|
text = "".join([x for x in m.groups(0) if x != 0])
|
||||||
|
text = text.replace(' x ', ' by ')
|
||||||
|
text = text.replace('x', ' by ')
|
||||||
|
if text.endswith(tuple(_dimension_key.keys())):
|
||||||
|
if text[-2].isdigit():
|
||||||
|
text = "{} {}".format(text[:-1], _dimension_key[text[-1:]])
|
||||||
|
elif text[-3].isdigit():
|
||||||
|
text = "{} {}".format(text[:-2], _dimension_key[text[-2:]])
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_letters_and_numbers(text):
|
||||||
|
text = re.sub(_hardware_re, _expand_hardware, text)
|
||||||
|
text = re.sub(_dimension_re, _expand_dimension, text)
|
||||||
|
text = re.sub(_letters_and_numbers_re, _expand_letters_and_numbers, text)
|
||||||
|
return text
|
|
@ -1,71 +0,0 @@
|
||||||
""" from https://github.com/keithito/tacotron """
|
|
||||||
|
|
||||||
import inflect
|
|
||||||
import re
|
|
||||||
|
|
||||||
|
|
||||||
_inflect = inflect.engine()
|
|
||||||
_comma_number_re = re.compile(r'([0-9][0-9\,]+[0-9])')
|
|
||||||
_decimal_number_re = re.compile(r'([0-9]+\.[0-9]+)')
|
|
||||||
_pounds_re = re.compile(r'£([0-9\,]*[0-9]+)')
|
|
||||||
_dollars_re = re.compile(r'\$([0-9\.\,]*[0-9]+)')
|
|
||||||
_ordinal_re = re.compile(r'[0-9]+(st|nd|rd|th)')
|
|
||||||
_number_re = re.compile(r'[0-9]+')
|
|
||||||
|
|
||||||
|
|
||||||
def _remove_commas(m):
|
|
||||||
return m.group(1).replace(',', '')
|
|
||||||
|
|
||||||
|
|
||||||
def _expand_decimal_point(m):
|
|
||||||
return m.group(1).replace('.', ' point ')
|
|
||||||
|
|
||||||
|
|
||||||
def _expand_dollars(m):
|
|
||||||
match = m.group(1)
|
|
||||||
parts = match.split('.')
|
|
||||||
if len(parts) > 2:
|
|
||||||
return match + ' dollars' # Unexpected format
|
|
||||||
dollars = int(parts[0]) if parts[0] else 0
|
|
||||||
cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0
|
|
||||||
if dollars and cents:
|
|
||||||
dollar_unit = 'dollar' if dollars == 1 else 'dollars'
|
|
||||||
cent_unit = 'cent' if cents == 1 else 'cents'
|
|
||||||
return '%s %s, %s %s' % (dollars, dollar_unit, cents, cent_unit)
|
|
||||||
elif dollars:
|
|
||||||
dollar_unit = 'dollar' if dollars == 1 else 'dollars'
|
|
||||||
return '%s %s' % (dollars, dollar_unit)
|
|
||||||
elif cents:
|
|
||||||
cent_unit = 'cent' if cents == 1 else 'cents'
|
|
||||||
return '%s %s' % (cents, cent_unit)
|
|
||||||
else:
|
|
||||||
return 'zero dollars'
|
|
||||||
|
|
||||||
|
|
||||||
def _expand_ordinal(m):
|
|
||||||
return _inflect.number_to_words(m.group(0))
|
|
||||||
|
|
||||||
|
|
||||||
def _expand_number(m):
|
|
||||||
num = int(m.group(0))
|
|
||||||
if num > 1000 and num < 3000:
|
|
||||||
if num == 2000:
|
|
||||||
return 'two thousand'
|
|
||||||
elif num > 2000 and num < 2010:
|
|
||||||
return 'two thousand ' + _inflect.number_to_words(num % 100)
|
|
||||||
elif num % 100 == 0:
|
|
||||||
return _inflect.number_to_words(num // 100) + ' hundred'
|
|
||||||
else:
|
|
||||||
return _inflect.number_to_words(num, andword='', zero='oh', group=2).replace(', ', ' ')
|
|
||||||
else:
|
|
||||||
return _inflect.number_to_words(num, andword='')
|
|
||||||
|
|
||||||
|
|
||||||
def normalize_numbers(text):
|
|
||||||
text = re.sub(_comma_number_re, _remove_commas, text)
|
|
||||||
text = re.sub(_pounds_re, r'\1 pounds', text)
|
|
||||||
text = re.sub(_dollars_re, _expand_dollars, text)
|
|
||||||
text = re.sub(_decimal_number_re, _expand_decimal_point, text)
|
|
||||||
text = re.sub(_ordinal_re, _expand_ordinal, text)
|
|
||||||
text = re.sub(_number_re, _expand_number, text)
|
|
||||||
return text
|
|
153
PyTorch/SpeechSynthesis/FastPitch/common/text/numerical.py
Normal file
153
PyTorch/SpeechSynthesis/FastPitch/common/text/numerical.py
Normal file
|
@ -0,0 +1,153 @@
|
||||||
|
""" adapted from https://github.com/keithito/tacotron """
|
||||||
|
|
||||||
|
import inflect
|
||||||
|
import re
|
||||||
|
_magnitudes = ['trillion', 'billion', 'million', 'thousand', 'hundred', 'm', 'b', 't']
|
||||||
|
_magnitudes_key = {'m': 'million', 'b': 'billion', 't': 'trillion'}
|
||||||
|
_measurements = '(f|c|k|d|m)'
|
||||||
|
_measurements_key = {'f': 'fahrenheit',
|
||||||
|
'c': 'celsius',
|
||||||
|
'k': 'thousand',
|
||||||
|
'm': 'meters'}
|
||||||
|
_currency_key = {'$': 'dollar', '£': 'pound', '€': 'euro', '₩': 'won'}
|
||||||
|
_inflect = inflect.engine()
|
||||||
|
_comma_number_re = re.compile(r'([0-9][0-9\,]+[0-9])')
|
||||||
|
_decimal_number_re = re.compile(r'([0-9]+\.[0-9]+)')
|
||||||
|
_currency_re = re.compile(r'([\$€£₩])([0-9\.\,]*[0-9]+)(?:[ ]?({})(?=[^a-zA-Z]))?'.format("|".join(_magnitudes)), re.IGNORECASE)
|
||||||
|
_measurement_re = re.compile(r'([0-9\.\,]*[0-9]+(\s)?{}\b)'.format(_measurements), re.IGNORECASE)
|
||||||
|
_ordinal_re = re.compile(r'[0-9]+(st|nd|rd|th)')
|
||||||
|
# _range_re = re.compile(r'(?<=[0-9])+(-)(?=[0-9])+.*?')
|
||||||
|
_roman_re = re.compile(r'\b(?=[MDCLXVI]+\b)M{0,4}(CM|CD|D?C{0,3})(XC|XL|L?X{0,3})(IX|IV|V?I{2,3})\b') # avoid I
|
||||||
|
_multiply_re = re.compile(r'(\b[0-9]+)(x)([0-9]+)')
|
||||||
|
_number_re = re.compile(r"[0-9]+'s|[0-9]+s|[0-9]+")
|
||||||
|
|
||||||
|
def _remove_commas(m):
|
||||||
|
return m.group(1).replace(',', '')
|
||||||
|
|
||||||
|
|
||||||
|
def _expand_decimal_point(m):
|
||||||
|
return m.group(1).replace('.', ' point ')
|
||||||
|
|
||||||
|
|
||||||
|
def _expand_currency(m):
|
||||||
|
currency = _currency_key[m.group(1)]
|
||||||
|
quantity = m.group(2)
|
||||||
|
magnitude = m.group(3)
|
||||||
|
|
||||||
|
# remove commas from quantity to be able to convert to numerical
|
||||||
|
quantity = quantity.replace(',', '')
|
||||||
|
|
||||||
|
# check for million, billion, etc...
|
||||||
|
if magnitude is not None and magnitude.lower() in _magnitudes:
|
||||||
|
if len(magnitude) == 1:
|
||||||
|
magnitude = _magnitudes_key[magnitude.lower()]
|
||||||
|
return "{} {} {}".format(_expand_hundreds(quantity), magnitude, currency+'s')
|
||||||
|
|
||||||
|
parts = quantity.split('.')
|
||||||
|
if len(parts) > 2:
|
||||||
|
return quantity + " " + currency + "s" # Unexpected format
|
||||||
|
|
||||||
|
dollars = int(parts[0]) if parts[0] else 0
|
||||||
|
|
||||||
|
cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0
|
||||||
|
if dollars and cents:
|
||||||
|
dollar_unit = currency if dollars == 1 else currency+'s'
|
||||||
|
cent_unit = 'cent' if cents == 1 else 'cents'
|
||||||
|
return "{} {}, {} {}".format(
|
||||||
|
_expand_hundreds(dollars), dollar_unit,
|
||||||
|
_inflect.number_to_words(cents), cent_unit)
|
||||||
|
elif dollars:
|
||||||
|
dollar_unit = currency if dollars == 1 else currency+'s'
|
||||||
|
return "{} {}".format(_expand_hundreds(dollars), dollar_unit)
|
||||||
|
elif cents:
|
||||||
|
cent_unit = 'cent' if cents == 1 else 'cents'
|
||||||
|
return "{} {}".format(_inflect.number_to_words(cents), cent_unit)
|
||||||
|
else:
|
||||||
|
return 'zero' + ' ' + currency + 's'
|
||||||
|
|
||||||
|
|
||||||
|
def _expand_hundreds(text):
|
||||||
|
number = float(text)
|
||||||
|
if 1000 < number < 10000 and (number % 100 == 0) and (number % 1000 != 0):
|
||||||
|
return _inflect.number_to_words(int(number / 100)) + " hundred"
|
||||||
|
else:
|
||||||
|
return _inflect.number_to_words(text)
|
||||||
|
|
||||||
|
|
||||||
|
def _expand_ordinal(m):
|
||||||
|
return _inflect.number_to_words(m.group(0))
|
||||||
|
|
||||||
|
|
||||||
|
def _expand_measurement(m):
|
||||||
|
_, number, measurement = re.split('(\d+(?:\.\d+)?)', m.group(0))
|
||||||
|
number = _inflect.number_to_words(number)
|
||||||
|
measurement = "".join(measurement.split())
|
||||||
|
measurement = _measurements_key[measurement.lower()]
|
||||||
|
return "{} {}".format(number, measurement)
|
||||||
|
|
||||||
|
|
||||||
|
def _expand_range(m):
|
||||||
|
return ' to '
|
||||||
|
|
||||||
|
|
||||||
|
def _expand_multiply(m):
|
||||||
|
left = m.group(1)
|
||||||
|
right = m.group(3)
|
||||||
|
return "{} by {}".format(left, right)
|
||||||
|
|
||||||
|
|
||||||
|
def _expand_roman(m):
|
||||||
|
# from https://stackoverflow.com/questions/19308177/converting-roman-numerals-to-integers-in-python
|
||||||
|
roman_numerals = {'I':1, 'V':5, 'X':10, 'L':50, 'C':100, 'D':500, 'M':1000}
|
||||||
|
result = 0
|
||||||
|
num = m.group(0)
|
||||||
|
for i, c in enumerate(num):
|
||||||
|
if (i+1) == len(num) or roman_numerals[c] >= roman_numerals[num[i+1]]:
|
||||||
|
result += roman_numerals[c]
|
||||||
|
else:
|
||||||
|
result -= roman_numerals[c]
|
||||||
|
return str(result)
|
||||||
|
|
||||||
|
|
||||||
|
def _expand_number(m):
|
||||||
|
_, number, suffix = re.split(r"(\d+(?:'?\d+)?)", m.group(0))
|
||||||
|
number = int(number)
|
||||||
|
if number > 1000 < 10000 and (number % 100 == 0) and (number % 1000 != 0):
|
||||||
|
text = _inflect.number_to_words(number // 100) + " hundred"
|
||||||
|
elif number > 1000 and number < 3000:
|
||||||
|
if number == 2000:
|
||||||
|
text = 'two thousand'
|
||||||
|
elif number > 2000 and number < 2010:
|
||||||
|
text = 'two thousand ' + _inflect.number_to_words(number % 100)
|
||||||
|
elif number % 100 == 0:
|
||||||
|
text = _inflect.number_to_words(number // 100) + ' hundred'
|
||||||
|
else:
|
||||||
|
number = _inflect.number_to_words(number, andword='', zero='oh', group=2).replace(', ', ' ')
|
||||||
|
number = re.sub(r'-', ' ', number)
|
||||||
|
text = number
|
||||||
|
else:
|
||||||
|
number = _inflect.number_to_words(number, andword='and')
|
||||||
|
number = re.sub(r'-', ' ', number)
|
||||||
|
number = re.sub(r',', '', number)
|
||||||
|
text = number
|
||||||
|
|
||||||
|
if suffix in ("'s", "s"):
|
||||||
|
if text[-1] == 'y':
|
||||||
|
text = text[:-1] + 'ies'
|
||||||
|
else:
|
||||||
|
text = text + suffix
|
||||||
|
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_numbers(text):
|
||||||
|
text = re.sub(_comma_number_re, _remove_commas, text)
|
||||||
|
text = re.sub(_currency_re, _expand_currency, text)
|
||||||
|
text = re.sub(_decimal_number_re, _expand_decimal_point, text)
|
||||||
|
text = re.sub(_ordinal_re, _expand_ordinal, text)
|
||||||
|
# text = re.sub(_range_re, _expand_range, text)
|
||||||
|
# text = re.sub(_measurement_re, _expand_measurement, text)
|
||||||
|
text = re.sub(_roman_re, _expand_roman, text)
|
||||||
|
text = re.sub(_multiply_re, _expand_multiply, text)
|
||||||
|
text = re.sub(_number_re, _expand_number, text)
|
||||||
|
return text
|
|
@ -4,16 +4,41 @@
|
||||||
Defines the set of symbols used in text input to the model.
|
Defines the set of symbols used in text input to the model.
|
||||||
|
|
||||||
The default is a set of ASCII characters that works well for English or text that has been run through Unidecode. For other data, you can modify _characters. See TRAINING_DATA.md for details. '''
|
The default is a set of ASCII characters that works well for English or text that has been run through Unidecode. For other data, you can modify _characters. See TRAINING_DATA.md for details. '''
|
||||||
from common.text import cmudict
|
from .cmudict import valid_symbols
|
||||||
|
|
||||||
_pad = '_'
|
|
||||||
_punctuation = '!\'(),.:;? '
|
|
||||||
_special = '-'
|
|
||||||
_letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'
|
|
||||||
|
|
||||||
# Prepend "@" to ARPAbet symbols to ensure uniqueness (some are the same as uppercase letters):
|
# Prepend "@" to ARPAbet symbols to ensure uniqueness (some are the same as uppercase letters):
|
||||||
_arpabet = ['@' + s for s in cmudict.valid_symbols]
|
_arpabet = ['@' + s for s in valid_symbols]
|
||||||
|
|
||||||
# Export all symbols:
|
|
||||||
symbols = [_pad] + list(_special) + list(_punctuation) + list(_letters) + _arpabet
|
def get_symbols(symbol_set='english_basic'):
|
||||||
pad_idx = 0
|
if symbol_set == 'english_basic':
|
||||||
|
_pad = '_'
|
||||||
|
_punctuation = '!\'(),.:;? '
|
||||||
|
_special = '-'
|
||||||
|
_letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'
|
||||||
|
symbols = list(_pad + _special + _punctuation + _letters) + _arpabet
|
||||||
|
elif symbol_set == 'english_basic_lowercase':
|
||||||
|
_pad = '_'
|
||||||
|
_punctuation = '!\'"(),.:;? '
|
||||||
|
_special = '-'
|
||||||
|
_letters = 'abcdefghijklmnopqrstuvwxyz'
|
||||||
|
symbols = list(_pad + _special + _punctuation + _letters) + _arpabet
|
||||||
|
elif symbol_set == 'english_expanded':
|
||||||
|
_punctuation = '!\'",.:;? '
|
||||||
|
_math = '#%&*+-/[]()'
|
||||||
|
_special = '_@©°½—₩€$'
|
||||||
|
_accented = 'áçéêëñöøćž'
|
||||||
|
_letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'
|
||||||
|
symbols = list(_punctuation + _math + _special + _accented + _letters) + _arpabet
|
||||||
|
else:
|
||||||
|
raise Exception("{} symbol set does not exist".format(symbol_set))
|
||||||
|
|
||||||
|
return symbols
|
||||||
|
|
||||||
|
|
||||||
|
def get_pad_idx(symbol_set='english_basic'):
|
||||||
|
if symbol_set in {'english_basic', 'english_basic_lowercase'}:
|
||||||
|
return 0
|
||||||
|
else:
|
||||||
|
raise Exception("{} symbol set not used yet".format(symbol_set))
|
||||||
|
|
175
PyTorch/SpeechSynthesis/FastPitch/common/text/text_processing.py
Normal file
175
PyTorch/SpeechSynthesis/FastPitch/common/text/text_processing.py
Normal file
|
@ -0,0 +1,175 @@
|
||||||
|
""" adapted from https://github.com/keithito/tacotron """
|
||||||
|
import re
|
||||||
|
import numpy as np
|
||||||
|
from . import cleaners
|
||||||
|
from .symbols import get_symbols
|
||||||
|
from .cmudict import CMUDict
|
||||||
|
from .numerical import _currency_re, _expand_currency
|
||||||
|
|
||||||
|
|
||||||
|
#########
|
||||||
|
# REGEX #
|
||||||
|
#########
|
||||||
|
|
||||||
|
# Regular expression matching text enclosed in curly braces for encoding
|
||||||
|
_curly_re = re.compile(r'(.*?)\{(.+?)\}(.*)')
|
||||||
|
|
||||||
|
# Regular expression matching words and not words
|
||||||
|
_words_re = re.compile(r"([a-zA-ZÀ-ž]+['][a-zA-ZÀ-ž]{1,2}|[a-zA-ZÀ-ž]+)|([{][^}]+[}]|[^a-zA-ZÀ-ž{}]+)")
|
||||||
|
|
||||||
|
# Regular expression separating words enclosed in curly braces for cleaning
|
||||||
|
_arpa_re = re.compile(r'{[^}]+}|\S+')
|
||||||
|
|
||||||
|
|
||||||
|
def lines_to_list(filename):
|
||||||
|
with open(filename, encoding='utf-8') as f:
|
||||||
|
lines = f.readlines()
|
||||||
|
lines = [l.rstrip() for l in lines]
|
||||||
|
return lines
|
||||||
|
|
||||||
|
|
||||||
|
class TextProcessing(object):
|
||||||
|
def __init__(self, symbol_set, cleaner_names, p_arpabet=0.0,
|
||||||
|
handle_arpabet='word', handle_arpabet_ambiguous='ignore',
|
||||||
|
expand_currency=True):
|
||||||
|
self.symbols = get_symbols(symbol_set)
|
||||||
|
self.cleaner_names = cleaner_names
|
||||||
|
|
||||||
|
# Mappings from symbol to numeric ID and vice versa:
|
||||||
|
self.symbol_to_id = {s: i for i, s in enumerate(self.symbols)}
|
||||||
|
self.id_to_symbol = {i: s for i, s in enumerate(self.symbols)}
|
||||||
|
self.expand_currency = expand_currency
|
||||||
|
|
||||||
|
# cmudict
|
||||||
|
self.p_arpabet = p_arpabet
|
||||||
|
self.handle_arpabet = handle_arpabet
|
||||||
|
self.handle_arpabet_ambiguous = handle_arpabet_ambiguous
|
||||||
|
|
||||||
|
|
||||||
|
def text_to_sequence(self, text):
|
||||||
|
sequence = []
|
||||||
|
|
||||||
|
# Check for curly braces and treat their contents as ARPAbet:
|
||||||
|
while len(text):
|
||||||
|
m = _curly_re.match(text)
|
||||||
|
if not m:
|
||||||
|
sequence += self.symbols_to_sequence(text)
|
||||||
|
break
|
||||||
|
sequence += self.symbols_to_sequence(m.group(1))
|
||||||
|
sequence += self.arpabet_to_sequence(m.group(2))
|
||||||
|
text = m.group(3)
|
||||||
|
|
||||||
|
return sequence
|
||||||
|
|
||||||
|
def sequence_to_text(self, sequence):
|
||||||
|
result = ''
|
||||||
|
for symbol_id in sequence:
|
||||||
|
if symbol_id in self.id_to_symbol:
|
||||||
|
s = self.id_to_symbol[symbol_id]
|
||||||
|
# Enclose ARPAbet back in curly braces:
|
||||||
|
if len(s) > 1 and s[0] == '@':
|
||||||
|
s = '{%s}' % s[1:]
|
||||||
|
result += s
|
||||||
|
return result.replace('}{', ' ')
|
||||||
|
|
||||||
|
def clean_text(self, text):
|
||||||
|
for name in self.cleaner_names:
|
||||||
|
cleaner = getattr(cleaners, name)
|
||||||
|
if not cleaner:
|
||||||
|
raise Exception('Unknown cleaner: %s' % name)
|
||||||
|
text = cleaner(text)
|
||||||
|
|
||||||
|
return text
|
||||||
|
|
||||||
|
def symbols_to_sequence(self, symbols):
|
||||||
|
return [self.symbol_to_id[s] for s in symbols if s in self.symbol_to_id]
|
||||||
|
|
||||||
|
def arpabet_to_sequence(self, text):
|
||||||
|
return self.symbols_to_sequence(['@' + s for s in text.split()])
|
||||||
|
|
||||||
|
def get_arpabet(self, word):
|
||||||
|
arpabet_suffix = ''
|
||||||
|
|
||||||
|
if word.lower() in cmudict.heteronyms:
|
||||||
|
return word
|
||||||
|
|
||||||
|
if len(word) > 2 and word.endswith("'s"):
|
||||||
|
arpabet = cmudict.lookup(word)
|
||||||
|
if arpabet is None:
|
||||||
|
arpabet = self.get_arpabet(word[:-2])
|
||||||
|
arpabet_suffix = ' Z'
|
||||||
|
elif len(word) > 1 and word.endswith("s"):
|
||||||
|
arpabet = cmudict.lookup(word)
|
||||||
|
if arpabet is None:
|
||||||
|
arpabet = self.get_arpabet(word[:-1])
|
||||||
|
arpabet_suffix = ' Z'
|
||||||
|
else:
|
||||||
|
arpabet = cmudict.lookup(word)
|
||||||
|
|
||||||
|
if arpabet is None:
|
||||||
|
return word
|
||||||
|
elif arpabet[0] == '{':
|
||||||
|
arpabet = [arpabet[1:-1]]
|
||||||
|
|
||||||
|
if len(arpabet) > 1:
|
||||||
|
if self.handle_arpabet_ambiguous == 'first':
|
||||||
|
arpabet = arpabet[0]
|
||||||
|
elif self.handle_arpabet_ambiguous == 'random':
|
||||||
|
arpabet = np.random.choice(arpabet)
|
||||||
|
elif self.handle_arpabet_ambiguous == 'ignore':
|
||||||
|
return word
|
||||||
|
else:
|
||||||
|
arpabet = arpabet[0]
|
||||||
|
|
||||||
|
arpabet = "{" + arpabet + arpabet_suffix + "}"
|
||||||
|
|
||||||
|
return arpabet
|
||||||
|
|
||||||
|
# def get_characters(self, word):
|
||||||
|
# for name in self.cleaner_names:
|
||||||
|
# cleaner = getattr(cleaners, f'{name}_post_chars')
|
||||||
|
# if not cleaner:
|
||||||
|
# raise Exception('Unknown cleaner: %s' % name)
|
||||||
|
# word = cleaner(word)
|
||||||
|
|
||||||
|
# return word
|
||||||
|
|
||||||
|
def encode_text(self, text, return_all=False):
|
||||||
|
if self.expand_currency:
|
||||||
|
text = re.sub(_currency_re, _expand_currency, text)
|
||||||
|
text_clean = [self.clean_text(split) if split[0] != '{' else split
|
||||||
|
for split in _arpa_re.findall(text)]
|
||||||
|
text_clean = ' '.join(text_clean)
|
||||||
|
text = text_clean
|
||||||
|
|
||||||
|
text_arpabet = ''
|
||||||
|
if self.p_arpabet > 0:
|
||||||
|
if self.handle_arpabet == 'sentence':
|
||||||
|
if np.random.uniform() < self.p_arpabet:
|
||||||
|
words = _words_re.findall(text)
|
||||||
|
text_arpabet = [
|
||||||
|
self.get_arpabet(word[0])
|
||||||
|
if (word[0] != '') else word[1]
|
||||||
|
for word in words]
|
||||||
|
text_arpabet = ''.join(text_arpabet)
|
||||||
|
text = text_arpabet
|
||||||
|
elif self.handle_arpabet == 'word':
|
||||||
|
words = _words_re.findall(text)
|
||||||
|
text_arpabet = [
|
||||||
|
word[1] if word[0] == '' else (
|
||||||
|
self.get_arpabet(word[0])
|
||||||
|
if np.random.uniform() < self.p_arpabet
|
||||||
|
else word[0])
|
||||||
|
for word in words]
|
||||||
|
text_arpabet = ''.join(text_arpabet)
|
||||||
|
text = text_arpabet
|
||||||
|
elif self.handle_arpabet != '':
|
||||||
|
raise Exception("{} handle_arpabet is not supported".format(
|
||||||
|
self.handle_arpabet))
|
||||||
|
|
||||||
|
text_encoded = self.text_to_sequence(text)
|
||||||
|
|
||||||
|
if return_all:
|
||||||
|
return text_encoded, text_clean, text_arpabet
|
||||||
|
|
||||||
|
return text_encoded
|
|
@ -25,7 +25,6 @@
|
||||||
#
|
#
|
||||||
# *****************************************************************************
|
# *****************************************************************************
|
||||||
|
|
||||||
import os
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
@ -48,14 +47,20 @@ def load_wav_to_torch(full_path):
|
||||||
return torch.FloatTensor(data.astype(np.float32)), sampling_rate
|
return torch.FloatTensor(data.astype(np.float32)), sampling_rate
|
||||||
|
|
||||||
|
|
||||||
def load_filepaths_and_text(dataset_path, filename, split="|"):
|
def load_filepaths_and_text(dataset_path, fnames, has_speakers=False, split="|"):
|
||||||
def split_line(root, line):
|
def split_line(root, line):
|
||||||
parts = line.strip().split(split)
|
parts = line.strip().split(split)
|
||||||
paths, text = parts[:-1], parts[-1]
|
if has_speakers:
|
||||||
return tuple(os.path.join(root, p) for p in paths) + (text,)
|
paths, non_paths = parts[:-2], parts[-2:]
|
||||||
with open(filename, encoding='utf-8') as f:
|
else:
|
||||||
filepaths_and_text = [split_line(dataset_path, line) for line in f]
|
paths, non_paths = parts[:-1], parts[-1:]
|
||||||
return filepaths_and_text
|
return tuple(str(Path(root, p)) for p in paths) + tuple(non_paths)
|
||||||
|
|
||||||
|
fpaths_and_text = []
|
||||||
|
for fname in fnames.split(','):
|
||||||
|
with open(fname, encoding='utf-8') as f:
|
||||||
|
fpaths_and_text += [split_line(dataset_path, line) for line in f]
|
||||||
|
return fpaths_and_text
|
||||||
|
|
||||||
|
|
||||||
def stats_filename(dataset_path, filelist_path, feature_name):
|
def stats_filename(dataset_path, filelist_path, feature_name):
|
||||||
|
|
|
@ -40,6 +40,7 @@ from torch.utils.data import DataLoader
|
||||||
from common import utils
|
from common import utils
|
||||||
from inference import load_and_setup_model
|
from inference import load_and_setup_model
|
||||||
from tacotron2.data_function import TextMelLoader, TextMelCollate, batch_to_gpu
|
from tacotron2.data_function import TextMelLoader, TextMelCollate, batch_to_gpu
|
||||||
|
from common.text.text_processing import TextProcessing
|
||||||
|
|
||||||
|
|
||||||
def parse_args(parser):
|
def parse_args(parser):
|
||||||
|
@ -59,6 +60,8 @@ def parse_args(parser):
|
||||||
parser.add_argument('--text-cleaners', nargs='*',
|
parser.add_argument('--text-cleaners', nargs='*',
|
||||||
default=['english_cleaners'], type=str,
|
default=['english_cleaners'], type=str,
|
||||||
help='Type of text cleaners for input text')
|
help='Type of text cleaners for input text')
|
||||||
|
parser.add_argument('--symbol-set', type=str, default='english_basic',
|
||||||
|
help='Define symbol set for input text')
|
||||||
parser.add_argument('--max-wav-value', default=32768.0, type=float,
|
parser.add_argument('--max-wav-value', default=32768.0, type=float,
|
||||||
help='Maximum audiowave value')
|
help='Maximum audiowave value')
|
||||||
parser.add_argument('--sampling-rate', default=22050, type=int,
|
parser.add_argument('--sampling-rate', default=22050, type=int,
|
||||||
|
@ -98,6 +101,7 @@ def parse_args(parser):
|
||||||
class FilenamedLoader(TextMelLoader):
|
class FilenamedLoader(TextMelLoader):
|
||||||
def __init__(self, filenames, *args, **kwargs):
|
def __init__(self, filenames, *args, **kwargs):
|
||||||
super(FilenamedLoader, self).__init__(*args, **kwargs)
|
super(FilenamedLoader, self).__init__(*args, **kwargs)
|
||||||
|
self.tp = TextProcessing(args[-1].symbol_set, args[-1].text_cleaners)
|
||||||
self.filenames = filenames
|
self.filenames = filenames
|
||||||
|
|
||||||
def __getitem__(self, index):
|
def __getitem__(self, index):
|
||||||
|
@ -211,6 +215,8 @@ def main():
|
||||||
|
|
||||||
filenames = [Path(l.split('|')[0]).stem
|
filenames = [Path(l.split('|')[0]).stem
|
||||||
for l in open(args.wav_text_filelist, 'r')]
|
for l in open(args.wav_text_filelist, 'r')]
|
||||||
|
# Compatibility with Tacotron2 Data loader
|
||||||
|
args.n_speakers = 1
|
||||||
dataset = FilenamedLoader(filenames, args.dataset_path, args.wav_text_filelist,
|
dataset = FilenamedLoader(filenames, args.dataset_path, args.wav_text_filelist,
|
||||||
args, load_mel_from_disk=False)
|
args, load_mel_from_disk=False)
|
||||||
# TextMelCollate supports only n_frames_per_step=1
|
# TextMelCollate supports only n_frames_per_step=1
|
||||||
|
|
|
@ -27,8 +27,6 @@
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
from common.text import symbols
|
|
||||||
|
|
||||||
|
|
||||||
def parse_fastpitch_args(parent, add_help=False):
|
def parse_fastpitch_args(parent, add_help=False):
|
||||||
"""
|
"""
|
||||||
|
@ -42,11 +40,12 @@ def parse_fastpitch_args(parent, add_help=False):
|
||||||
help='Number of bins in mel-spectrograms')
|
help='Number of bins in mel-spectrograms')
|
||||||
io.add_argument('--max-seq-len', default=2048, type=int,
|
io.add_argument('--max-seq-len', default=2048, type=int,
|
||||||
help='')
|
help='')
|
||||||
global symbols
|
|
||||||
len_symbols = len(symbols)
|
|
||||||
symbols = parser.add_argument_group('symbols parameters')
|
symbols = parser.add_argument_group('symbols parameters')
|
||||||
symbols.add_argument('--n-symbols', default=len_symbols, type=int,
|
symbols.add_argument('--n-symbols', default=148, type=int,
|
||||||
help='Number of symbols in dictionary')
|
help='Number of symbols in dictionary')
|
||||||
|
symbols.add_argument('--padding-idx', default=0, type=int,
|
||||||
|
help='Index of padding symbol in dictionary')
|
||||||
symbols.add_argument('--symbols-embedding-dim', default=384, type=int,
|
symbols.add_argument('--symbols-embedding-dim', default=384, type=int,
|
||||||
help='Input embedding dimension')
|
help='Input embedding dimension')
|
||||||
|
|
||||||
|
@ -102,11 +101,18 @@ def parse_fastpitch_args(parent, add_help=False):
|
||||||
|
|
||||||
pitch_pred = parser.add_argument_group('pitch predictor parameters')
|
pitch_pred = parser.add_argument_group('pitch predictor parameters')
|
||||||
pitch_pred.add_argument('--pitch-predictor-kernel-size', default=3, type=int,
|
pitch_pred.add_argument('--pitch-predictor-kernel-size', default=3, type=int,
|
||||||
help='Pitch predictor conv-1D kernel size')
|
help='Pitch predictor conv-1D kernel size')
|
||||||
pitch_pred.add_argument('--pitch-predictor-filter-size', default=256, type=int,
|
pitch_pred.add_argument('--pitch-predictor-filter-size', default=256, type=int,
|
||||||
help='Pitch predictor conv-1D filter size')
|
help='Pitch predictor conv-1D filter size')
|
||||||
pitch_pred.add_argument('--p-pitch-predictor-dropout', default=0.1, type=float,
|
pitch_pred.add_argument('--p-pitch-predictor-dropout', default=0.1, type=float,
|
||||||
help='Pitch probability for pitch predictor')
|
help='Pitch probability for pitch predictor')
|
||||||
pitch_pred.add_argument('--pitch-predictor-n-layers', default=2, type=int,
|
pitch_pred.add_argument('--pitch-predictor-n-layers', default=2, type=int,
|
||||||
help='Number of conv-1D layers')
|
help='Number of conv-1D layers')
|
||||||
|
|
||||||
|
cond = parser.add_argument_group('conditioning parameters')
|
||||||
|
cond.add_argument('--pitch-embedding-kernel-size', default=3, type=int,
|
||||||
|
help='Pitch embedding conv-1D kernel size')
|
||||||
|
cond.add_argument('--speaker-emb-weight', type=float, default=1.0,
|
||||||
|
help='Scale speaker embedding')
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
|
@ -31,6 +31,7 @@ import torch
|
||||||
|
|
||||||
from common.utils import to_gpu
|
from common.utils import to_gpu
|
||||||
from tacotron2.data_function import TextMelLoader
|
from tacotron2.data_function import TextMelLoader
|
||||||
|
from common.text.text_processing import TextProcessing
|
||||||
|
|
||||||
|
|
||||||
class TextMelAliLoader(TextMelLoader):
|
class TextMelAliLoader(TextMelLoader):
|
||||||
|
@ -38,18 +39,27 @@ class TextMelAliLoader(TextMelLoader):
|
||||||
"""
|
"""
|
||||||
def __init__(self, *args):
|
def __init__(self, *args):
|
||||||
super(TextMelAliLoader, self).__init__(*args)
|
super(TextMelAliLoader, self).__init__(*args)
|
||||||
if len(self.audiopaths_and_text[0]) != 4:
|
self.tp = TextProcessing(args[-1].symbol_set, args[-1].text_cleaners)
|
||||||
raise ValueError('Expected four columns in audiopaths file')
|
self.n_speakers = args[-1].n_speakers
|
||||||
|
if len(self.audiopaths_and_text[0]) != 4 + (args[-1].n_speakers > 1):
|
||||||
|
raise ValueError('Expected four columns in audiopaths file for single speaker model. \n'
|
||||||
|
'For multispeaker model, the filelist format is '
|
||||||
|
'<mel>|<dur>|<pitch>|<text>|<speaker_id>')
|
||||||
|
|
||||||
def __getitem__(self, index):
|
def __getitem__(self, index):
|
||||||
# separate filename and text
|
# separate filename and text
|
||||||
audiopath, durpath, pitchpath, text = self.audiopaths_and_text[index]
|
if self.n_speakers > 1:
|
||||||
|
audiopath, durpath, pitchpath, text, speaker = self.audiopaths_and_text[index]
|
||||||
|
speaker = int(speaker)
|
||||||
|
else:
|
||||||
|
audiopath, durpath, pitchpath, text = self.audiopaths_and_text[index]
|
||||||
|
speaker = None
|
||||||
len_text = len(text)
|
len_text = len(text)
|
||||||
text = self.get_text(text)
|
text = self.get_text(text)
|
||||||
mel = self.get_mel(audiopath)
|
mel = self.get_mel(audiopath)
|
||||||
dur = torch.load(durpath)
|
dur = torch.load(durpath)
|
||||||
pitch = torch.load(pitchpath)
|
pitch = torch.load(pitchpath)
|
||||||
return (text, mel, len_text, dur, pitch)
|
return (text, mel, len_text, dur, pitch, speaker)
|
||||||
|
|
||||||
|
|
||||||
class TextMelAliCollate():
|
class TextMelAliCollate():
|
||||||
|
@ -107,16 +117,24 @@ class TextMelAliCollate():
|
||||||
pitch = batch[ids_sorted_decreasing[i]][4]
|
pitch = batch[ids_sorted_decreasing[i]][4]
|
||||||
pitch_padded[i, :pitch.shape[0]] = pitch
|
pitch_padded[i, :pitch.shape[0]] = pitch
|
||||||
|
|
||||||
|
if batch[0][5] is not None:
|
||||||
|
speaker = torch.zeros_like(input_lengths)
|
||||||
|
for i in range(len(ids_sorted_decreasing)):
|
||||||
|
speaker[i] = batch[ids_sorted_decreasing[i]][5]
|
||||||
|
else:
|
||||||
|
speaker = None
|
||||||
|
|
||||||
# count number of items - characters in text
|
# count number of items - characters in text
|
||||||
len_x = [x[2] for x in batch]
|
len_x = [x[2] for x in batch]
|
||||||
len_x = torch.Tensor(len_x)
|
len_x = torch.Tensor(len_x)
|
||||||
return (text_padded, input_lengths, mel_padded,
|
|
||||||
output_lengths, len_x, dur_padded, dur_lens, pitch_padded)
|
return (text_padded, input_lengths, mel_padded, output_lengths,
|
||||||
|
len_x, dur_padded, dur_lens, pitch_padded, speaker)
|
||||||
|
|
||||||
|
|
||||||
def batch_to_gpu(batch):
|
def batch_to_gpu(batch):
|
||||||
text_padded, input_lengths, mel_padded, \
|
text_padded, input_lengths, mel_padded, output_lengths, \
|
||||||
output_lengths, len_x, dur_padded, dur_lens, pitch_padded = batch
|
len_x, dur_padded, dur_lens, pitch_padded, speaker = batch
|
||||||
text_padded = to_gpu(text_padded).long()
|
text_padded = to_gpu(text_padded).long()
|
||||||
input_lengths = to_gpu(input_lengths).long()
|
input_lengths = to_gpu(input_lengths).long()
|
||||||
mel_padded = to_gpu(mel_padded).float()
|
mel_padded = to_gpu(mel_padded).float()
|
||||||
|
@ -124,9 +142,11 @@ def batch_to_gpu(batch):
|
||||||
dur_padded = to_gpu(dur_padded).long()
|
dur_padded = to_gpu(dur_padded).long()
|
||||||
dur_lens = to_gpu(dur_lens).long()
|
dur_lens = to_gpu(dur_lens).long()
|
||||||
pitch_padded = to_gpu(pitch_padded).float()
|
pitch_padded = to_gpu(pitch_padded).float()
|
||||||
|
if speaker is not None:
|
||||||
|
speaker = to_gpu(speaker).long()
|
||||||
# Alignments act as both inputs and targets - pass shallow copies
|
# Alignments act as both inputs and targets - pass shallow copies
|
||||||
x = [text_padded, input_lengths, mel_padded, output_lengths,
|
x = [text_padded, input_lengths, mel_padded, output_lengths,
|
||||||
dur_padded, dur_lens, pitch_padded]
|
dur_padded, dur_lens, pitch_padded, speaker]
|
||||||
y = [mel_padded, dur_padded, dur_lens, pitch_padded]
|
y = [mel_padded, dur_padded, dur_lens, pitch_padded]
|
||||||
len_x = torch.sum(output_lengths)
|
len_x = torch.sum(output_lengths)
|
||||||
return (x, y, len_x)
|
return (x, y, len_x)
|
||||||
|
|
|
@ -73,7 +73,7 @@ class FastPitchLoss(nn.Module):
|
||||||
'mel_loss': mel_loss.clone().detach(),
|
'mel_loss': mel_loss.clone().detach(),
|
||||||
'duration_predictor_loss': dur_pred_loss.clone().detach(),
|
'duration_predictor_loss': dur_pred_loss.clone().detach(),
|
||||||
'pitch_loss': pitch_loss.clone().detach(),
|
'pitch_loss': pitch_loss.clone().detach(),
|
||||||
'dur_error': (torch.abs(dur_pred - dur_tgt).sum()
|
'dur_error': (torch.abs(dur_pred - dur_tgt).sum()
|
||||||
/ dur_mask.sum()).detach(),
|
/ dur_mask.sum()).detach(),
|
||||||
}
|
}
|
||||||
assert meta_agg in ('sum', 'mean')
|
assert meta_agg in ('sum', 'mean')
|
||||||
|
|
|
@ -1,210 +1,236 @@
|
||||||
# *****************************************************************************
|
# *****************************************************************************
|
||||||
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
||||||
#
|
#
|
||||||
# Redistribution and use in source and binary forms, with or without
|
# Redistribution and use in source and binary forms, with or without
|
||||||
# modification, are permitted provided that the following conditions are met:
|
# modification, are permitted provided that the following conditions are met:
|
||||||
# * Redistributions of source code must retain the above copyright
|
# * Redistributions of source code must retain the above copyright
|
||||||
# notice, this list of conditions and the following disclaimer.
|
# notice, this list of conditions and the following disclaimer.
|
||||||
# * Redistributions in binary form must reproduce the above copyright
|
# * Redistributions in binary form must reproduce the above copyright
|
||||||
# notice, this list of conditions and the following disclaimer in the
|
# notice, this list of conditions and the following disclaimer in the
|
||||||
# documentation and/or other materials provided with the distribution.
|
# documentation and/or other materials provided with the distribution.
|
||||||
# * Neither the name of the NVIDIA CORPORATION nor the
|
# * Neither the name of the NVIDIA CORPORATION nor the
|
||||||
# names of its contributors may be used to endorse or promote products
|
# names of its contributors may be used to endorse or promote products
|
||||||
# derived from this software without specific prior written permission.
|
# derived from this software without specific prior written permission.
|
||||||
#
|
#
|
||||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
||||||
# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
||||||
# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||||
# DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
|
# DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
|
||||||
# DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
# DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
||||||
# (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
# (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
||||||
# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
|
# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
|
||||||
# ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
# ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||||
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||||
# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||||
#
|
#
|
||||||
# *****************************************************************************
|
# *****************************************************************************
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn as nn
|
from torch import nn as nn
|
||||||
from torch.nn.utils.rnn import pad_sequence
|
from torch.nn.utils.rnn import pad_sequence
|
||||||
|
|
||||||
from common.layers import ConvReLUNorm
|
from common.layers import ConvReLUNorm
|
||||||
from common.utils import mask_from_lens
|
from common.utils import mask_from_lens
|
||||||
from fastpitch.transformer import FFTransformer
|
from fastpitch.transformer import FFTransformer
|
||||||
|
|
||||||
|
|
||||||
def regulate_len(durations, enc_out, pace=1.0, mel_max_len=None):
|
def regulate_len(durations, enc_out, pace=1.0, mel_max_len=None):
|
||||||
"""If target=None, then predicted durations are applied"""
|
"""If target=None, then predicted durations are applied"""
|
||||||
reps = torch.round(durations.float() / pace).long()
|
reps = torch.round(durations.float() / pace).long()
|
||||||
dec_lens = reps.sum(dim=1)
|
dec_lens = reps.sum(dim=1)
|
||||||
|
|
||||||
enc_rep = pad_sequence([torch.repeat_interleave(o, r, dim=0)
|
enc_rep = pad_sequence([torch.repeat_interleave(o, r, dim=0)
|
||||||
for o, r in zip(enc_out, reps)],
|
for o, r in zip(enc_out, reps)],
|
||||||
batch_first=True)
|
batch_first=True)
|
||||||
if mel_max_len:
|
if mel_max_len:
|
||||||
enc_rep = enc_rep[:, :mel_max_len]
|
enc_rep = enc_rep[:, :mel_max_len]
|
||||||
dec_lens = torch.clamp_max(dec_lens, mel_max_len)
|
dec_lens = torch.clamp_max(dec_lens, mel_max_len)
|
||||||
return enc_rep, dec_lens
|
return enc_rep, dec_lens
|
||||||
|
|
||||||
|
|
||||||
class TemporalPredictor(nn.Module):
|
class TemporalPredictor(nn.Module):
|
||||||
"""Predicts a single float per each temporal location"""
|
"""Predicts a single float per each temporal location"""
|
||||||
|
|
||||||
def __init__(self, input_size, filter_size, kernel_size, dropout,
|
def __init__(self, input_size, filter_size, kernel_size, dropout,
|
||||||
n_layers=2):
|
n_layers=2):
|
||||||
super(TemporalPredictor, self).__init__()
|
super(TemporalPredictor, self).__init__()
|
||||||
|
|
||||||
self.layers = nn.Sequential(*[
|
self.layers = nn.Sequential(*[
|
||||||
ConvReLUNorm(input_size if i == 0 else filter_size, filter_size,
|
ConvReLUNorm(input_size if i == 0 else filter_size, filter_size,
|
||||||
kernel_size=kernel_size, dropout=dropout)
|
kernel_size=kernel_size, dropout=dropout)
|
||||||
for i in range(n_layers)]
|
for i in range(n_layers)]
|
||||||
)
|
)
|
||||||
self.fc = nn.Linear(filter_size, 1, bias=True)
|
self.fc = nn.Linear(filter_size, 1, bias=True)
|
||||||
|
|
||||||
def forward(self, enc_out, enc_out_mask):
|
def forward(self, enc_out, enc_out_mask):
|
||||||
out = enc_out * enc_out_mask
|
out = enc_out * enc_out_mask
|
||||||
out = self.layers(out.transpose(1, 2)).transpose(1, 2)
|
out = self.layers(out.transpose(1, 2)).transpose(1, 2)
|
||||||
out = self.fc(out) * enc_out_mask
|
out = self.fc(out) * enc_out_mask
|
||||||
return out.squeeze(-1)
|
return out.squeeze(-1)
|
||||||
|
|
||||||
|
|
||||||
class FastPitch(nn.Module):
|
class FastPitch(nn.Module):
|
||||||
def __init__(self, n_mel_channels, max_seq_len, n_symbols,
|
def __init__(self, n_mel_channels, max_seq_len, n_symbols, padding_idx,
|
||||||
symbols_embedding_dim, in_fft_n_layers, in_fft_n_heads,
|
symbols_embedding_dim, in_fft_n_layers, in_fft_n_heads,
|
||||||
in_fft_d_head,
|
in_fft_d_head,
|
||||||
in_fft_conv1d_kernel_size, in_fft_conv1d_filter_size,
|
in_fft_conv1d_kernel_size, in_fft_conv1d_filter_size,
|
||||||
in_fft_output_size,
|
in_fft_output_size,
|
||||||
p_in_fft_dropout, p_in_fft_dropatt, p_in_fft_dropemb,
|
p_in_fft_dropout, p_in_fft_dropatt, p_in_fft_dropemb,
|
||||||
out_fft_n_layers, out_fft_n_heads, out_fft_d_head,
|
out_fft_n_layers, out_fft_n_heads, out_fft_d_head,
|
||||||
out_fft_conv1d_kernel_size, out_fft_conv1d_filter_size,
|
out_fft_conv1d_kernel_size, out_fft_conv1d_filter_size,
|
||||||
out_fft_output_size,
|
out_fft_output_size,
|
||||||
p_out_fft_dropout, p_out_fft_dropatt, p_out_fft_dropemb,
|
p_out_fft_dropout, p_out_fft_dropatt, p_out_fft_dropemb,
|
||||||
dur_predictor_kernel_size, dur_predictor_filter_size,
|
dur_predictor_kernel_size, dur_predictor_filter_size,
|
||||||
p_dur_predictor_dropout, dur_predictor_n_layers,
|
p_dur_predictor_dropout, dur_predictor_n_layers,
|
||||||
pitch_predictor_kernel_size, pitch_predictor_filter_size,
|
pitch_predictor_kernel_size, pitch_predictor_filter_size,
|
||||||
p_pitch_predictor_dropout, pitch_predictor_n_layers):
|
p_pitch_predictor_dropout, pitch_predictor_n_layers,
|
||||||
super(FastPitch, self).__init__()
|
pitch_embedding_kernel_size, n_speakers, speaker_emb_weight):
|
||||||
del max_seq_len # unused
|
super(FastPitch, self).__init__()
|
||||||
del n_symbols
|
del max_seq_len # unused
|
||||||
|
|
||||||
self.encoder = FFTransformer(
|
self.encoder = FFTransformer(
|
||||||
n_layer=in_fft_n_layers, n_head=in_fft_n_heads,
|
n_layer=in_fft_n_layers, n_head=in_fft_n_heads,
|
||||||
d_model=symbols_embedding_dim,
|
d_model=symbols_embedding_dim,
|
||||||
d_head=in_fft_d_head,
|
d_head=in_fft_d_head,
|
||||||
d_inner=in_fft_conv1d_filter_size,
|
d_inner=in_fft_conv1d_filter_size,
|
||||||
kernel_size=in_fft_conv1d_kernel_size,
|
kernel_size=in_fft_conv1d_kernel_size,
|
||||||
dropout=p_in_fft_dropout,
|
dropout=p_in_fft_dropout,
|
||||||
dropatt=p_in_fft_dropatt,
|
dropatt=p_in_fft_dropatt,
|
||||||
dropemb=p_in_fft_dropemb,
|
dropemb=p_in_fft_dropemb,
|
||||||
d_embed=symbols_embedding_dim,
|
embed_input=True,
|
||||||
embed_input=True)
|
d_embed=symbols_embedding_dim,
|
||||||
|
n_embed=n_symbols,
|
||||||
self.duration_predictor = TemporalPredictor(
|
padding_idx=padding_idx)
|
||||||
in_fft_output_size,
|
|
||||||
filter_size=dur_predictor_filter_size,
|
if n_speakers > 1:
|
||||||
kernel_size=dur_predictor_kernel_size,
|
self.speaker_emb = nn.Embedding(n_speakers, symbols_embedding_dim)
|
||||||
dropout=p_dur_predictor_dropout, n_layers=dur_predictor_n_layers
|
else:
|
||||||
)
|
self.speaker_emb = None
|
||||||
|
self.speaker_emb_weight = speaker_emb_weight
|
||||||
self.decoder = FFTransformer(
|
|
||||||
n_layer=out_fft_n_layers, n_head=out_fft_n_heads,
|
self.duration_predictor = TemporalPredictor(
|
||||||
d_model=symbols_embedding_dim,
|
in_fft_output_size,
|
||||||
d_head=out_fft_d_head,
|
filter_size=dur_predictor_filter_size,
|
||||||
d_inner=out_fft_conv1d_filter_size,
|
kernel_size=dur_predictor_kernel_size,
|
||||||
kernel_size=out_fft_conv1d_kernel_size,
|
dropout=p_dur_predictor_dropout, n_layers=dur_predictor_n_layers
|
||||||
dropout=p_out_fft_dropout,
|
)
|
||||||
dropatt=p_out_fft_dropatt,
|
|
||||||
dropemb=p_out_fft_dropemb,
|
self.decoder = FFTransformer(
|
||||||
d_embed=symbols_embedding_dim,
|
n_layer=out_fft_n_layers, n_head=out_fft_n_heads,
|
||||||
embed_input=False)
|
d_model=symbols_embedding_dim,
|
||||||
|
d_head=out_fft_d_head,
|
||||||
self.pitch_predictor = TemporalPredictor(
|
d_inner=out_fft_conv1d_filter_size,
|
||||||
in_fft_output_size,
|
kernel_size=out_fft_conv1d_kernel_size,
|
||||||
filter_size=pitch_predictor_filter_size,
|
dropout=p_out_fft_dropout,
|
||||||
kernel_size=pitch_predictor_kernel_size,
|
dropatt=p_out_fft_dropatt,
|
||||||
dropout=p_pitch_predictor_dropout, n_layers=pitch_predictor_n_layers
|
dropemb=p_out_fft_dropemb,
|
||||||
)
|
embed_input=False,
|
||||||
self.pitch_emb = nn.Conv1d(1, symbols_embedding_dim, kernel_size=3,
|
d_embed=symbols_embedding_dim
|
||||||
padding=1)
|
)
|
||||||
|
|
||||||
# Store values precomputed for training data within the model
|
self.pitch_predictor = TemporalPredictor(
|
||||||
self.register_buffer('pitch_mean', torch.zeros(1))
|
in_fft_output_size,
|
||||||
self.register_buffer('pitch_std', torch.zeros(1))
|
filter_size=pitch_predictor_filter_size,
|
||||||
|
kernel_size=pitch_predictor_kernel_size,
|
||||||
self.proj = nn.Linear(out_fft_output_size, n_mel_channels, bias=True)
|
dropout=p_pitch_predictor_dropout, n_layers=pitch_predictor_n_layers
|
||||||
|
)
|
||||||
def forward(self, inputs, use_gt_durations=True, use_gt_pitch=True,
|
|
||||||
pace=1.0, max_duration=75):
|
self.pitch_emb = nn.Conv1d(
|
||||||
inputs, _, mel_tgt, _, dur_tgt, _, pitch_tgt = inputs
|
1, symbols_embedding_dim,
|
||||||
mel_max_len = mel_tgt.size(2)
|
kernel_size=pitch_embedding_kernel_size,
|
||||||
|
padding=int((pitch_embedding_kernel_size - 1) / 2))
|
||||||
# Input FFT
|
|
||||||
enc_out, enc_mask = self.encoder(inputs)
|
# Store values precomputed for training data within the model
|
||||||
|
self.register_buffer('pitch_mean', torch.zeros(1))
|
||||||
# Embedded for predictors
|
self.register_buffer('pitch_std', torch.zeros(1))
|
||||||
pred_enc_out, pred_enc_mask = enc_out, enc_mask
|
|
||||||
|
self.proj = nn.Linear(out_fft_output_size, n_mel_channels, bias=True)
|
||||||
# Predict durations
|
|
||||||
log_dur_pred = self.duration_predictor(pred_enc_out, pred_enc_mask)
|
def forward(self, inputs, use_gt_durations=True, use_gt_pitch=True,
|
||||||
dur_pred = torch.clamp(torch.exp(log_dur_pred) - 1, 0, max_duration)
|
pace=1.0, max_duration=75):
|
||||||
|
inputs, _, mel_tgt, _, dur_tgt, _, pitch_tgt, speaker = inputs
|
||||||
# Predict pitch
|
mel_max_len = mel_tgt.size(2)
|
||||||
pitch_pred = self.pitch_predictor(enc_out, enc_mask)
|
|
||||||
|
# Calculate speaker embedding
|
||||||
if use_gt_pitch and pitch_tgt is not None:
|
if self.speaker_emb is None:
|
||||||
pitch_emb = self.pitch_emb(pitch_tgt.unsqueeze(1))
|
spk_emb = 0
|
||||||
else:
|
else:
|
||||||
pitch_emb = self.pitch_emb(pitch_pred.unsqueeze(1))
|
spk_emb = self.speaker_emb(speaker).unsqueeze(1)
|
||||||
enc_out = enc_out + pitch_emb.transpose(1, 2)
|
spk_emb.mul_(self.speaker_emb_weight)
|
||||||
|
|
||||||
len_regulated, dec_lens = regulate_len(
|
# Input FFT
|
||||||
dur_tgt if use_gt_durations else dur_pred,
|
enc_out, enc_mask = self.encoder(inputs, conditioning=spk_emb)
|
||||||
enc_out, pace, mel_max_len)
|
|
||||||
|
# Embedded for predictors
|
||||||
# Output FFT
|
pred_enc_out, pred_enc_mask = enc_out, enc_mask
|
||||||
dec_out, dec_mask = self.decoder(len_regulated, dec_lens)
|
|
||||||
mel_out = self.proj(dec_out)
|
# Predict durations
|
||||||
return mel_out, dec_mask, dur_pred, log_dur_pred, pitch_pred
|
log_dur_pred = self.duration_predictor(pred_enc_out, pred_enc_mask)
|
||||||
|
dur_pred = torch.clamp(torch.exp(log_dur_pred) - 1, 0, max_duration)
|
||||||
def infer(self, inputs, input_lens, pace=1.0, dur_tgt=None, pitch_tgt=None,
|
|
||||||
pitch_transform=None, max_duration=75):
|
# Predict pitch
|
||||||
del input_lens # unused
|
pitch_pred = self.pitch_predictor(enc_out, enc_mask)
|
||||||
|
|
||||||
# Input FFT
|
if use_gt_pitch and pitch_tgt is not None:
|
||||||
enc_out, enc_mask = self.encoder(inputs)
|
pitch_emb = self.pitch_emb(pitch_tgt.unsqueeze(1))
|
||||||
|
else:
|
||||||
# Embedded for predictors
|
pitch_emb = self.pitch_emb(pitch_pred.unsqueeze(1))
|
||||||
pred_enc_out, pred_enc_mask = enc_out, enc_mask
|
enc_out = enc_out + pitch_emb.transpose(1, 2)
|
||||||
|
|
||||||
# Predict durations
|
len_regulated, dec_lens = regulate_len(
|
||||||
log_dur_pred = self.duration_predictor(pred_enc_out, pred_enc_mask)
|
dur_tgt if use_gt_durations else dur_pred,
|
||||||
dur_pred = torch.clamp(torch.exp(log_dur_pred) - 1, 0, max_duration)
|
enc_out, pace, mel_max_len)
|
||||||
|
|
||||||
# Pitch over chars
|
# Output FFT
|
||||||
pitch_pred = self.pitch_predictor(enc_out, enc_mask)
|
dec_out, dec_mask = self.decoder(len_regulated, dec_lens)
|
||||||
|
mel_out = self.proj(dec_out)
|
||||||
if pitch_transform is not None:
|
return mel_out, dec_mask, dur_pred, log_dur_pred, pitch_pred
|
||||||
if self.pitch_std[0] == 0.0:
|
|
||||||
# XXX LJSpeech-1.1 defaults
|
def infer(self, inputs, input_lens, pace=1.0, dur_tgt=None, pitch_tgt=None,
|
||||||
mean, std = 218.14, 67.24
|
pitch_transform=None, max_duration=75, speaker=0):
|
||||||
else:
|
del input_lens # unused
|
||||||
mean, std = self.pitch_mean[0], self.pitch_std[0]
|
|
||||||
pitch_pred = pitch_transform(pitch_pred, enc_mask.sum(dim=(1,2)), mean, std)
|
if self.speaker_emb is None:
|
||||||
|
spk_emb = 0
|
||||||
if pitch_tgt is None:
|
else:
|
||||||
pitch_emb = self.pitch_emb(pitch_pred.unsqueeze(1)).transpose(1, 2)
|
speaker = torch.ones(inputs.size(0)).long().to(inputs.device) * speaker
|
||||||
else:
|
spk_emb = self.speaker_emb(speaker).unsqueeze(1)
|
||||||
pitch_emb = self.pitch_emb(pitch_tgt.unsqueeze(1)).transpose(1, 2)
|
spk_emb.mul_(self.speaker_emb_weight)
|
||||||
|
|
||||||
enc_out = enc_out + pitch_emb
|
# Input FFT
|
||||||
|
enc_out, enc_mask = self.encoder(inputs, conditioning=spk_emb)
|
||||||
len_regulated, dec_lens = regulate_len(
|
|
||||||
dur_pred if dur_tgt is None else dur_tgt,
|
# Embedded for predictors
|
||||||
enc_out, pace, mel_max_len=None)
|
pred_enc_out, pred_enc_mask = enc_out, enc_mask
|
||||||
|
|
||||||
dec_out, dec_mask = self.decoder(len_regulated, dec_lens)
|
# Predict durations
|
||||||
mel_out = self.proj(dec_out)
|
log_dur_pred = self.duration_predictor(pred_enc_out, pred_enc_mask)
|
||||||
# mel_lens = dec_mask.squeeze(2).sum(axis=1).long()
|
dur_pred = torch.clamp(torch.exp(log_dur_pred) - 1, 0, max_duration)
|
||||||
mel_out = mel_out.permute(0, 2, 1) # For inference.py
|
|
||||||
return mel_out, dec_lens, dur_pred, pitch_pred
|
# Pitch over chars
|
||||||
|
pitch_pred = self.pitch_predictor(enc_out, enc_mask)
|
||||||
|
|
||||||
|
if pitch_transform is not None:
|
||||||
|
if self.pitch_std[0] == 0.0:
|
||||||
|
# XXX LJSpeech-1.1 defaults
|
||||||
|
mean, std = 218.14, 67.24
|
||||||
|
else:
|
||||||
|
mean, std = self.pitch_mean[0], self.pitch_std[0]
|
||||||
|
pitch_pred = pitch_transform(pitch_pred, enc_mask.sum(dim=(1,2)), mean, std)
|
||||||
|
|
||||||
|
if pitch_tgt is None:
|
||||||
|
pitch_emb = self.pitch_emb(pitch_pred.unsqueeze(1)).transpose(1, 2)
|
||||||
|
else:
|
||||||
|
pitch_emb = self.pitch_emb(pitch_tgt.unsqueeze(1)).transpose(1, 2)
|
||||||
|
|
||||||
|
enc_out = enc_out + pitch_emb
|
||||||
|
|
||||||
|
len_regulated, dec_lens = regulate_len(
|
||||||
|
dur_pred if dur_tgt is None else dur_tgt,
|
||||||
|
enc_out, pace, mel_max_len=None)
|
||||||
|
|
||||||
|
dec_out, dec_mask = self.decoder(len_regulated, dec_lens)
|
||||||
|
mel_out = self.proj(dec_out)
|
||||||
|
# mel_lens = dec_mask.squeeze(2).sum(axis=1).long()
|
||||||
|
mel_out = mel_out.permute(0, 2, 1) # For inference.py
|
||||||
|
return mel_out, dec_lens, dur_pred, pitch_pred
|
||||||
|
|
|
@ -1,218 +1,246 @@
|
||||||
# *****************************************************************************
|
# *****************************************************************************
|
||||||
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
||||||
#
|
#
|
||||||
# Redistribution and use in source and binary forms, with or without
|
# Redistribution and use in source and binary forms, with or without
|
||||||
# modification, are permitted provided that the following conditions are met:
|
# modification, are permitted provided that the following conditions are met:
|
||||||
# * Redistributions of source code must retain the above copyright
|
# * Redistributions of source code must retain the above copyright
|
||||||
# notice, this list of conditions and the following disclaimer.
|
# notice, this list of conditions and the following disclaimer.
|
||||||
# * Redistributions in binary form must reproduce the above copyright
|
# * Redistributions in binary form must reproduce the above copyright
|
||||||
# notice, this list of conditions and the following disclaimer in the
|
# notice, this list of conditions and the following disclaimer in the
|
||||||
# documentation and/or other materials provided with the distribution.
|
# documentation and/or other materials provided with the distribution.
|
||||||
# * Neither the name of the NVIDIA CORPORATION nor the
|
# * Neither the name of the NVIDIA CORPORATION nor the
|
||||||
# names of its contributors may be used to endorse or promote products
|
# names of its contributors may be used to endorse or promote products
|
||||||
# derived from this software without specific prior written permission.
|
# derived from this software without specific prior written permission.
|
||||||
#
|
#
|
||||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
||||||
# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
||||||
# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||||
# DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
|
# DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
|
||||||
# DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
# DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
||||||
# (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
# (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
||||||
# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
|
# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
|
||||||
# ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
# ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||||
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||||
# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||||
#
|
#
|
||||||
# *****************************************************************************
|
# *****************************************************************************
|
||||||
|
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn as nn
|
from torch import nn as nn
|
||||||
from torch.nn.utils.rnn import pad_sequence
|
from torch.nn.utils.rnn import pad_sequence
|
||||||
|
|
||||||
from common.layers import ConvReLUNorm
|
from common.layers import ConvReLUNorm
|
||||||
from fastpitch.transformer_jit import FFTransformer
|
from fastpitch.transformer_jit import FFTransformer
|
||||||
|
|
||||||
|
|
||||||
def regulate_len(durations, enc_out, pace: float = 1.0,
|
def regulate_len(durations, enc_out, pace: float = 1.0,
|
||||||
mel_max_len: Optional[int] = None):
|
mel_max_len: Optional[int] = None):
|
||||||
"""If target=None, then predicted durations are applied"""
|
"""If target=None, then predicted durations are applied"""
|
||||||
reps = torch.round(durations.float() / pace).long()
|
reps = torch.round(durations.float() / pace).long()
|
||||||
dec_lens = reps.sum(dim=1)
|
dec_lens = reps.sum(dim=1)
|
||||||
|
|
||||||
max_len = dec_lens.max()
|
max_len = dec_lens.max()
|
||||||
bsz, _, hid = enc_out.size()
|
bsz, _, hid = enc_out.size()
|
||||||
|
|
||||||
reps_padded = torch.cat([reps, (max_len - dec_lens)[:, None]], dim=1)
|
reps_padded = torch.cat([reps, (max_len - dec_lens)[:, None]], dim=1)
|
||||||
pad_vec = torch.zeros(bsz, 1, hid, dtype=enc_out.dtype,
|
pad_vec = torch.zeros(bsz, 1, hid, dtype=enc_out.dtype,
|
||||||
device=enc_out.device)
|
device=enc_out.device)
|
||||||
|
|
||||||
enc_rep = torch.cat([enc_out, pad_vec], dim=1)
|
enc_rep = torch.cat([enc_out, pad_vec], dim=1)
|
||||||
enc_rep = torch.repeat_interleave(
|
enc_rep = torch.repeat_interleave(
|
||||||
enc_rep.view(-1, hid), reps_padded.view(-1), dim=0
|
enc_rep.view(-1, hid), reps_padded.view(-1), dim=0
|
||||||
).view(bsz, -1, hid)
|
).view(bsz, -1, hid)
|
||||||
|
|
||||||
# enc_rep = pad_sequence([torch.repeat_interleave(o, r, dim=0)
|
# enc_rep = pad_sequence([torch.repeat_interleave(o, r, dim=0)
|
||||||
# for o, r in zip(enc_out, reps)],
|
# for o, r in zip(enc_out, reps)],
|
||||||
# batch_first=True)
|
# batch_first=True)
|
||||||
if mel_max_len is not None:
|
if mel_max_len is not None:
|
||||||
enc_rep = enc_rep[:, :mel_max_len]
|
enc_rep = enc_rep[:, :mel_max_len]
|
||||||
dec_lens = torch.clamp_max(dec_lens, mel_max_len)
|
dec_lens = torch.clamp_max(dec_lens, mel_max_len)
|
||||||
return enc_rep, dec_lens
|
return enc_rep, dec_lens
|
||||||
|
|
||||||
|
|
||||||
class TemporalPredictor(nn.Module):
|
class TemporalPredictor(nn.Module):
|
||||||
"""Predicts a single float per each temporal location"""
|
"""Predicts a single float per each temporal location"""
|
||||||
|
|
||||||
def __init__(self, input_size, filter_size, kernel_size, dropout,
|
def __init__(self, input_size, filter_size, kernel_size, dropout,
|
||||||
n_layers=2):
|
n_layers=2):
|
||||||
super(TemporalPredictor, self).__init__()
|
super(TemporalPredictor, self).__init__()
|
||||||
|
|
||||||
self.layers = nn.Sequential(*[
|
self.layers = nn.Sequential(*[
|
||||||
ConvReLUNorm(input_size if i == 0 else filter_size, filter_size,
|
ConvReLUNorm(input_size if i == 0 else filter_size, filter_size,
|
||||||
kernel_size=kernel_size, dropout=dropout)
|
kernel_size=kernel_size, dropout=dropout)
|
||||||
for i in range(n_layers)]
|
for i in range(n_layers)]
|
||||||
)
|
)
|
||||||
self.fc = nn.Linear(filter_size, 1, bias=True)
|
self.fc = nn.Linear(filter_size, 1, bias=True)
|
||||||
|
|
||||||
def forward(self, enc_out, enc_out_mask):
|
def forward(self, enc_out, enc_out_mask):
|
||||||
out = enc_out * enc_out_mask
|
out = enc_out * enc_out_mask
|
||||||
out = self.layers(out.transpose(1, 2)).transpose(1, 2)
|
out = self.layers(out.transpose(1, 2)).transpose(1, 2)
|
||||||
out = self.fc(out) * enc_out_mask
|
out = self.fc(out) * enc_out_mask
|
||||||
return out.squeeze(-1)
|
return out.squeeze(-1)
|
||||||
|
|
||||||
|
|
||||||
class FastPitch(nn.Module):
|
class FastPitch(nn.Module):
|
||||||
def __init__(self, n_mel_channels, max_seq_len, n_symbols,
|
def __init__(self, n_mel_channels, max_seq_len, n_symbols, padding_idx,
|
||||||
symbols_embedding_dim, in_fft_n_layers, in_fft_n_heads,
|
symbols_embedding_dim, in_fft_n_layers, in_fft_n_heads,
|
||||||
in_fft_d_head,
|
in_fft_d_head,
|
||||||
in_fft_conv1d_kernel_size, in_fft_conv1d_filter_size,
|
in_fft_conv1d_kernel_size, in_fft_conv1d_filter_size,
|
||||||
in_fft_output_size,
|
in_fft_output_size,
|
||||||
p_in_fft_dropout, p_in_fft_dropatt, p_in_fft_dropemb,
|
p_in_fft_dropout, p_in_fft_dropatt, p_in_fft_dropemb,
|
||||||
out_fft_n_layers, out_fft_n_heads, out_fft_d_head,
|
out_fft_n_layers, out_fft_n_heads, out_fft_d_head,
|
||||||
out_fft_conv1d_kernel_size, out_fft_conv1d_filter_size,
|
out_fft_conv1d_kernel_size, out_fft_conv1d_filter_size,
|
||||||
out_fft_output_size,
|
out_fft_output_size,
|
||||||
p_out_fft_dropout, p_out_fft_dropatt, p_out_fft_dropemb,
|
p_out_fft_dropout, p_out_fft_dropatt, p_out_fft_dropemb,
|
||||||
dur_predictor_kernel_size, dur_predictor_filter_size,
|
dur_predictor_kernel_size, dur_predictor_filter_size,
|
||||||
p_dur_predictor_dropout, dur_predictor_n_layers,
|
p_dur_predictor_dropout, dur_predictor_n_layers,
|
||||||
pitch_predictor_kernel_size, pitch_predictor_filter_size,
|
pitch_predictor_kernel_size, pitch_predictor_filter_size,
|
||||||
p_pitch_predictor_dropout, pitch_predictor_n_layers):
|
p_pitch_predictor_dropout, pitch_predictor_n_layers,
|
||||||
super(FastPitch, self).__init__()
|
pitch_embedding_kernel_size, n_speakers, speaker_emb_weight):
|
||||||
del max_seq_len # unused
|
super(FastPitch, self).__init__()
|
||||||
del n_symbols
|
del max_seq_len # unused
|
||||||
|
|
||||||
self.encoder = FFTransformer(
|
self.encoder = FFTransformer(
|
||||||
n_layer=in_fft_n_layers, n_head=in_fft_n_heads,
|
n_layer=in_fft_n_layers, n_head=in_fft_n_heads,
|
||||||
d_model=symbols_embedding_dim,
|
d_model=symbols_embedding_dim,
|
||||||
d_head=in_fft_d_head,
|
d_head=in_fft_d_head,
|
||||||
d_inner=in_fft_conv1d_filter_size,
|
d_inner=in_fft_conv1d_filter_size,
|
||||||
kernel_size=in_fft_conv1d_kernel_size,
|
kernel_size=in_fft_conv1d_kernel_size,
|
||||||
dropout=p_in_fft_dropout,
|
dropout=p_in_fft_dropout,
|
||||||
dropatt=p_in_fft_dropatt,
|
dropatt=p_in_fft_dropatt,
|
||||||
dropemb=p_in_fft_dropemb,
|
dropemb=p_in_fft_dropemb,
|
||||||
d_embed=symbols_embedding_dim,
|
embed_input=True,
|
||||||
embed_input=True)
|
d_embed=symbols_embedding_dim,
|
||||||
|
n_embed=n_symbols,
|
||||||
self.duration_predictor = TemporalPredictor(
|
padding_idx=padding_idx)
|
||||||
in_fft_output_size,
|
|
||||||
filter_size=dur_predictor_filter_size,
|
if n_speakers > 1:
|
||||||
kernel_size=dur_predictor_kernel_size,
|
self.speaker_emb = nn.Embedding(n_speakers, symbols_embedding_dim)
|
||||||
dropout=p_dur_predictor_dropout, n_layers=dur_predictor_n_layers
|
else:
|
||||||
)
|
self.speaker_emb = None
|
||||||
|
self.speaker_emb_weight = speaker_emb_weight
|
||||||
self.decoder = FFTransformer(
|
|
||||||
n_layer=out_fft_n_layers, n_head=out_fft_n_heads,
|
self.duration_predictor = TemporalPredictor(
|
||||||
d_model=symbols_embedding_dim,
|
in_fft_output_size,
|
||||||
d_head=out_fft_d_head,
|
filter_size=dur_predictor_filter_size,
|
||||||
d_inner=out_fft_conv1d_filter_size,
|
kernel_size=dur_predictor_kernel_size,
|
||||||
kernel_size=out_fft_conv1d_kernel_size,
|
dropout=p_dur_predictor_dropout, n_layers=dur_predictor_n_layers
|
||||||
dropout=p_out_fft_dropout,
|
)
|
||||||
dropatt=p_out_fft_dropatt,
|
|
||||||
dropemb=p_out_fft_dropemb,
|
self.decoder = FFTransformer(
|
||||||
d_embed=symbols_embedding_dim,
|
n_layer=out_fft_n_layers, n_head=out_fft_n_heads,
|
||||||
embed_input=False)
|
d_model=symbols_embedding_dim,
|
||||||
|
d_head=out_fft_d_head,
|
||||||
self.pitch_predictor = TemporalPredictor(
|
d_inner=out_fft_conv1d_filter_size,
|
||||||
in_fft_output_size,
|
kernel_size=out_fft_conv1d_kernel_size,
|
||||||
filter_size=pitch_predictor_filter_size,
|
dropout=p_out_fft_dropout,
|
||||||
kernel_size=pitch_predictor_kernel_size,
|
dropatt=p_out_fft_dropatt,
|
||||||
dropout=p_pitch_predictor_dropout, n_layers=pitch_predictor_n_layers
|
dropemb=p_out_fft_dropemb,
|
||||||
)
|
embed_input=False,
|
||||||
self.pitch_emb = nn.Conv1d(1, symbols_embedding_dim, kernel_size=3,
|
d_embed=symbols_embedding_dim
|
||||||
padding=1)
|
)
|
||||||
|
|
||||||
# Store values precomputed for training data within the model
|
self.pitch_predictor = TemporalPredictor(
|
||||||
self.register_buffer('pitch_mean', torch.zeros(1))
|
in_fft_output_size,
|
||||||
self.register_buffer('pitch_std', torch.zeros(1))
|
filter_size=pitch_predictor_filter_size,
|
||||||
|
kernel_size=pitch_predictor_kernel_size,
|
||||||
self.proj = nn.Linear(out_fft_output_size, n_mel_channels, bias=True)
|
dropout=p_pitch_predictor_dropout, n_layers=pitch_predictor_n_layers
|
||||||
|
)
|
||||||
def forward(self, inputs: List[torch.Tensor], use_gt_durations: bool = True,
|
|
||||||
use_gt_pitch: bool = True, pace: float = 1.0,
|
self.pitch_emb = nn.Conv1d(
|
||||||
max_duration: int = 75):
|
1, symbols_embedding_dim,
|
||||||
inputs, _, mel_tgt, _, dur_tgt, _, pitch_tgt = inputs
|
kernel_size=pitch_embedding_kernel_size,
|
||||||
mel_max_len = mel_tgt.size(2)
|
padding=int((pitch_embedding_kernel_size - 1) / 2))
|
||||||
|
|
||||||
# Input FFT
|
# Store values precomputed for training data within the model
|
||||||
enc_out, enc_mask = self.encoder(inputs)
|
self.register_buffer('pitch_mean', torch.zeros(1))
|
||||||
|
self.register_buffer('pitch_std', torch.zeros(1))
|
||||||
# Embedded for predictors
|
|
||||||
pred_enc_out, pred_enc_mask = enc_out, enc_mask
|
self.proj = nn.Linear(out_fft_output_size, n_mel_channels, bias=True)
|
||||||
|
|
||||||
# Predict durations
|
def forward(self, inputs: List[torch.Tensor], use_gt_durations: bool = True,
|
||||||
log_dur_pred = self.duration_predictor(pred_enc_out, pred_enc_mask)
|
use_gt_pitch: bool = True, pace: float = 1.0,
|
||||||
dur_pred = torch.clamp(torch.exp(log_dur_pred) - 1, 0, max_duration)
|
max_duration: int = 75):
|
||||||
|
inputs, _, mel_tgt, _, dur_tgt, _, pitch_tgt, speaker = inputs
|
||||||
# Predict pitch
|
mel_max_len = mel_tgt.size(2)
|
||||||
pitch_pred = self.pitch_predictor(enc_out, enc_mask)
|
|
||||||
|
# Calculate speaker embedding
|
||||||
if use_gt_pitch and pitch_tgt is not None:
|
if self.speaker_emb is None:
|
||||||
pitch_emb = self.pitch_emb(pitch_tgt.unsqueeze(1))
|
spk_emb = 0
|
||||||
else:
|
else:
|
||||||
pitch_emb = self.pitch_emb(pitch_pred.unsqueeze(1))
|
spk_emb = self.speaker_emb(speaker).unsqueeze(1)
|
||||||
enc_out = enc_out + pitch_emb.transpose(1, 2)
|
spk_emb.mul_(self.speaker_emb_weight)
|
||||||
|
|
||||||
len_regulated, dec_lens = regulate_len(
|
# Input FFT
|
||||||
dur_tgt if use_gt_durations else dur_pred,
|
enc_out, enc_mask = self.encoder(inputs, conditioning=spk_emb)
|
||||||
enc_out, pace, mel_max_len)
|
|
||||||
|
# Embedded for predictors
|
||||||
# Output FFT
|
pred_enc_out, pred_enc_mask = enc_out, enc_mask
|
||||||
dec_out, dec_mask = self.decoder(len_regulated, dec_lens)
|
|
||||||
mel_out = self.proj(dec_out)
|
# Predict durations
|
||||||
return mel_out, dec_mask, dur_pred, log_dur_pred, pitch_pred
|
log_dur_pred = self.duration_predictor(pred_enc_out, pred_enc_mask)
|
||||||
|
dur_pred = torch.clamp(torch.exp(log_dur_pred) - 1, 0, max_duration)
|
||||||
def infer(self, inputs, input_lens, pace: float = 1.0,
|
|
||||||
dur_tgt: Optional[torch.Tensor] = None,
|
# Predict pitch
|
||||||
pitch_tgt: Optional[torch.Tensor] = None,
|
pitch_pred = self.pitch_predictor(enc_out, enc_mask)
|
||||||
max_duration: float = 75):
|
|
||||||
|
if use_gt_pitch and pitch_tgt is not None:
|
||||||
# Input FFT
|
pitch_emb = self.pitch_emb(pitch_tgt.unsqueeze(1))
|
||||||
enc_out, enc_mask = self.encoder(inputs)
|
else:
|
||||||
|
pitch_emb = self.pitch_emb(pitch_pred.unsqueeze(1))
|
||||||
# Embedded for predictors
|
enc_out = enc_out + pitch_emb.transpose(1, 2)
|
||||||
pred_enc_out, pred_enc_mask = enc_out, enc_mask
|
|
||||||
|
len_regulated, dec_lens = regulate_len(
|
||||||
# Predict durations
|
dur_tgt if use_gt_durations else dur_pred,
|
||||||
log_dur_pred = self.duration_predictor(pred_enc_out, pred_enc_mask)
|
enc_out, pace, mel_max_len)
|
||||||
dur_pred = torch.clamp(torch.exp(log_dur_pred) - 1, 0, max_duration)
|
|
||||||
|
# Output FFT
|
||||||
# Pitch over chars
|
dec_out, dec_mask = self.decoder(len_regulated, dec_lens)
|
||||||
pitch_pred = self.pitch_predictor(enc_out, enc_mask)
|
mel_out = self.proj(dec_out)
|
||||||
|
return mel_out, dec_mask, dur_pred, log_dur_pred, pitch_pred
|
||||||
if pitch_tgt is None:
|
|
||||||
pitch_emb = self.pitch_emb(pitch_pred.unsqueeze(1)).transpose(1, 2)
|
def infer(self, inputs, input_lens, pace: float = 1.0,
|
||||||
else:
|
dur_tgt: Optional[torch.Tensor] = None,
|
||||||
pitch_emb = self.pitch_emb(pitch_tgt.unsqueeze(1)).transpose(1, 2)
|
pitch_tgt: Optional[torch.Tensor] = None,
|
||||||
|
max_duration: float = 75,
|
||||||
enc_out = enc_out + pitch_emb
|
speaker: int = 0):
|
||||||
|
del input_lens # unused
|
||||||
len_regulated, dec_lens = regulate_len(
|
|
||||||
dur_pred if dur_tgt is None else dur_tgt,
|
if self.speaker_emb is None:
|
||||||
enc_out, pace, mel_max_len=None)
|
spk_emb = None
|
||||||
|
else:
|
||||||
dec_out, dec_mask = self.decoder(len_regulated, dec_lens)
|
speaker = torch.ones(inputs.size(0), dtype=torch.long, device=inputs.device).fill_(speaker)
|
||||||
mel_out = self.proj(dec_out)
|
spk_emb = self.speaker_emb(speaker).unsqueeze(1)
|
||||||
# mel_lens = dec_mask.squeeze(2).sum(axis=1).long()
|
spk_emb.mul_(self.speaker_emb_weight)
|
||||||
mel_out = mel_out.permute(0, 2, 1) # For inference.py
|
|
||||||
return mel_out, dec_lens, dur_pred, pitch_pred
|
# Input FFT
|
||||||
|
enc_out, enc_mask = self.encoder(inputs, conditioning=spk_emb)
|
||||||
|
|
||||||
|
# Embedded for predictors
|
||||||
|
pred_enc_out, pred_enc_mask = enc_out, enc_mask
|
||||||
|
|
||||||
|
# Predict durations
|
||||||
|
log_dur_pred = self.duration_predictor(pred_enc_out, pred_enc_mask)
|
||||||
|
dur_pred = torch.clamp(torch.exp(log_dur_pred) - 1, 0, max_duration)
|
||||||
|
|
||||||
|
# Pitch over chars
|
||||||
|
pitch_pred = self.pitch_predictor(enc_out, enc_mask)
|
||||||
|
|
||||||
|
if pitch_tgt is None:
|
||||||
|
pitch_emb = self.pitch_emb(pitch_pred.unsqueeze(1)).transpose(1, 2)
|
||||||
|
else:
|
||||||
|
pitch_emb = self.pitch_emb(pitch_tgt.unsqueeze(1)).transpose(1, 2)
|
||||||
|
|
||||||
|
enc_out = enc_out + pitch_emb
|
||||||
|
|
||||||
|
len_regulated, dec_lens = regulate_len(
|
||||||
|
dur_pred if dur_tgt is None else dur_tgt,
|
||||||
|
enc_out, pace, mel_max_len=None)
|
||||||
|
|
||||||
|
dec_out, dec_mask = self.decoder(len_regulated, dec_lens)
|
||||||
|
mel_out = self.proj(dec_out)
|
||||||
|
# mel_lens = dec_mask.squeeze(2).sum(axis=1).long()
|
||||||
|
mel_out = mel_out.permute(0, 2, 1) # For inference.py
|
||||||
|
return mel_out, dec_lens, dur_pred, pitch_pred
|
||||||
|
|
|
@ -17,7 +17,6 @@ import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from common.utils import mask_from_lens
|
from common.utils import mask_from_lens
|
||||||
from common.text.symbols import pad_idx, symbols
|
|
||||||
|
|
||||||
|
|
||||||
class PositionalEmbedding(nn.Module):
|
class PositionalEmbedding(nn.Module):
|
||||||
|
@ -248,16 +247,17 @@ class TransformerLayer(nn.Module):
|
||||||
|
|
||||||
class FFTransformer(nn.Module):
|
class FFTransformer(nn.Module):
|
||||||
def __init__(self, n_layer, n_head, d_model, d_head, d_inner, kernel_size,
|
def __init__(self, n_layer, n_head, d_model, d_head, d_inner, kernel_size,
|
||||||
dropout, dropatt, dropemb=0.0, embed_input=True, d_embed=None,
|
dropout, dropatt, dropemb=0.0, embed_input=True,
|
||||||
pre_lnorm=False):
|
n_embed=None, d_embed=None, padding_idx=0, pre_lnorm=False):
|
||||||
super(FFTransformer, self).__init__()
|
super(FFTransformer, self).__init__()
|
||||||
self.d_model = d_model
|
self.d_model = d_model
|
||||||
self.n_head = n_head
|
self.n_head = n_head
|
||||||
self.d_head = d_head
|
self.d_head = d_head
|
||||||
|
self.padding_idx = padding_idx
|
||||||
|
|
||||||
if embed_input:
|
if embed_input:
|
||||||
self.word_emb = nn.Embedding(len(symbols), d_embed or d_model,
|
self.word_emb = nn.Embedding(n_embed, d_embed or d_model,
|
||||||
padding_idx=pad_idx)
|
padding_idx=self.padding_idx)
|
||||||
else:
|
else:
|
||||||
self.word_emb = None
|
self.word_emb = None
|
||||||
|
|
||||||
|
@ -272,18 +272,18 @@ class FFTransformer(nn.Module):
|
||||||
dropatt=dropatt, pre_lnorm=pre_lnorm)
|
dropatt=dropatt, pre_lnorm=pre_lnorm)
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, dec_inp, seq_lens=None):
|
def forward(self, dec_inp, seq_lens=None, conditioning=0):
|
||||||
if self.word_emb is None:
|
if self.word_emb is None:
|
||||||
inp = dec_inp
|
inp = dec_inp
|
||||||
mask = mask_from_lens(seq_lens).unsqueeze(2)
|
mask = mask_from_lens(seq_lens).unsqueeze(2)
|
||||||
else:
|
else:
|
||||||
inp = self.word_emb(dec_inp)
|
inp = self.word_emb(dec_inp)
|
||||||
# [bsz x L x 1]
|
# [bsz x L x 1]
|
||||||
mask = (dec_inp != pad_idx).unsqueeze(2)
|
mask = (dec_inp != self.padding_idx).unsqueeze(2)
|
||||||
|
|
||||||
pos_seq = torch.arange(inp.size(1), device=inp.device, dtype=inp.dtype)
|
pos_seq = torch.arange(inp.size(1), device=inp.device, dtype=inp.dtype)
|
||||||
pos_emb = self.pos_emb(pos_seq) * mask
|
pos_emb = self.pos_emb(pos_seq) * mask
|
||||||
out = self.drop(inp + pos_emb)
|
out = self.drop(inp + pos_emb + conditioning)
|
||||||
|
|
||||||
for layer in self.layers:
|
for layer in self.layers:
|
||||||
out = layer(out, mask=mask)
|
out = layer(out, mask=mask)
|
||||||
|
|
|
@ -19,7 +19,6 @@ import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from common.utils import mask_from_lens
|
from common.utils import mask_from_lens
|
||||||
from common.text.symbols import pad_idx, symbols
|
|
||||||
|
|
||||||
|
|
||||||
class NoOp(nn.Module):
|
class NoOp(nn.Module):
|
||||||
|
@ -255,20 +254,20 @@ class TransformerLayer(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class FFTransformer(nn.Module):
|
class FFTransformer(nn.Module):
|
||||||
pad_idx = 0 # XXX
|
|
||||||
|
|
||||||
def __init__(self, n_layer, n_head, d_model, d_head, d_inner, kernel_size,
|
def __init__(self, n_layer, n_head, d_model, d_head, d_inner, kernel_size,
|
||||||
dropout, dropatt, dropemb=0.0, embed_input=True, d_embed=None,
|
dropout, dropatt, dropemb=0.0, embed_input=True,
|
||||||
pre_lnorm=False):
|
n_embed=None, d_embed=None, padding_idx=0, pre_lnorm=False):
|
||||||
super(FFTransformer, self).__init__()
|
super(FFTransformer, self).__init__()
|
||||||
self.d_model = d_model
|
self.d_model = d_model
|
||||||
self.n_head = n_head
|
self.n_head = n_head
|
||||||
self.d_head = d_head
|
self.d_head = d_head
|
||||||
|
self.padding_idx = padding_idx
|
||||||
|
self.n_embed = n_embed
|
||||||
|
|
||||||
self.embed_input = embed_input
|
self.embed_input = embed_input
|
||||||
if embed_input:
|
if embed_input:
|
||||||
self.word_emb = nn.Embedding(len(symbols), d_embed or d_model,
|
self.word_emb = nn.Embedding(n_embed, d_embed or d_model,
|
||||||
padding_idx=FFTransformer.pad_idx)
|
padding_idx=self.padding_idx)
|
||||||
else:
|
else:
|
||||||
self.word_emb = NoOp()
|
self.word_emb = NoOp()
|
||||||
|
|
||||||
|
@ -283,19 +282,23 @@ class FFTransformer(nn.Module):
|
||||||
dropatt=dropatt, pre_lnorm=pre_lnorm)
|
dropatt=dropatt, pre_lnorm=pre_lnorm)
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, dec_inp, seq_lens: Optional[torch.Tensor] = None):
|
def forward(self, dec_inp, seq_lens: Optional[torch.Tensor] = None,
|
||||||
if self.embed_input:
|
conditioning: Optional[torch.Tensor] = None):
|
||||||
inp = self.word_emb(dec_inp)
|
if not self.embed_input:
|
||||||
# [bsz x L x 1]
|
|
||||||
# mask = (dec_inp != FFTransformer.pad_idx).unsqueeze(2)
|
|
||||||
mask = (dec_inp != 0).unsqueeze(2)
|
|
||||||
else:
|
|
||||||
inp = dec_inp
|
inp = dec_inp
|
||||||
assert seq_lens is not None
|
assert seq_lens is not None
|
||||||
mask = mask_from_lens(seq_lens).unsqueeze(2)
|
mask = mask_from_lens(seq_lens).unsqueeze(2)
|
||||||
|
else:
|
||||||
|
inp = self.word_emb(dec_inp)
|
||||||
|
# [bsz x L x 1]
|
||||||
|
mask = (dec_inp != self.padding_idx).unsqueeze(2)
|
||||||
|
|
||||||
pos_seq = torch.arange(inp.size(1), device=inp.device, dtype=inp.dtype)
|
pos_seq = torch.arange(inp.size(1), device=inp.device, dtype=inp.dtype)
|
||||||
pos_emb = self.pos_emb(pos_seq) * mask
|
pos_emb = self.pos_emb(pos_seq) * mask
|
||||||
out = self.drop(inp + pos_emb)
|
if conditioning is not None:
|
||||||
|
out = self.drop(inp + pos_emb + conditioning)
|
||||||
|
else:
|
||||||
|
out = self.drop(inp + pos_emb)
|
||||||
|
|
||||||
for layer in self.layers:
|
for layer in self.layers:
|
||||||
out = layer(out, mask=mask)
|
out = layer(out, mask=mask)
|
||||||
|
|
|
@ -45,7 +45,7 @@ from dllogger import StdOutBackend, JSONStreamBackend, Verbosity
|
||||||
from common import utils
|
from common import utils
|
||||||
from common.tb_dllogger import (init_inference_metadata, stdout_metric_format,
|
from common.tb_dllogger import (init_inference_metadata, stdout_metric_format,
|
||||||
unique_log_fpath)
|
unique_log_fpath)
|
||||||
from common.text import text_to_sequence
|
from common.text.text_processing import TextProcessing
|
||||||
from pitch_transform import pitch_transform_custom
|
from pitch_transform import pitch_transform_custom
|
||||||
from waveglow import model as glow
|
from waveglow import model as glow
|
||||||
from waveglow.denoiser import Denoiser
|
from waveglow.denoiser import Denoiser
|
||||||
|
@ -92,9 +92,11 @@ def parse_args(parser):
|
||||||
help='Use EMA averaged model (if saved in checkpoints)')
|
help='Use EMA averaged model (if saved in checkpoints)')
|
||||||
parser.add_argument('--dataset-path', type=str,
|
parser.add_argument('--dataset-path', type=str,
|
||||||
help='Path to dataset (for loading extra data fields)')
|
help='Path to dataset (for loading extra data fields)')
|
||||||
|
parser.add_argument('--speaker', type=int, default=0,
|
||||||
|
help='Speaker ID for a multi-speaker model')
|
||||||
|
|
||||||
transform = parser.add_argument_group('transform')
|
transform = parser.add_argument_group('transform')
|
||||||
transform.add_argument('--fade-out', type=int, default=5,
|
transform.add_argument('--fade-out', type=int, default=10,
|
||||||
help='Number of fadeout frames at the end')
|
help='Number of fadeout frames at the end')
|
||||||
transform.add_argument('--pace', type=float, default=1.0,
|
transform.add_argument('--pace', type=float, default=1.0,
|
||||||
help='Adjust the pace of speech')
|
help='Adjust the pace of speech')
|
||||||
|
@ -108,6 +110,18 @@ def parse_args(parser):
|
||||||
help='Raise/lower the pitch by <hz>')
|
help='Raise/lower the pitch by <hz>')
|
||||||
transform.add_argument('--pitch-transform-custom', action='store_true',
|
transform.add_argument('--pitch-transform-custom', action='store_true',
|
||||||
help='Apply the transform from pitch_transform.py')
|
help='Apply the transform from pitch_transform.py')
|
||||||
|
|
||||||
|
text_processing = parser.add_argument_group('Text processing parameters')
|
||||||
|
text_processing.add_argument('--text-cleaners', nargs='*',
|
||||||
|
default=['english_cleaners'], type=str,
|
||||||
|
help='Type of text cleaners for input text')
|
||||||
|
text_processing.add_argument('--symbol-set', type=str, default='english_basic',
|
||||||
|
help='Define symbol set for input text')
|
||||||
|
|
||||||
|
cond = parser.add_argument_group('conditioning on additional attributes')
|
||||||
|
cond.add_argument('--n-speakers', type=int, default=1,
|
||||||
|
help='Number of speakers in the model.')
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
@ -138,7 +152,7 @@ def load_and_setup_model(model_name, parser, checkpoint, amp, device,
|
||||||
|
|
||||||
if any(key.startswith('module.') for key in sd):
|
if any(key.startswith('module.') for key in sd):
|
||||||
sd = {k.replace('module.', ''): v for k,v in sd.items()}
|
sd = {k.replace('module.', ''): v for k,v in sd.items()}
|
||||||
status += ' ' + str(model.load_state_dict(sd, strict=False))
|
status += ' ' + str(model.load_state_dict(sd, strict=True))
|
||||||
else:
|
else:
|
||||||
model = checkpoint_data['model']
|
model = checkpoint_data['model']
|
||||||
print(f'Loaded {model_name}{status}')
|
print(f'Loaded {model_name}{status}')
|
||||||
|
@ -162,10 +176,13 @@ def load_fields(fpath):
|
||||||
return {c:f for c, f in zip(columns, fields)}
|
return {c:f for c, f in zip(columns, fields)}
|
||||||
|
|
||||||
|
|
||||||
def prepare_input_sequence(fields, device, batch_size=128, dataset=None,
|
def prepare_input_sequence(fields, device, symbol_set, text_cleaners,
|
||||||
load_mels=False, load_pitch=False):
|
batch_size=128, dataset=None, load_mels=False,
|
||||||
fields['text'] = [torch.LongTensor(text_to_sequence(t, ['english_cleaners']))
|
load_pitch=False):
|
||||||
for t in fields['text']]
|
tp = TextProcessing(symbol_set, text_cleaners)
|
||||||
|
|
||||||
|
fields['text'] = [torch.LongTensor(tp.encode_text(text))
|
||||||
|
for text in fields['text']]
|
||||||
order = np.argsort([-t.size(0) for t in fields['text']])
|
order = np.argsort([-t.size(0) for t in fields['text']])
|
||||||
|
|
||||||
fields['text'] = [fields['text'][i] for i in order]
|
fields['text'] = [fields['text'][i] for i in order]
|
||||||
|
@ -206,7 +223,6 @@ def prepare_input_sequence(fields, device, batch_size=128, dataset=None,
|
||||||
|
|
||||||
|
|
||||||
def build_pitch_transformation(args):
|
def build_pitch_transformation(args):
|
||||||
|
|
||||||
if args.pitch_transform_custom:
|
if args.pitch_transform_custom:
|
||||||
def custom_(pitch, pitch_lens, mean, std):
|
def custom_(pitch, pitch_lens, mean, std):
|
||||||
return (pitch_transform_custom(pitch * std + mean, pitch_lens)
|
return (pitch_transform_custom(pitch * std + mean, pitch_lens)
|
||||||
|
@ -262,7 +278,7 @@ def main():
|
||||||
StdOutBackend(Verbosity.VERBOSE,
|
StdOutBackend(Verbosity.VERBOSE,
|
||||||
metric_format=stdout_metric_format)])
|
metric_format=stdout_metric_format)])
|
||||||
init_inference_metadata()
|
init_inference_metadata()
|
||||||
[DLLogger.log("PARAMETER", {k:v}) for k,v in vars(args).items()]
|
[DLLogger.log("PARAMETER", {k: v}) for k, v in vars(args).items()]
|
||||||
|
|
||||||
device = torch.device('cuda' if args.cuda else 'cpu')
|
device = torch.device('cuda' if args.cuda else 'cpu')
|
||||||
|
|
||||||
|
@ -293,8 +309,8 @@ def main():
|
||||||
|
|
||||||
fields = load_fields(args.input)
|
fields = load_fields(args.input)
|
||||||
batches = prepare_input_sequence(
|
batches = prepare_input_sequence(
|
||||||
fields, device, args.batch_size, args.dataset_path,
|
fields, device, args.symbol_set, args.text_cleaners, args.batch_size,
|
||||||
load_mels=(generator is None))
|
args.dataset_path, load_mels=(generator is None))
|
||||||
|
|
||||||
if args.include_warmup:
|
if args.include_warmup:
|
||||||
# Use real data rather than synthetic - FastPitch predicts len
|
# Use real data rather than synthetic - FastPitch predicts len
|
||||||
|
@ -311,11 +327,13 @@ def main():
|
||||||
waveglow_measures = MeasureTime()
|
waveglow_measures = MeasureTime()
|
||||||
|
|
||||||
gen_kw = {'pace': args.pace,
|
gen_kw = {'pace': args.pace,
|
||||||
|
'speaker': args.speaker,
|
||||||
'pitch_tgt': None,
|
'pitch_tgt': None,
|
||||||
'pitch_transform': build_pitch_transformation(args)}
|
'pitch_transform': build_pitch_transformation(args)}
|
||||||
|
|
||||||
if args.torchscript:
|
if args.torchscript:
|
||||||
gen_kw.pop('pitch_transform')
|
gen_kw.pop('pitch_transform')
|
||||||
|
print('NOTE: Pitch transforms are disabled with TorchScript')
|
||||||
|
|
||||||
all_utterances = 0
|
all_utterances = 0
|
||||||
all_samples = 0
|
all_samples = 0
|
||||||
|
@ -323,11 +341,10 @@ def main():
|
||||||
all_frames = 0
|
all_frames = 0
|
||||||
|
|
||||||
reps = args.repeats
|
reps = args.repeats
|
||||||
log_enabled = True # reps == 1
|
log_enabled = reps == 1
|
||||||
log = lambda s, d: DLLogger.log(step=s, data=d) if log_enabled else None
|
log = lambda s, d: DLLogger.log(step=s, data=d) if log_enabled else None
|
||||||
|
|
||||||
# for repeat in (tqdm.tqdm(range(reps)) if reps > 1 else range(reps)):
|
for rep in (tqdm.tqdm(range(reps)) if reps > 1 else range(reps)):
|
||||||
for rep in range(reps):
|
|
||||||
for b in batches:
|
for b in batches:
|
||||||
if generator is None:
|
if generator is None:
|
||||||
log(rep, {'Synthesizing from ground truth mels'})
|
log(rep, {'Synthesizing from ground truth mels'})
|
||||||
|
@ -348,7 +365,7 @@ def main():
|
||||||
audios = waveglow(mel, sigma=args.sigma_infer)
|
audios = waveglow(mel, sigma=args.sigma_infer)
|
||||||
audios = denoiser(audios.float(),
|
audios = denoiser(audios.float(),
|
||||||
strength=args.denoising_strength
|
strength=args.denoising_strength
|
||||||
).squeeze(1)
|
).squeeze(1)
|
||||||
|
|
||||||
all_utterances += len(audios)
|
all_utterances += len(audios)
|
||||||
all_samples += sum(audio.size(0) for audio in audios)
|
all_samples += sum(audio.size(0) for audio in audios)
|
||||||
|
@ -367,7 +384,7 @@ def main():
|
||||||
fade_w = torch.linspace(1.0, 0.0, fade_len)
|
fade_w = torch.linspace(1.0, 0.0, fade_len)
|
||||||
audio[-fade_len:] *= fade_w.to(audio.device)
|
audio[-fade_len:] *= fade_w.to(audio.device)
|
||||||
|
|
||||||
audio = audio/torch.max(torch.abs(audio))
|
audio = audio / torch.max(torch.abs(audio))
|
||||||
fname = b['output'][i] if 'output' in b else f'audio_{i}.wav'
|
fname = b['output'][i] if 'output' in b else f'audio_{i}.wav'
|
||||||
audio_path = Path(args.output, fname)
|
audio_path = Path(args.output, fname)
|
||||||
write(audio_path, args.sampling_rate, audio.cpu().numpy())
|
write(audio_path, args.sampling_rate, audio.cpu().numpy())
|
||||||
|
|
|
@ -37,6 +37,7 @@ from fastpitch.model import FastPitch as _FastPitch
|
||||||
from fastpitch.model_jit import FastPitch as _FastPitchJIT
|
from fastpitch.model_jit import FastPitch as _FastPitchJIT
|
||||||
from tacotron2.model import Tacotron2
|
from tacotron2.model import Tacotron2
|
||||||
from waveglow.model import WaveGlow
|
from waveglow.model import WaveGlow
|
||||||
|
from common.text.symbols import get_symbols, get_pad_idx
|
||||||
|
|
||||||
|
|
||||||
def parse_model_args(model_name, parser, add_help=False):
|
def parse_model_args(model_name, parser, add_help=False):
|
||||||
|
@ -94,25 +95,27 @@ def get_model(model_name, model_config, device,
|
||||||
model = WaveGlow(**model_config)
|
model = WaveGlow(**model_config)
|
||||||
|
|
||||||
elif model_name == 'FastPitch':
|
elif model_name == 'FastPitch':
|
||||||
|
|
||||||
if forward_is_infer:
|
if forward_is_infer:
|
||||||
|
|
||||||
if jitable:
|
if jitable:
|
||||||
class FastPitch__forward_is_infer(_FastPitchJIT):
|
class FastPitch__forward_is_infer(_FastPitchJIT):
|
||||||
def forward(self, inputs, input_lengths, pace: float = 1.0,
|
def forward(self, inputs, input_lengths, pace: float = 1.0,
|
||||||
dur_tgt: Optional[torch.Tensor] = None,
|
dur_tgt: Optional[torch.Tensor] = None,
|
||||||
pitch_tgt: Optional[torch.Tensor] = None):
|
pitch_tgt: Optional[torch.Tensor] = None,
|
||||||
|
speaker: int = 0):
|
||||||
return self.infer(inputs, input_lengths, pace=pace,
|
return self.infer(inputs, input_lengths, pace=pace,
|
||||||
dur_tgt=dur_tgt, pitch_tgt=pitch_tgt)
|
dur_tgt=dur_tgt, pitch_tgt=pitch_tgt,
|
||||||
|
speaker=speaker)
|
||||||
else:
|
else:
|
||||||
class FastPitch__forward_is_infer(_FastPitch):
|
class FastPitch__forward_is_infer(_FastPitch):
|
||||||
def forward(self, inputs, input_lengths, pace: float = 1.0,
|
def forward(self, inputs, input_lengths, pace: float = 1.0,
|
||||||
dur_tgt: Optional[torch.Tensor] = None,
|
dur_tgt: Optional[torch.Tensor] = None,
|
||||||
pitch_tgt: Optional[torch.Tensor] = None,
|
pitch_tgt: Optional[torch.Tensor] = None,
|
||||||
pitch_transform=None):
|
pitch_transform=None,
|
||||||
|
speaker: Optional[int] = None):
|
||||||
return self.infer(inputs, input_lengths, pace=pace,
|
return self.infer(inputs, input_lengths, pace=pace,
|
||||||
dur_tgt=dur_tgt, pitch_tgt=pitch_tgt,
|
dur_tgt=dur_tgt, pitch_tgt=pitch_tgt,
|
||||||
pitch_transform=pitch_transform)
|
pitch_transform=pitch_transform,
|
||||||
|
speaker=speaker)
|
||||||
|
|
||||||
model = FastPitch__forward_is_infer(**model_config)
|
model = FastPitch__forward_is_infer(**model_config)
|
||||||
else:
|
else:
|
||||||
|
@ -136,7 +139,7 @@ def get_model_config(model_name, args):
|
||||||
# audio
|
# audio
|
||||||
n_mel_channels=args.n_mel_channels,
|
n_mel_channels=args.n_mel_channels,
|
||||||
# symbols
|
# symbols
|
||||||
n_symbols=args.n_symbols,
|
n_symbols=len(get_symbols(args.symbol_set)),
|
||||||
symbols_embedding_dim=args.symbols_embedding_dim,
|
symbols_embedding_dim=args.symbols_embedding_dim,
|
||||||
# encoder
|
# encoder
|
||||||
encoder_kernel_size=args.encoder_kernel_size,
|
encoder_kernel_size=args.encoder_kernel_size,
|
||||||
|
@ -183,7 +186,8 @@ def get_model_config(model_name, args):
|
||||||
n_mel_channels=args.n_mel_channels,
|
n_mel_channels=args.n_mel_channels,
|
||||||
max_seq_len=args.max_seq_len,
|
max_seq_len=args.max_seq_len,
|
||||||
# symbols
|
# symbols
|
||||||
n_symbols=args.n_symbols,
|
n_symbols=len(get_symbols(args.symbol_set)),
|
||||||
|
padding_idx=get_pad_idx(args.symbol_set),
|
||||||
symbols_embedding_dim=args.symbols_embedding_dim,
|
symbols_embedding_dim=args.symbols_embedding_dim,
|
||||||
# input FFT
|
# input FFT
|
||||||
in_fft_n_layers=args.in_fft_n_layers,
|
in_fft_n_layers=args.in_fft_n_layers,
|
||||||
|
@ -215,6 +219,11 @@ def get_model_config(model_name, args):
|
||||||
pitch_predictor_filter_size=args.pitch_predictor_filter_size,
|
pitch_predictor_filter_size=args.pitch_predictor_filter_size,
|
||||||
p_pitch_predictor_dropout=args.p_pitch_predictor_dropout,
|
p_pitch_predictor_dropout=args.p_pitch_predictor_dropout,
|
||||||
pitch_predictor_n_layers=args.pitch_predictor_n_layers,
|
pitch_predictor_n_layers=args.pitch_predictor_n_layers,
|
||||||
|
# pitch conditioning
|
||||||
|
pitch_embedding_kernel_size=args.pitch_embedding_kernel_size,
|
||||||
|
# speakers parameters
|
||||||
|
n_speakers=args.n_speakers,
|
||||||
|
speaker_emb_weight=args.speaker_emb_weight
|
||||||
)
|
)
|
||||||
return model_config
|
return model_config
|
||||||
|
|
||||||
|
|
|
@ -6,16 +6,13 @@ DATA_DIR="LJSpeech-1.1"
|
||||||
LJS_ARCH="LJSpeech-1.1.tar.bz2"
|
LJS_ARCH="LJSpeech-1.1.tar.bz2"
|
||||||
LJS_URL="http://data.keithito.com/data/speech/${LJS_ARCH}"
|
LJS_URL="http://data.keithito.com/data/speech/${LJS_ARCH}"
|
||||||
|
|
||||||
if [ ! -f ${LJS_ARCH} ]; then
|
if [ ! -d ${DATA_DIR} ]; then
|
||||||
echo "Downloading ${LJS_ARCH} ..."
|
echo "Downloading ${LJS_ARCH} ..."
|
||||||
wget -q ${LJS_URL}
|
wget -q ${LJS_URL}
|
||||||
fi
|
|
||||||
|
|
||||||
if [ ! -d ${DATA_DIR} ]; then
|
|
||||||
echo "Extracting ${LJS_ARCH} ..."
|
echo "Extracting ${LJS_ARCH} ..."
|
||||||
tar jxvf ${LJS_ARCH}
|
tar jxvf ${LJS_ARCH}
|
||||||
rm -f ${LJS_ARCH}
|
rm -f ${LJS_ARCH}
|
||||||
fi
|
fi
|
||||||
|
|
||||||
bash scripts/download_tacotron2.sh
|
bash ./scripts/download_tacotron2.sh
|
||||||
bash scripts/download_waveglow.sh
|
bash ./scripts/download_waveglow.sh
|
||||||
|
|
|
@ -2,19 +2,20 @@
|
||||||
|
|
||||||
set -e
|
set -e
|
||||||
|
|
||||||
MODEL_DIR=${MODEL_DIR:-"pretrained_models"}
|
: ${MODEL_DIR:="pretrained_models/fastpitch"}
|
||||||
FASTP_ZIP="nvidia_fastpitch_200518.zip"
|
MODEL_ZIP="nvidia_fastpitch_200518.zip"
|
||||||
FASTP_CH="nvidia_fastpitch_200518.pt"
|
MODEL_CH="nvidia_fastpitch_200518.pt"
|
||||||
FASTP_URL="https://api.ngc.nvidia.com/v2/models/nvidia/fastpitch_pyt_amp_ckpt_v1/versions/20.02.0/zip"
|
MODEL_URL="https://api.ngc.nvidia.com/v2/models/nvidia/fastpitch_pyt_amp_ckpt_v1/versions/20.02.0/zip"
|
||||||
|
|
||||||
mkdir -p "$MODEL_DIR"/fastpitch
|
mkdir -p "$MODEL_DIR"
|
||||||
|
|
||||||
if [ ! -f "${MODEL_DIR}/fastpitch/${FASTP_ZIP}" ]; then
|
if [ ! -f "${MODEL_DIR}/${MODEL_ZIP}" ]; then
|
||||||
echo "Downloading ${FASTP_ZIP} ..."
|
echo "Downloading ${MODEL_ZIP} ..."
|
||||||
wget -qO ${MODEL_DIR}/fastpitch/${FASTP_ZIP} ${FASTP_URL}
|
wget -qO ${MODEL_DIR}/${MODEL_ZIP} ${MODEL_URL} \
|
||||||
|
|| echo "ERROR: Failed to download ${MODEL_ZIP} from NGC" && exit 1
|
||||||
fi
|
fi
|
||||||
|
|
||||||
if [ ! -f "${MODEL_DIR}/fastpitch/${FASTP_CH}" ]; then
|
if [ ! -f "${MODEL_DIR}/${MODEL_CH}" ]; then
|
||||||
echo "Extracting ${FASTP_CH} ..."
|
echo "Extracting ${MODEL_CH} ..."
|
||||||
unzip -qo ${MODEL_DIR}/fastpitch/${FASTP_ZIP} -d ${MODEL_DIR}/fastpitch/
|
unzip -qo ${MODEL_DIR}/${MODEL_ZIP} -d ${MODEL_DIR}
|
||||||
fi
|
fi
|
||||||
|
|
|
@ -2,12 +2,18 @@
|
||||||
|
|
||||||
set -e
|
set -e
|
||||||
|
|
||||||
MODEL_DIR=${MODEL_DIR:-"pretrained_models"}
|
: ${MODEL_DIR:="pretrained_models/tacotron2"}
|
||||||
TACO_CH="nvidia_tacotron2pyt_fp32_20190427.pt"
|
MODEL="nvidia_tacotron2pyt_fp16.pt"
|
||||||
TACO_URL="https://api.ngc.nvidia.com/v2/models/nvidia/tacotron2pyt_fp32/versions/2/files/nvidia_tacotron2pyt_fp32_20190427"
|
MODEL_URL="https://api.ngc.nvidia.com/v2/models/nvidia/tacotron2_pyt_ckpt_amp/versions/19.12.0/files/nvidia_tacotron2pyt_fp16.pt"
|
||||||
|
|
||||||
if [ ! -f "${MODEL_DIR}/tacotron2/${TACO_CH}" ]; then
|
mkdir -p "$MODEL_DIR"
|
||||||
echo "Downloading ${TACO_CH} ..."
|
|
||||||
mkdir -p "$MODEL_DIR"/tacotron2
|
if [ ! -f "${MODEL_DIR}/${MODEL}" ]; then
|
||||||
wget -qO ${MODEL_DIR}/tacotron2/${TACO_CH} ${TACO_URL}
|
echo "Downloading ${MODEL} ..."
|
||||||
|
wget --content-disposition -qO ${MODEL_DIR}/${MODEL} ${MODEL_URL} \
|
||||||
|
|| echo "ERROR: Failed to download ${MODEL} from NGC" && exit 1
|
||||||
|
echo "OK"
|
||||||
|
|
||||||
|
else
|
||||||
|
echo "${MODEL}.pt already downloaded."
|
||||||
fi
|
fi
|
||||||
|
|
|
@ -2,19 +2,26 @@
|
||||||
|
|
||||||
set -e
|
set -e
|
||||||
|
|
||||||
MODEL_DIR=${MODEL_DIR:-"pretrained_models"}
|
: ${MODEL_DIR:="pretrained_models/waveglow"}
|
||||||
WAVEG="waveglow_1076430_14000_amp"
|
MODEL="nvidia_waveglow256pyt_fp16"
|
||||||
WAVEG_URL="https://api.ngc.nvidia.com/v2/models/nvidia/waveglow256pyt_fp16/versions/2/zip"
|
MODEL_ZIP="waveglow_ckpt_amp_256_20.01.0.zip"
|
||||||
|
MODEL_URL="https://api.ngc.nvidia.com/v2/models/nvidia/waveglow_ckpt_amp_256/versions/20.01.0/zip"
|
||||||
|
|
||||||
mkdir -p "$MODEL_DIR"/waveglow
|
mkdir -p "$MODEL_DIR"
|
||||||
|
|
||||||
if [ ! -f "${MODEL_DIR}/waveglow/${WAVEG}.zip" ]; then
|
if [ ! -f "${MODEL_DIR}/${MODEL_ZIP}" ]; then
|
||||||
echo "Downloading ${WAVEG}.zip ..."
|
echo "Downloading ${MODEL_ZIP} ..."
|
||||||
wget -qO "${MODEL_DIR}/waveglow/${WAVEG}.zip" ${WAVEG_URL}
|
wget --content-disposition -qO ${MODEL_DIR}/${MODEL_ZIP} ${MODEL_URL} \
|
||||||
|
|| echo "ERROR: Failed to download ${MODEL_ZIP} from NGC" && exit 1
|
||||||
fi
|
fi
|
||||||
|
|
||||||
if [ ! -f "${MODEL_DIR}/waveglow/${WAVEG}.pt" ]; then
|
if [ ! -f "${MODEL_DIR}/${MODEL}.pt" ]; then
|
||||||
echo "Extracting ${WAVEG} ..."
|
echo "Extracting ${MODEL} ..."
|
||||||
unzip -qo "${MODEL_DIR}/waveglow/${WAVEG}.zip" -d ${MODEL_DIR}/waveglow/
|
unzip -qo ${MODEL_DIR}/${MODEL_ZIP} -d ${MODEL_DIR} \
|
||||||
mv "${MODEL_DIR}/waveglow/${WAVEG}" "${MODEL_DIR}/waveglow/${WAVEG}.pt"
|
|| echo "ERROR: Failed to extract ${MODEL_ZIP}" && exit 1
|
||||||
|
|
||||||
|
echo "OK"
|
||||||
|
|
||||||
|
else
|
||||||
|
echo "${MODEL}.pt already downloaded."
|
||||||
fi
|
fi
|
||||||
|
|
|
@ -1,22 +1,26 @@
|
||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
|
|
||||||
[ ! -n "$WAVEG_CH" ] && WAVEG_CH="pretrained_models/waveglow/waveglow_1076430_14000_amp.pt"
|
: ${WAVEGLOW:="pretrained_models/waveglow/nvidia_waveglow256pyt_fp16.pt"}
|
||||||
[ ! -n "$FASTPITCH_CH" ] && FASTPITCH_CH="output/FastPitch_checkpoint_1500.pt"
|
: ${FASTPITCH:="output/FastPitch_checkpoint_1500.pt"}
|
||||||
[ ! -n "$REPEATS" ] && REPEATS=1000
|
: ${REPEATS:=1000}
|
||||||
[ ! -n "$BS_SEQ" ] && BS_SEQ="1 4 8"
|
: ${BS_SEQUENCE:="1 4 8"}
|
||||||
[ ! -n "$PHRASES" ] && PHRASES="phrases/benchmark_8_128.tsv"
|
: ${PHRASES:="phrases/benchmark_8_128.tsv"}
|
||||||
[ ! -n "$OUTPUT_DIR" ] && OUTPUT_DIR="./output/audio_$(basename ${PHRASES} .tsv)"
|
: ${OUTPUT_DIR:="./output/audio_$(basename ${PHRASES} .tsv)"}
|
||||||
[ "$AMP" == "true" ] && AMP_FLAG="--amp" || AMP=false
|
: ${AMP:=false}
|
||||||
|
|
||||||
for BS in $BS_SEQ ; do
|
[ "$AMP" = true ] && AMP_FLAG="--amp"
|
||||||
|
|
||||||
|
mkdir -o "$OUTPUT_DIR"
|
||||||
|
|
||||||
|
for BS in $BS_SEQUENCE ; do
|
||||||
|
|
||||||
echo -e "\nAMP: ${AMP}, batch size: ${BS}\n"
|
echo -e "\nAMP: ${AMP}, batch size: ${BS}\n"
|
||||||
|
|
||||||
python inference.py --cuda --cudnn-benchmark \
|
python inference.py --cuda --cudnn-benchmark \
|
||||||
-i ${PHRASES} \
|
-i ${PHRASES} \
|
||||||
-o ${OUTPUT_DIR} \
|
-o ${OUTPUT_DIR} \
|
||||||
--fastpitch ${FASTPITCH_CH} \
|
--fastpitch ${FASTPITCH} \
|
||||||
--waveglow ${WAVEG_CH} \
|
--waveglow ${WAVEGLOW} \
|
||||||
--wn-channels 256 \
|
--wn-channels 256 \
|
||||||
--include-warmup \
|
--include-warmup \
|
||||||
--batch-size ${BS} \
|
--batch-size ${BS} \
|
||||||
|
|
|
@ -1,22 +1,21 @@
|
||||||
#!/usr/bin/env bash
|
#!/usr/bin/env bash
|
||||||
|
|
||||||
DATA_DIR="LJSpeech-1.1"
|
: ${WAVEGLOW:="pretrained_models/waveglow/nvidia_waveglow256pyt_fp16.pt"}
|
||||||
|
: ${FASTPITCH:="output/FastPitch_checkpoint_1500.pt"}
|
||||||
|
: ${BS:=32}
|
||||||
|
: ${PHRASES:="phrases/devset10.tsv"}
|
||||||
|
: ${OUTPUT_DIR:="./output/audio_$(basename ${PHRASES} .tsv)"}
|
||||||
|
: ${AMP:=false}
|
||||||
|
|
||||||
[ ! -n "$WAVEG_CH" ] && WAVEG_CH="pretrained_models/waveglow/waveglow_1076430_14000_amp.pt"
|
[ "$AMP" = true ] && AMP_FLAG="--amp"
|
||||||
[ ! -n "$FASTPITCH_CH" ] && FASTPITCH_CH="output/FastPitch_checkpoint_1500.pt"
|
|
||||||
[ ! -n "$BS" ] && BS=32
|
|
||||||
[ ! -n "$PHRASES" ] && PHRASES="phrases/devset10.tsv"
|
|
||||||
[ ! -n "$OUTPUT_DIR" ] && OUTPUT_DIR="./output/audio_$(basename ${PHRASES} .tsv)"
|
|
||||||
[ "$AMP" == "true" ] && AMP_FLAG="--amp"
|
|
||||||
|
|
||||||
mkdir -p "$OUTPUT_DIR"
|
mkdir -p "$OUTPUT_DIR"
|
||||||
|
|
||||||
python inference.py --cuda \
|
python inference.py --cuda \
|
||||||
-i ${PHRASES} \
|
-i ${PHRASES} \
|
||||||
-o ${OUTPUT_DIR} \
|
-o ${OUTPUT_DIR} \
|
||||||
--dataset-path ${DATA_DIR} \
|
--fastpitch ${FASTPITCH} \
|
||||||
--fastpitch ${FASTPITCH_CH} \
|
--waveglow ${WAVEGLOW} \
|
||||||
--waveglow ${WAVEG_CH} \
|
|
||||||
--wn-channels 256 \
|
--wn-channels 256 \
|
||||||
--batch-size ${BS} \
|
--batch-size ${BS} \
|
||||||
${AMP_FLAG}
|
${AMP_FLAG}
|
||||||
|
|
|
@ -3,7 +3,7 @@
|
||||||
set -e
|
set -e
|
||||||
|
|
||||||
DATA_DIR="LJSpeech-1.1"
|
DATA_DIR="LJSpeech-1.1"
|
||||||
TACO_CH="pretrained_models/tacotron2/nvidia_tacotron2pyt_fp32_20190427.pt"
|
TACOTRON2="pretrained_models/tacotron2/nvidia_tacotron2pyt_fp16.pt"
|
||||||
for FILELIST in ljs_audio_text_train_filelist.txt \
|
for FILELIST in ljs_audio_text_train_filelist.txt \
|
||||||
ljs_audio_text_val_filelist.txt \
|
ljs_audio_text_val_filelist.txt \
|
||||||
ljs_audio_text_test_filelist.txt \
|
ljs_audio_text_test_filelist.txt \
|
||||||
|
@ -16,5 +16,5 @@ for FILELIST in ljs_audio_text_train_filelist.txt \
|
||||||
--extract-mels \
|
--extract-mels \
|
||||||
--extract-durations \
|
--extract-durations \
|
||||||
--extract-pitch-char \
|
--extract-pitch-char \
|
||||||
--tacotron2-checkpoint ${TACO_CH}
|
--tacotron2-checkpoint ${TACOTRON2}
|
||||||
done
|
done
|
||||||
|
|
|
@ -2,24 +2,26 @@
|
||||||
|
|
||||||
export OMP_NUM_THREADS=1
|
export OMP_NUM_THREADS=1
|
||||||
|
|
||||||
|
: ${NUM_GPUS:=8}
|
||||||
|
: ${BS:=32}
|
||||||
|
: ${GRAD_ACCUMULATION:=1}
|
||||||
|
: ${OUTPUT_DIR:="./output"}
|
||||||
|
: ${AMP:=false}
|
||||||
|
: ${EPOCHS:=1500}
|
||||||
|
|
||||||
|
[ "$AMP" == "true" ] && AMP_FLAG="--amp"
|
||||||
|
|
||||||
# Adjust env variables to maintain the global batch size
|
# Adjust env variables to maintain the global batch size
|
||||||
#
|
#
|
||||||
# NGPU x BS x GRAD_ACC = 256.
|
# NGPU x BS x GRAD_ACC = 256.
|
||||||
|
#
|
||||||
[ ! -n "$OUTPUT_DIR" ] && OUTPUT_DIR="./output"
|
GBS=$(($NUM_GPUS * $BS * $GRAD_ACCUMULATION))
|
||||||
[ ! -n "$NGPU" ] && NGPU=8
|
|
||||||
[ ! -n "$BS" ] && BS=32
|
|
||||||
[ ! -n "$GRAD_ACC" ] && GRAD_ACC=1
|
|
||||||
[ ! -n "$EPOCHS" ] && EPOCHS=1500
|
|
||||||
[ "$AMP" == "true" ] && AMP_FLAG="--amp"
|
|
||||||
|
|
||||||
GBS=$(($NGPU * $BS * $GRAD_ACC))
|
|
||||||
[ $GBS -ne 256 ] && echo -e "\nWARNING: Global batch size changed from 256 to ${GBS}.\n"
|
[ $GBS -ne 256 ] && echo -e "\nWARNING: Global batch size changed from 256 to ${GBS}.\n"
|
||||||
|
|
||||||
echo -e "\nSetup: ${NGPU}x${BS}x${GRAD_ACC} - global batch size ${GBS}\n"
|
echo -e "\nSetup: ${NUM_GPUS}x${BS}x${GRAD_ACCUMULATION} - global batch size ${GBS}\n"
|
||||||
|
|
||||||
mkdir -p "$OUTPUT_DIR"
|
mkdir -p "$OUTPUT_DIR"
|
||||||
python -m torch.distributed.launch --nproc_per_node ${NGPU} train.py \
|
python -m torch.distributed.launch --nproc_per_node ${NUM_GPUS} train.py \
|
||||||
--cuda \
|
--cuda \
|
||||||
-o "$OUTPUT_DIR/" \
|
-o "$OUTPUT_DIR/" \
|
||||||
--log-file "$OUTPUT_DIR/nvlog.json" \
|
--log-file "$OUTPUT_DIR/nvlog.json" \
|
||||||
|
@ -37,5 +39,5 @@ python -m torch.distributed.launch --nproc_per_node ${NGPU} train.py \
|
||||||
--dur-predictor-loss-scale 0.1 \
|
--dur-predictor-loss-scale 0.1 \
|
||||||
--pitch-predictor-loss-scale 0.1 \
|
--pitch-predictor-loss-scale 0.1 \
|
||||||
--weight-decay 1e-6 \
|
--weight-decay 1e-6 \
|
||||||
--gradient-accumulation-steps ${GRAD_ACC} \
|
--gradient-accumulation-steps ${GRAD_ACCUMULATION} \
|
||||||
${AMP_FLAG}
|
${AMP_FLAG}
|
||||||
|
|
|
@ -27,9 +27,6 @@
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
from common.text import symbols
|
|
||||||
|
|
||||||
|
|
||||||
def parse_tacotron2_args(parent, add_help=False):
|
def parse_tacotron2_args(parent, add_help=False):
|
||||||
"""
|
"""
|
||||||
Parse commandline arguments.
|
Parse commandline arguments.
|
||||||
|
@ -43,11 +40,7 @@ def parse_tacotron2_args(parent, add_help=False):
|
||||||
help='Number of bins in mel-spectrograms')
|
help='Number of bins in mel-spectrograms')
|
||||||
|
|
||||||
# symbols parameters
|
# symbols parameters
|
||||||
global symbols
|
|
||||||
len_symbols = len(symbols)
|
|
||||||
symbols = parser.add_argument_group('symbols parameters')
|
symbols = parser.add_argument_group('symbols parameters')
|
||||||
symbols.add_argument('--n-symbols', default=len_symbols, type=int,
|
|
||||||
help='Number of symbols in dictionary')
|
|
||||||
symbols.add_argument('--symbols-embedding-dim', default=512, type=int,
|
symbols.add_argument('--symbols-embedding-dim', default=512, type=int,
|
||||||
help='Input embedding dimension')
|
help='Input embedding dimension')
|
||||||
|
|
||||||
|
|
|
@ -33,8 +33,6 @@ import torch.utils.data
|
||||||
|
|
||||||
import common.layers as layers
|
import common.layers as layers
|
||||||
from common.utils import load_wav_to_torch, load_filepaths_and_text, to_gpu
|
from common.utils import load_wav_to_torch, load_filepaths_and_text, to_gpu
|
||||||
from common.text import text_to_sequence
|
|
||||||
|
|
||||||
|
|
||||||
class TextMelLoader(torch.utils.data.Dataset):
|
class TextMelLoader(torch.utils.data.Dataset):
|
||||||
"""
|
"""
|
||||||
|
@ -43,8 +41,9 @@ class TextMelLoader(torch.utils.data.Dataset):
|
||||||
3) computes mel-spectrograms from audio files.
|
3) computes mel-spectrograms from audio files.
|
||||||
"""
|
"""
|
||||||
def __init__(self, dataset_path, audiopaths_and_text, args, load_mel_from_disk=True):
|
def __init__(self, dataset_path, audiopaths_and_text, args, load_mel_from_disk=True):
|
||||||
self.audiopaths_and_text = load_filepaths_and_text(dataset_path, audiopaths_and_text)
|
self.audiopaths_and_text = load_filepaths_and_text(
|
||||||
self.text_cleaners = args.text_cleaners
|
dataset_path, audiopaths_and_text,
|
||||||
|
has_speakers=(args.n_speakers > 1))
|
||||||
self.load_mel_from_disk = load_mel_from_disk
|
self.load_mel_from_disk = load_mel_from_disk
|
||||||
if not load_mel_from_disk:
|
if not load_mel_from_disk:
|
||||||
self.max_wav_value = args.max_wav_value
|
self.max_wav_value = args.max_wav_value
|
||||||
|
@ -74,14 +73,14 @@ class TextMelLoader(torch.utils.data.Dataset):
|
||||||
return melspec
|
return melspec
|
||||||
|
|
||||||
def get_text(self, text):
|
def get_text(self, text):
|
||||||
text_norm = torch.IntTensor(text_to_sequence(text, self.text_cleaners))
|
text_encoded = torch.IntTensor(self.tp.encode_text(text))
|
||||||
return text_norm
|
return text_encoded
|
||||||
|
|
||||||
def __getitem__(self, index):
|
def __getitem__(self, index):
|
||||||
# separate filename and text
|
# separate filename and text
|
||||||
audiopath, text = self.audiopaths_and_text[index]
|
audiopath, text = self.audiopaths_and_text[index]
|
||||||
len_text = len(text)
|
|
||||||
text = self.get_text(text)
|
text = self.get_text(text)
|
||||||
|
len_text = len(text)
|
||||||
mel = self.get_mel(audiopath)
|
mel = self.get_mel(audiopath)
|
||||||
return (text, mel, len_text)
|
return (text, mel, len_text)
|
||||||
|
|
||||||
|
|
|
@ -108,14 +108,20 @@ def parse_args(parser):
|
||||||
|
|
||||||
dataset = parser.add_argument_group('dataset parameters')
|
dataset = parser.add_argument_group('dataset parameters')
|
||||||
dataset.add_argument('--training-files', type=str, required=True,
|
dataset.add_argument('--training-files', type=str, required=True,
|
||||||
help='Path to training filelist')
|
help='Path to training filelist. Separate multiple paths with commas.')
|
||||||
dataset.add_argument('--validation-files', type=str, required=True,
|
dataset.add_argument('--validation-files', type=str, required=True,
|
||||||
help='Path to validation filelist')
|
help='Path to validation filelist. Separate multiple paths with commas.')
|
||||||
dataset.add_argument('--pitch-mean-std-file', type=str, default=None,
|
dataset.add_argument('--pitch-mean-std-file', type=str, default=None,
|
||||||
help='Path to pitch stats to be stored in the model')
|
help='Path to pitch stats to be stored in the model')
|
||||||
dataset.add_argument('--text-cleaners', nargs='*',
|
dataset.add_argument('--text-cleaners', nargs='*',
|
||||||
default=['english_cleaners'], type=str,
|
default=['english_cleaners'], type=str,
|
||||||
help='Type of text cleaners for input text')
|
help='Type of text cleaners for input text')
|
||||||
|
dataset.add_argument('--symbol-set', type=str, default='english_basic',
|
||||||
|
help='Define symbol set for input text')
|
||||||
|
|
||||||
|
cond = parser.add_argument_group('conditioning on additional attributes')
|
||||||
|
cond.add_argument('--n-speakers', type=int, default=1,
|
||||||
|
help='Condition on speaker, value > 1 enables trainable speaker embeddings.')
|
||||||
|
|
||||||
distributed = parser.add_argument_group('distributed setup')
|
distributed = parser.add_argument_group('distributed setup')
|
||||||
distributed.add_argument('--local_rank', type=int, default=os.getenv('LOCAL_RANK', 0),
|
distributed.add_argument('--local_rank', type=int, default=os.getenv('LOCAL_RANK', 0),
|
||||||
|
@ -316,8 +322,6 @@ def main():
|
||||||
model = models.get_model('FastPitch', model_config, device)
|
model = models.get_model('FastPitch', model_config, device)
|
||||||
|
|
||||||
# Store pitch mean/std as params to translate from Hz during inference
|
# Store pitch mean/std as params to translate from Hz during inference
|
||||||
fpath = common.utils.stats_filename(
|
|
||||||
args.dataset_path, args.training_files, 'pitch_char')
|
|
||||||
with open(args.pitch_mean_std_file, 'r') as f:
|
with open(args.pitch_mean_std_file, 'r') as f:
|
||||||
stats = json.load(f)
|
stats = json.load(f)
|
||||||
model.pitch_mean[0] = stats['mean']
|
model.pitch_mean[0] = stats['mean']
|
||||||
|
@ -530,6 +534,13 @@ def main():
|
||||||
validate(model, None, total_iter, criterion, valset, args.batch_size,
|
validate(model, None, total_iter, criterion, valset, args.batch_size,
|
||||||
collate_fn, distributed_run, batch_to_gpu, use_gt_durations=True)
|
collate_fn, distributed_run, batch_to_gpu, use_gt_durations=True)
|
||||||
|
|
||||||
|
if (epoch > 0 and args.epochs_per_checkpoint > 0 and
|
||||||
|
(epoch % args.epochs_per_checkpoint != 0) and args.local_rank == 0):
|
||||||
|
checkpoint_path = os.path.join(
|
||||||
|
args.output, f"FastPitch_checkpoint_{epoch}.pt")
|
||||||
|
save_checkpoint(args.local_rank, model, ema_model, optimizer, epoch,
|
||||||
|
total_iter, model_config, args.amp, checkpoint_path)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
main()
|
main()
|
||||||
|
|
|
@ -58,6 +58,7 @@ class Invertible1x1Conv(torch.nn.Module):
|
||||||
if torch.det(W) < 0:
|
if torch.det(W) < 0:
|
||||||
W[:, 0] = -1 * W[:, 0]
|
W[:, 0] = -1 * W[:, 0]
|
||||||
W = W.view(c, c, 1)
|
W = W.view(c, c, 1)
|
||||||
|
W = W.contiguous()
|
||||||
self.conv.weight.data = W
|
self.conv.weight.data = W
|
||||||
|
|
||||||
def forward(self, z):
|
def forward(self, z):
|
||||||
|
@ -279,6 +280,49 @@ class WaveGlow(torch.nn.Module):
|
||||||
return audio
|
return audio
|
||||||
|
|
||||||
|
|
||||||
|
def infer_onnx(self, spect, z, sigma=0.9):
|
||||||
|
|
||||||
|
spect = self.upsample(spect)
|
||||||
|
# trim conv artifacts. maybe pad spec to kernel multiple
|
||||||
|
time_cutoff = self.upsample.kernel_size[0] - self.upsample.stride[0]
|
||||||
|
spect = spect[:, :, :-time_cutoff]
|
||||||
|
|
||||||
|
length_spect_group = spect.size(2)//8
|
||||||
|
mel_dim = 80
|
||||||
|
batch_size = spect.size(0)
|
||||||
|
|
||||||
|
spect = spect.view((batch_size, mel_dim, length_spect_group, self.n_group))
|
||||||
|
spect = spect.permute(0, 2, 1, 3)
|
||||||
|
spect = spect.contiguous()
|
||||||
|
spect = spect.view((batch_size, length_spect_group, self.n_group*mel_dim))
|
||||||
|
spect = spect.permute(0, 2, 1)
|
||||||
|
spect = spect.contiguous()
|
||||||
|
|
||||||
|
audio = z[:, :self.n_remaining_channels, :]
|
||||||
|
z = z[:, self.n_remaining_channels:self.n_group, :]
|
||||||
|
audio = sigma*audio
|
||||||
|
|
||||||
|
for k in reversed(range(self.n_flows)):
|
||||||
|
n_half = int(audio.size(1) // 2)
|
||||||
|
audio_0 = audio[:, :n_half, :]
|
||||||
|
audio_1 = audio[:, n_half:(n_half+n_half), :]
|
||||||
|
|
||||||
|
output = self.WN[k]((audio_0, spect))
|
||||||
|
s = output[:, n_half:(n_half+n_half), :]
|
||||||
|
b = output[:, :n_half, :]
|
||||||
|
audio_1 = (audio_1 - b) / torch.exp(s)
|
||||||
|
audio = torch.cat([audio_0, audio_1], 1)
|
||||||
|
audio = self.convinv[k].infer(audio)
|
||||||
|
|
||||||
|
if k % self.n_early_every == 0 and k > 0:
|
||||||
|
audio = torch.cat((z[:, :self.n_early_size, :], audio), 1)
|
||||||
|
z = z[:, self.n_early_size:self.n_group, :]
|
||||||
|
|
||||||
|
audio = audio.permute(0,2,1).contiguous().view(batch_size, (length_spect_group * self.n_group))
|
||||||
|
|
||||||
|
return audio
|
||||||
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def remove_weightnorm(model):
|
def remove_weightnorm(model):
|
||||||
waveglow = model
|
waveglow = model
|
||||||
|
|
Loading…
Reference in a new issue