Merge branch 'main' into cluster_model
This commit is contained in:
commit
426bfeebea
76
Jenkinsfile
vendored
76
Jenkinsfile
vendored
|
@ -190,33 +190,7 @@ pipeline {
|
|||
sh 'rm -rf examples/asr/speech_to_text_results'
|
||||
}
|
||||
}
|
||||
// stage('Speech to Text - DALI AudioToMelSpectrogramPreprocessor') {
|
||||
// steps {
|
||||
// sh 'python examples/asr/speech_to_text.py \
|
||||
// model.train_ds.manifest_filepath=/home/TestData/an4_dataset/an4_train.json \
|
||||
// +model.train_ds.use_dali=True \
|
||||
// model.validation_ds.manifest_filepath=/home/TestData/an4_dataset/an4_val.json \
|
||||
// +model.validation_ds.use_dali=True \
|
||||
// trainer.gpus=[0] \
|
||||
// +trainer.fast_dev_run=True \
|
||||
// exp_manager.exp_dir=examples/asr/speech_to_text_results'
|
||||
// sh 'rm -rf examples/asr/speech_to_text_results'
|
||||
// }
|
||||
// }
|
||||
// stage('Speech to Text - DALI AudioToMFCCPreprocessor') {
|
||||
// steps {
|
||||
// sh 'python examples/asr/speech_to_text.py \
|
||||
// model.train_ds.manifest_filepath=/home/TestData/an4_dataset/an4_train.json \
|
||||
// +model.train_ds.use_dali=True \
|
||||
// model.validation_ds.manifest_filepath=/home/TestData/an4_dataset/an4_val.json \
|
||||
// +model.validation_ds.use_dali=True \
|
||||
// model.preprocessor._target_=nemo.collections.asr.modules.AudioToMFCCPreprocessor \
|
||||
// trainer.gpus=[0] \
|
||||
// +trainer.fast_dev_run=True \
|
||||
// exp_manager.exp_dir=examples/asr/speech_to_text_results'
|
||||
// sh 'rm -rf examples/asr/speech_to_text_results'
|
||||
// }
|
||||
// }
|
||||
|
||||
stage('Speech to Label') {
|
||||
steps {
|
||||
sh 'python examples/asr/speech_to_label.py \
|
||||
|
@ -299,6 +273,54 @@ pipeline {
|
|||
}
|
||||
}
|
||||
|
||||
stage('L2: ASR DALI dev run') {
|
||||
when {
|
||||
anyOf {
|
||||
branch 'main'
|
||||
changeRequest target: 'main'
|
||||
}
|
||||
}
|
||||
failFast true
|
||||
parallel {
|
||||
stage('Speech to Text - DALI AudioToMelSpectrogramPreprocessor') {
|
||||
steps {
|
||||
sh 'python examples/asr/speech_to_text.py \
|
||||
model.train_ds.manifest_filepath=/home/TestData/an4_dataset/an4_train.json \
|
||||
+model.train_ds.use_dali=True \
|
||||
model.validation_ds.manifest_filepath=/home/TestData/an4_dataset/an4_val.json \
|
||||
+model.validation_ds.use_dali=True \
|
||||
trainer.gpus=[0] \
|
||||
+trainer.fast_dev_run=True \
|
||||
exp_manager.exp_dir=examples/asr/speech_to_text_results'
|
||||
sh 'rm -rf examples/asr/speech_to_text_results'
|
||||
}
|
||||
}
|
||||
// TODO: This would fail due to an unnecessary torchaudio import.
|
||||
// To be enabled once torchaudio is available in the container used for CI
|
||||
// stage('Speech to Text - DALI AudioToMFCCPreprocessor') {
|
||||
// steps {
|
||||
// sh 'python examples/asr/speech_to_text.py \
|
||||
// model.train_ds.manifest_filepath=/home/TestData/an4_dataset/an4_train.json \
|
||||
// +model.train_ds.use_dali=True \
|
||||
// model.validation_ds.manifest_filepath=/home/TestData/an4_dataset/an4_val.json \
|
||||
// +model.validation_ds.use_dali=True \
|
||||
// model.preprocessor._target_=nemo.collections.asr.modules.AudioToMFCCPreprocessor \
|
||||
// ~model.preprocessor.normalize \
|
||||
// ~model.preprocessor.features \
|
||||
// ~model.preprocessor.frame_splicing \
|
||||
// ~model.preprocessor.dither \
|
||||
// ~model.preprocessor.stft_conv \
|
||||
// +model.n_mels=64 \
|
||||
// +model.n_mfcc=64 \
|
||||
// trainer.gpus=[0] \
|
||||
// +trainer.fast_dev_run=True \
|
||||
// exp_manager.exp_dir=examples/asr/speech_to_text_results'
|
||||
// sh 'rm -rf examples/asr/speech_to_text_results'
|
||||
// }
|
||||
// }
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: UNCOMMENT TESTS AFTER 21.04 release (numba 0.53 min requirement)
|
||||
// stage('L2: ASR RNNT dev run') {
|
||||
// when {
|
||||
|
|
|
@ -32,12 +32,6 @@
|
|||
Introduction
|
||||
------------
|
||||
|
||||
NeMo is a toolkit for creating `Conversational AI <https://developer.nvidia.com/conversational-ai#started>`_ applications.
|
||||
|
||||
`NeMo product page. <https://developer.nvidia.com/nvidia-nemo>`_
|
||||
|
||||
`Introductory video. <https://www.youtube.com/embed/wBgpMf_KQVw>`_
|
||||
|
||||
The toolkit comes with extendable collections of pre-built modules and ready-to-use models for:
|
||||
|
||||
* `Automatic Speech Recognition (ASR) <https://ngc.nvidia.com/catalog/collections/nvidia:nemo_asr>`_
|
||||
|
|
|
@ -62,7 +62,7 @@ Modules
|
|||
Parts
|
||||
-----
|
||||
|
||||
.. autoclass:: nemo.collections.asr.parts.jasper.JasperBlock
|
||||
.. autoclass:: nemo.collections.asr.parts.submodules.jasper.JasperBlock
|
||||
:show-inheritance:
|
||||
:members:
|
||||
|
||||
|
@ -70,11 +70,11 @@ Parts
|
|||
Mixins
|
||||
------
|
||||
|
||||
.. autoclass:: nemo.collections.asr.parts.mixins.ASRBPEMixin
|
||||
.. autoclass:: nemo.collections.asr.parts.mixins.mixins.ASRBPEMixin
|
||||
:show-inheritance:
|
||||
:members:
|
||||
|
||||
.. autoclass:: nemo.collections.asr.parts.mixins.ASRModuleMixin
|
||||
.. autoclass:: nemo.collections.asr.parts.mixins.mixins.ASRModuleMixin
|
||||
:show-inheritance:
|
||||
:members:
|
||||
|
||||
|
@ -129,39 +129,39 @@ Audio Augmentors
|
|||
:show-inheritance:
|
||||
:members:
|
||||
|
||||
.. autoclass:: nemo.collections.asr.parts.perturb.SpeedPerturbation
|
||||
.. autoclass:: nemo.collections.asr.parts.preprocessing.perturb.SpeedPerturbation
|
||||
:show-inheritance:
|
||||
:members:
|
||||
|
||||
.. autoclass:: nemo.collections.asr.parts.perturb.TimeStretchPerturbation
|
||||
.. autoclass:: nemo.collections.asr.parts.preprocessing.perturb.TimeStretchPerturbation
|
||||
:show-inheritance:
|
||||
:members:
|
||||
|
||||
.. autoclass:: nemo.collections.asr.parts.perturb.GainPerturbation
|
||||
.. autoclass:: nemo.collections.asr.parts.preprocessing.perturb.GainPerturbation
|
||||
:show-inheritance:
|
||||
:members:
|
||||
|
||||
.. autoclass:: nemo.collections.asr.parts.perturb.ImpulsePerturbation
|
||||
.. autoclass:: nemo.collections.asr.parts.preprocessing.perturb.ImpulsePerturbation
|
||||
:show-inheritance:
|
||||
:members:
|
||||
|
||||
.. autoclass:: nemo.collections.asr.parts.perturb.ShiftPerturbation
|
||||
.. autoclass:: nemo.collections.asr.parts.preprocessing.perturb.ShiftPerturbation
|
||||
:show-inheritance:
|
||||
:members:
|
||||
|
||||
.. autoclass:: nemo.collections.asr.parts.perturb.NoisePerturbation
|
||||
.. autoclass:: nemo.collections.asr.parts.preprocessing.perturb.NoisePerturbation
|
||||
:show-inheritance:
|
||||
:members:
|
||||
|
||||
.. autoclass:: nemo.collections.asr.parts.perturb.WhiteNoisePerturbation
|
||||
.. autoclass:: nemo.collections.asr.parts.preprocessing.perturb.WhiteNoisePerturbation
|
||||
:show-inheritance:
|
||||
:members:
|
||||
|
||||
.. autoclass:: nemo.collections.asr.parts.perturb.RirAndNoisePerturbation
|
||||
.. autoclass:: nemo.collections.asr.parts.preprocessing.perturb.RirAndNoisePerturbation
|
||||
:show-inheritance:
|
||||
:members:
|
||||
|
||||
.. autoclass:: nemo.collections.asr.parts.perturb.TranscodePerturbation
|
||||
.. autoclass:: nemo.collections.asr.parts.preprocessing.perturb.TranscodePerturbation
|
||||
:show-inheritance:
|
||||
:members:
|
||||
|
||||
|
@ -179,25 +179,25 @@ RNNT Decoding
|
|||
:show-inheritance:
|
||||
:members:
|
||||
|
||||
.. autoclass:: nemo.collections.asr.parts.rnnt_greedy_decoding.GreedyRNNTInfer
|
||||
.. autoclass:: nemo.collections.asr.parts.submodules.rnnt_greedy_decoding.GreedyRNNTInfer
|
||||
:show-inheritance:
|
||||
:members:
|
||||
|
||||
.. autoclass:: nemo.collections.asr.parts.rnnt_greedy_decoding.GreedyBatchedRNNTInfer
|
||||
.. autoclass:: nemo.collections.asr.parts.submodules.rnnt_greedy_decoding.GreedyBatchedRNNTInfer
|
||||
:show-inheritance:
|
||||
:members:
|
||||
|
||||
.. autoclass:: nemo.collections.asr.parts.rnnt_beam_decoding.BeamRNNTInfer
|
||||
.. autoclass:: nemo.collections.asr.parts.submodules.rnnt_beam_decoding.BeamRNNTInfer
|
||||
:show-inheritance:
|
||||
:members:
|
||||
|
||||
Hypotheses
|
||||
~~~~~~~~~~
|
||||
|
||||
.. autoclass:: nemo.collections.asr.parts.rnnt_utils.Hypothesis
|
||||
.. autoclass:: nemo.collections.asr.parts.utils.rnnt_utils.Hypothesis
|
||||
:show-inheritance:
|
||||
:no-members:
|
||||
|
||||
.. autoclass:: nemo.collections.asr.parts.rnnt_utils.NBestHypotheses
|
||||
.. autoclass:: nemo.collections.asr.parts.utils.rnnt_utils.NBestHypotheses
|
||||
:show-inheritance:
|
||||
:no-members:
|
||||
|
|
|
@ -342,7 +342,7 @@ configuration is a shortform notation for Citrinet-21x5xC, such that ``B = 21``
|
|||
not be changed.
|
||||
|
||||
To use Citrinet instead of QuartzNet, refer to the ``citrinet_512.yaml`` configuration found inside the ``examples/asr/conf/citrinet``
|
||||
directory. Citrinet is primarily comprised of the same :class:`~nemo.collections.asr.parts.jasper.JasperBlock` as ``Jasper`` or
|
||||
directory. Citrinet is primarily comprised of the same :class:`~nemo.collections.asr.parts.submodules.jasper.JasperBlock` as ``Jasper`` or
|
||||
``QuartzNet`.
|
||||
|
||||
While the configs for Citrinet and QuartzNet are similar, we note the additional flags used for Citrinet below. Refer to the
|
||||
|
@ -442,7 +442,7 @@ changed slightly as Citrinet utilizes sub-word tokenization.
|
|||
.. note::
|
||||
The following information is relevant to any of the above models that implements its encoder as an :class:`~nemo.collections.asr.modules.conv_asr.ConvASREncoder`, and utilizes the ``SqueezeExcite`` mechanism.
|
||||
|
||||
The ``SqueezeExcite`` block within a :class:`~nemo.collections.asr.modules.conv_asr.ConvASREncoder` network can be modified to utilize a different context window after the model has been instantiated (even after the model has been trained) so as to evaluate the model with limited context. This can be achieved using the :meth:`~nemo.collections.asr.parts.mixins.ASRModuleMixin.change_conv_asr_se_context_window`
|
||||
The ``SqueezeExcite`` block within a :class:`~nemo.collections.asr.modules.conv_asr.ConvASREncoder` network can be modified to utilize a different context window after the model has been instantiated (even after the model has been trained) so as to evaluate the model with limited context. This can be achieved using the :meth:`~nemo.collections.asr.parts.mixins.mixins.ASRModuleMixin.change_conv_asr_se_context_window`
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
|
@ -473,3 +473,56 @@ specify the tokenizer if you want to use sub-word encoding instead of character-
|
|||
|
||||
The encoder section includes the details about the Conformer-CTC encoder architecture. You may find more information in the
|
||||
config files and also :doc:`nemo.collections.asr.modules.ConformerEncoder<./api.html#nemo.collections.asr.modules.ConformerEncoder>`.
|
||||
|
||||
|
||||
Fine-tuning Configurations
|
||||
-------------------------
|
||||
|
||||
All ASR scripts support easy fine-tuning by partially/fully loading the pretrained weights from a checkpoint into the currently instantiated model. Pre-trained weights can be provided in multiple ways -
|
||||
|
||||
1) Providing a path to a NeMo model (via ``init_from_nemo_model``)
|
||||
2) Providing a name of a pretrained NeMo model (which will be downloaded via the cloud) (via ``init_from_pretrained_model``)
|
||||
3) Providing a path to a Pytorch Lightning checkpoint file (via ``init_from_ptl_ckpt``)
|
||||
|
||||
Fine-tuning via a NeMo model
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. code-block:: sh
|
||||
|
||||
python examples/asr/script_to_<script_name>.py \
|
||||
--config-path=<path to dir of configs> \
|
||||
--config-name=<name of config without .yaml>) \
|
||||
model.train_ds.manifest_filepath="<path to manifest file>" \
|
||||
model.validation_ds.manifest_filepath="<path to manifest file>" \
|
||||
trainer.gpus=-1 \
|
||||
trainer.max_epochs=50 \
|
||||
+init_from_nemo_model="<path to .nemo model file>"
|
||||
|
||||
|
||||
Fine-tuning via a NeMo pretrained model name
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. code-block:: sh
|
||||
|
||||
python examples/asr/script_to_<script_name>.py \
|
||||
--config-path=<path to dir of configs> \
|
||||
--config-name=<name of config without .yaml>) \
|
||||
model.train_ds.manifest_filepath="<path to manifest file>" \
|
||||
model.validation_ds.manifest_filepath="<path to manifest file>" \
|
||||
trainer.gpus=-1 \
|
||||
trainer.max_epochs=50 \
|
||||
+init_from_pretrained_model="<name of pretrained checkpoint>"
|
||||
|
||||
Fine-tuning via a Pytorch Lightning checkpoint
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. code-block:: sh
|
||||
|
||||
python examples/asr/script_to_<script_name>.py \
|
||||
--config-path=<path to dir of configs> \
|
||||
--config-name=<name of config without .yaml>) \
|
||||
model.train_ds.manifest_filepath="<path to manifest file>" \
|
||||
model.validation_ds.manifest_filepath="<path to manifest file>" \
|
||||
trainer.gpus=-1 \
|
||||
trainer.max_epochs=50 \
|
||||
+init_from_ptl_ckpt="<name of pytorch lightning checkpoint>"
|
|
@ -12,6 +12,6 @@ Model Classes
|
|||
Mixins
|
||||
------
|
||||
|
||||
.. autoclass:: nemo.collections.asr.parts.mixins.DiarizationMixin
|
||||
.. autoclass:: nemo.collections.asr.parts.mixins.mixins.DiarizationMixin
|
||||
:show-inheritance:
|
||||
:members:
|
||||
|
|
|
@ -80,7 +80,7 @@ minimum and maximum SNR specified with min_snr and max_snr respectively. This se
|
|||
max_snr_db: 15
|
||||
|
||||
|
||||
See the :class:`nemo.collections.asr.parts.perturb.AudioAugmentor` API section for more details.
|
||||
See the :class:`nemo.collections.asr.parts.preprocessing.perturb.AudioAugmentor` API section for more details.
|
||||
|
||||
|
||||
Model Architecture Configurations
|
||||
|
|
|
@ -52,6 +52,7 @@ autodoc_mock_imports = [
|
|||
'nemo_text_processing.inverse_text_normalization', # Not installed automatically
|
||||
'nemo_text_processing.text_normalization', # Not installed automatically
|
||||
'attr', # attrdict in requirements, attr in import
|
||||
'torchmetrics', # inherited from PTL
|
||||
]
|
||||
|
||||
_skipped_autodoc_mock_imports = ['wrapt', 'numpy']
|
||||
|
|
|
@ -98,3 +98,12 @@
|
|||
journal={arXiv preprint arXiv:1804.08771},
|
||||
year={2018}
|
||||
}
|
||||
|
||||
@misc{zhang2021sgdqa,
|
||||
title={SGD-QA: Fast Schema-Guided Dialogue State Tracking for Unseen Services},
|
||||
author={Yang Zhang and Vahid Noroozi and Evelina Bakhturina and Boris Ginsburg},
|
||||
year={2021},
|
||||
eprint={2105.08049},
|
||||
archivePrefix={arXiv},
|
||||
primaryClass={cs.CL}
|
||||
}
|
|
@ -3,4 +3,15 @@
|
|||
Dialogue State Tracking - SGD-QA Model
|
||||
======================================
|
||||
|
||||
More details and an example script on how to train the model can be found here: `NeMo/examples/nlp/dialogue_state_tracking/sgd_qa.py <https://github.com/NVIDIA/NeMo/blob/main/examples/nlp/dialogue_state_tracking/sgd_qa.py>`__.
|
||||
More details can be found in the paper
|
||||
`SGD-QA: Fast Schema-Guided Dialogue State Tracking for Unseen Services <https://arxiv.org/abs/2105.08049>`__ :cite:`nlp-sgdqa-zhang2021sgdqa`.
|
||||
An example script on how to train the model can be found here: `NeMo/examples/nlp/dialogue_state_tracking/sgd_qa.py <https://github.com/NVIDIA/NeMo/blob/main/examples/nlp/dialogue_state_tracking/sgd_qa.py>`__.
|
||||
|
||||
|
||||
References
|
||||
----------
|
||||
|
||||
.. bibliography:: nlp_all.bib
|
||||
:style: plain
|
||||
:labelprefix: NLP-SGDQA
|
||||
:keyprefix: nlp-sgdqa-
|
|
@ -102,20 +102,37 @@ python speech_to_label.py \
|
|||
+trainer.precision=16 \
|
||||
+trainer.amp_level=O1 # needed if using PyTorch < 1.6
|
||||
|
||||
# Fine-tune a model
|
||||
|
||||
For documentation on fine-tuning this model, please visit -
|
||||
https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/main/asr/configs.html#fine-tuning-configurations
|
||||
|
||||
# Pretrained Models
|
||||
|
||||
For documentation on existing pretrained models, please visit -
|
||||
https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/main/asr/speech_classification/results.html#
|
||||
|
||||
"""
|
||||
import pytorch_lightning as pl
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from nemo.collections.asr.models import EncDecClassificationModel
|
||||
from nemo.core.config import hydra_runner
|
||||
from nemo.utils import logging
|
||||
from nemo.utils.exp_manager import exp_manager
|
||||
|
||||
|
||||
@hydra_runner(config_path="conf", config_name="matchboxnet_3x1x64_v1")
|
||||
def main(cfg):
|
||||
logging.info(f'Hydra config: {OmegaConf.to_yaml(cfg)}')
|
||||
|
||||
trainer = pl.Trainer(**cfg.trainer)
|
||||
exp_manager(trainer, cfg.get("exp_manager", None))
|
||||
asr_model = EncDecClassificationModel(cfg=cfg.model, trainer=trainer)
|
||||
|
||||
# Initialize the weights of the model from another model, if provided via config
|
||||
asr_model.maybe_init_from_pretrained_checkpoint(cfg)
|
||||
|
||||
trainer.fit(asr_model)
|
||||
|
||||
if hasattr(cfg.model, 'test_ds') and cfg.model.test_ds.manifest_filepath is not None:
|
||||
|
|
|
@ -12,21 +12,14 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from nemo.collections.asr.models import EncDecCTCModel
|
||||
from nemo.core.config import hydra_runner
|
||||
from nemo.utils import logging
|
||||
from nemo.utils.exp_manager import exp_manager
|
||||
|
||||
|
||||
"""
|
||||
# Training the model
|
||||
|
||||
Basic run (on CPU for 50 epochs):
|
||||
python examples/asr/speech_to_text.py \
|
||||
model.train_ds.manifest_filepath="/Users/okuchaiev/Data/an4_dataset/an4_train.json" \
|
||||
model.validation_ds.manifest_filepath="/Users/okuchaiev/Data/an4_dataset/an4_val.json" \
|
||||
hydra.run.dir="." \
|
||||
# (Optional: --config-path=<path to dir of configs> --config-name=<name of config without .yaml>) \
|
||||
model.train_ds.manifest_filepath="<path to manifest file>" \
|
||||
model.validation_ds.manifest_filepath="<path to manifest file>" \
|
||||
trainer.gpus=0 \
|
||||
trainer.max_epochs=50
|
||||
|
||||
|
@ -41,19 +34,19 @@ PTL logs will be found in "$(./outputs/$(date +"%y-%m-%d")/$(date +"%H-%M-%S")/l
|
|||
|
||||
Override some args of optimizer:
|
||||
python speech_to_text.py \
|
||||
# (Optional: --config-path=<path to dir of configs> --config-name=<name of config without .yaml>) \
|
||||
model.train_ds.manifest_filepath="./an4/train_manifest.json" \
|
||||
model.validation_ds.manifest_filepath="./an4/test_manifest.json" \
|
||||
hydra.run.dir="." \
|
||||
trainer.gpus=2 \
|
||||
trainer.max_epochs=2 \
|
||||
model.optim.args.betas=[0.8,0.5] \
|
||||
model.optim.args.weight_decay=0.0001
|
||||
|
||||
Overide optimizer entirely
|
||||
Override optimizer entirely
|
||||
python speech_to_text.py \
|
||||
# (Optional: --config-path=<path to dir of configs> --config-name=<name of config without .yaml>) \
|
||||
model.train_ds.manifest_filepath="./an4/train_manifest.json" \
|
||||
model.validation_ds.manifest_filepath="./an4/test_manifest.json" \
|
||||
hydra.run.dir="." \
|
||||
trainer.gpus=2 \
|
||||
trainer.max_epochs=2 \
|
||||
model.optim.name=adamw \
|
||||
|
@ -62,16 +55,38 @@ Overide optimizer entirely
|
|||
+model.optim.args.betas=[0.8,0.5]\
|
||||
+model.optim.args.weight_decay=0.0005
|
||||
|
||||
# Fine-tune a model
|
||||
|
||||
For documentation on fine-tuning this model, please visit -
|
||||
https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/main/asr/configs.html#fine-tuning-configurations
|
||||
|
||||
# Pretrained Models
|
||||
|
||||
For documentation on existing pretrained models, please visit -
|
||||
https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/main/asr/results.html
|
||||
|
||||
"""
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from nemo.collections.asr.models import EncDecCTCModel
|
||||
from nemo.core.config import hydra_runner
|
||||
from nemo.utils import logging
|
||||
from nemo.utils.exp_manager import exp_manager
|
||||
|
||||
|
||||
@hydra_runner(config_path="conf", config_name="config")
|
||||
def main(cfg):
|
||||
logging.info(f'Hydra config: {OmegaConf.to_yaml(cfg)}')
|
||||
|
||||
trainer = pl.Trainer(**cfg.trainer)
|
||||
exp_manager(trainer, cfg.get("exp_manager", None))
|
||||
asr_model = EncDecCTCModel(cfg=cfg.model, trainer=trainer)
|
||||
|
||||
# Initialize the weights of the model from another model, if provided via config
|
||||
asr_model.maybe_init_from_pretrained_checkpoint(cfg)
|
||||
|
||||
trainer.fit(asr_model)
|
||||
|
||||
if hasattr(cfg.model, 'test_ds') and cfg.model.test_ds.manifest_filepath is not None:
|
||||
|
|
|
@ -50,7 +50,19 @@ python speech_to_text_bpe.py \
|
|||
exp_manager.wandb_logger_kwargs.name="<Name of experiment>" \
|
||||
exp_manager.wandb_logger_kwargs.project="<Name of project>"
|
||||
```
|
||||
|
||||
# Fine-tune a model
|
||||
|
||||
For documentation on fine-tuning this model, please visit -
|
||||
https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/main/asr/configs.html#fine-tuning-configurations
|
||||
|
||||
# Pretrained Models
|
||||
|
||||
For documentation on existing pretrained models, please visit -
|
||||
https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/main/asr/results.html
|
||||
|
||||
"""
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
|
@ -63,12 +75,14 @@ from nemo.utils.exp_manager import exp_manager
|
|||
@hydra_runner(config_path="experimental/configs/", config_name="config_bpe")
|
||||
def main(cfg):
|
||||
logging.info(f'Hydra config: {OmegaConf.to_yaml(cfg)}')
|
||||
print(OmegaConf.to_yaml(cfg))
|
||||
|
||||
trainer = pl.Trainer(**cfg.trainer)
|
||||
exp_manager(trainer, cfg.get("exp_manager", None))
|
||||
|
||||
asr_model = EncDecCTCModelBPE(cfg=cfg.model, trainer=trainer)
|
||||
|
||||
# Initialize the weights of the model from another model, if provided via config
|
||||
asr_model.maybe_init_from_pretrained_checkpoint(cfg)
|
||||
|
||||
trainer.fit(asr_model)
|
||||
|
||||
if hasattr(cfg.model, 'test_ds') and cfg.model.test_ds.manifest_filepath is not None:
|
||||
|
|
|
@ -12,32 +12,14 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import pytorch_lightning as pl
|
||||
|
||||
from nemo.collections.asr.models import EncDecRNNTModel
|
||||
from nemo.core.config import hydra_runner
|
||||
from nemo.utils.exp_manager import exp_manager
|
||||
|
||||
|
||||
"""
|
||||
# Preparing the Tokenizer for the dataset
|
||||
Use the `process_asr_text_tokenizer.py` script under <NEMO_ROOT>/scripts/tokenizers/ in order to prepare the tokenizer.
|
||||
|
||||
```sh
|
||||
python <NEMO_ROOT>/scripts/tokenizers/process_asr_text_tokenizer.py \
|
||||
--manifest=<path to train manifest files, seperated by commas> \
|
||||
--data_root="<output directory>" \
|
||||
--vocab_size=<number of tokens in vocabulary> \
|
||||
--tokenizer=<"bpe" or "wpe"> \
|
||||
--log
|
||||
```
|
||||
|
||||
# Training the model
|
||||
|
||||
Basic run (on CPU for 50 epochs):
|
||||
python examples/asr/speech_to_text_rnnt.py \
|
||||
model.train_ds.manifest_filepath="<path to train dataset>" \
|
||||
model.validation_ds.manifest_filepath="<path to validation dataset>" \
|
||||
# (Optional: --config-path=<path to dir of configs> --config-name=<name of config without .yaml>) \
|
||||
model.train_ds.manifest_filepath="<path to manifest file>" \
|
||||
model.validation_ds.manifest_filepath="<path to manifest file>" \
|
||||
trainer.gpus=0 \
|
||||
trainer.max_epochs=50
|
||||
|
||||
|
@ -56,12 +38,11 @@ Override some args of optimizer:
|
|||
--config-name="config_rnnt" \
|
||||
model.train_ds.manifest_filepath="./an4/train_manifest.json" \
|
||||
model.validation_ds.manifest_filepath="./an4/test_manifest.json" \
|
||||
hydra.run.dir="." \
|
||||
trainer.gpus=2 \
|
||||
trainer.precision=16 \
|
||||
trainer.max_epochs=2 \
|
||||
model.optim.args.params.betas=[0.8,0.5] \
|
||||
model.optim.args.params.weight_decay=0.0001
|
||||
model.optim.betas=[0.8,0.5] \
|
||||
model.optim.weight_decay=0.0001
|
||||
|
||||
Override optimizer entirely
|
||||
python speech_to_text_rnnt.py \
|
||||
|
@ -69,7 +50,6 @@ Override optimizer entirely
|
|||
--config-name="config_rnnt" \
|
||||
model.train_ds.manifest_filepath="./an4/train_manifest.json" \
|
||||
model.validation_ds.manifest_filepath="./an4/test_manifest.json" \
|
||||
hydra.run.dir="." \
|
||||
trainer.gpus=2 \
|
||||
trainer.precision=16 \
|
||||
trainer.max_epochs=2 \
|
||||
|
@ -79,15 +59,33 @@ Override optimizer entirely
|
|||
+model.optim.args.betas=[0.8,0.5]\
|
||||
+model.optim.args.weight_decay=0.0005
|
||||
|
||||
# Fine-tune a model
|
||||
|
||||
For documentation on fine-tuning this model, please visit -
|
||||
https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/main/asr/configs.html#fine-tuning-configurations
|
||||
|
||||
"""
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from nemo.collections.asr.models import EncDecRNNTModel
|
||||
from nemo.core.config import hydra_runner
|
||||
from nemo.utils import logging
|
||||
from nemo.utils.exp_manager import exp_manager
|
||||
|
||||
|
||||
@hydra_runner(config_path="experimental/contextnet_rnnt", config_name="config_rnnt")
|
||||
def main(cfg):
|
||||
logging.info(f'Hydra config: {OmegaConf.to_yaml(cfg)}')
|
||||
|
||||
trainer = pl.Trainer(**cfg.trainer)
|
||||
exp_manager(trainer, cfg.get("exp_manager", None))
|
||||
asr_model = EncDecRNNTModel(cfg=cfg.model, trainer=trainer)
|
||||
|
||||
# Initialize the weights of the model from another model, if provided via config
|
||||
asr_model.maybe_init_from_pretrained_checkpoint(cfg)
|
||||
|
||||
trainer.fit(asr_model)
|
||||
|
||||
if hasattr(cfg.model, 'test_ds') and cfg.model.test_ds.manifest_filepath is not None:
|
||||
|
|
|
@ -12,69 +12,72 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
# Preparing the Tokenizer for the dataset
|
||||
Use the `process_asr_text_tokenizer.py` script under <NEMO_ROOT>/scripts/tokenizers/ in order to prepare the tokenizer.
|
||||
|
||||
```sh
|
||||
python <NEMO_ROOT>/scripts/tokenizers/process_asr_text_tokenizer.py \
|
||||
--manifest=<path to train manifest files, seperated by commas>
|
||||
OR
|
||||
--data_file=<path to text data, seperated by commas> \
|
||||
--data_root="<output directory>" \
|
||||
--vocab_size=<number of tokens in vocabulary> \
|
||||
--tokenizer=<"spe" or "wpe"> \
|
||||
--no_lower_case \
|
||||
--spe_type=<"unigram", "bpe", "char" or "word"> \
|
||||
--spe_character_coverage=1.0 \
|
||||
--log
|
||||
```
|
||||
|
||||
# Training the model
|
||||
```sh
|
||||
python speech_to_text_rnnt_bpe.py \
|
||||
# (Optional: --config-path=<path to dir of configs> --config-name=<name of config without .yaml>) \
|
||||
model.train_ds.manifest_filepath=<path to train manifest> \
|
||||
model.validation_ds.manifest_filepath=<path to val/test manifest> \
|
||||
model.tokenizer.dir=<path to directory of tokenizer (not full path to the vocab file!)> \
|
||||
model.tokenizer.type=<either bpe or wpe> \
|
||||
trainer.gpus=-1 \
|
||||
trainer.accelerator="ddp" \
|
||||
trainer.max_epochs=100 \
|
||||
model.optim.name="adamw" \
|
||||
model.optim.lr=0.001 \
|
||||
model.optim.betas=[0.9,0.999] \
|
||||
model.optim.weight_decay=0.0001 \
|
||||
model.optim.sched.warmup_steps=2000
|
||||
exp_manager.create_wandb_logger=True \
|
||||
exp_manager.wandb_logger_kwargs.name="<Name of experiment>" \
|
||||
exp_manager.wandb_logger_kwargs.project="<Name of project>"
|
||||
```
|
||||
|
||||
# Fine-tune a model
|
||||
|
||||
For documentation on fine-tuning this model, please visit -
|
||||
https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/main/asr/configs.html#fine-tuning-configurations
|
||||
|
||||
"""
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from nemo.collections.asr.models import EncDecRNNTBPEModel
|
||||
from nemo.core.config import hydra_runner
|
||||
from nemo.utils import logging
|
||||
from nemo.utils.exp_manager import exp_manager
|
||||
|
||||
|
||||
"""
|
||||
Basic run (on CPU for 50 epochs):
|
||||
python examples/asr/speech_to_text_rnnt_bpe.py \
|
||||
model.train_ds.manifest_filepath="/Users/okuchaiev/Data/an4_dataset/an4_train.json" \
|
||||
model.validation_ds.manifest_filepath="/Users/okuchaiev/Data/an4_dataset/an4_val.json" \
|
||||
hydra.run.dir="." \
|
||||
trainer.gpus=0 \
|
||||
trainer.max_epochs=50
|
||||
|
||||
|
||||
Add PyTorch Lightning Trainer arguments from CLI:
|
||||
python speech_to_text_rnnt_bpe.py \
|
||||
... \
|
||||
+trainer.fast_dev_run=true
|
||||
|
||||
Hydra logs will be found in "$(./outputs/$(date +"%y-%m-%d")/$(date +"%H-%M-%S")/.hydra)"
|
||||
PTL logs will be found in "$(./outputs/$(date +"%y-%m-%d")/$(date +"%H-%M-%S")/lightning_logs)"
|
||||
|
||||
Override some args of optimizer:
|
||||
python speech_to_text_rnnt_bpe.py \
|
||||
--config-path="experimental/contextnet_rnnt" \
|
||||
--config-name="config_rnnt_bpe" \
|
||||
model.train_ds.manifest_filepath="./an4/train_manifest.json" \
|
||||
model.validation_ds.manifest_filepath="./an4/test_manifest.json" \
|
||||
hydra.run.dir="." \
|
||||
trainer.gpus=2 \
|
||||
trainer.precision=16 \
|
||||
trainer.max_epochs=2 \
|
||||
model.optim.args.params.betas=[0.8,0.5] \
|
||||
model.optim.args.params.weight_decay=0.0001
|
||||
|
||||
Overide optimizer entirely
|
||||
python speech_to_text_rnnt_bpe.py \
|
||||
--config-path="experimental/contextnet_rnnt" \
|
||||
--config-name="config_rnnt_bpe" \
|
||||
model.train_ds.manifest_filepath="./an4/train_manifest.json" \
|
||||
model.validation_ds.manifest_filepath="./an4/test_manifest.json" \
|
||||
hydra.run.dir="." \
|
||||
trainer.gpus=2 \
|
||||
trainer.precision=16 \
|
||||
trainer.max_epochs=2 \
|
||||
model.optim.name=adamw \
|
||||
model.optim.lr=0.001 \
|
||||
~model.optim.args \
|
||||
+model.optim.args.betas=[0.8,0.5]\
|
||||
+model.optim.args.weight_decay=0.0005
|
||||
|
||||
"""
|
||||
|
||||
|
||||
@hydra_runner(config_path="experimental/contextnet_rnnt", config_name="config_rnnt_bpe")
|
||||
def main(cfg):
|
||||
logging.info(f'Hydra config: {OmegaConf.to_yaml(cfg)}')
|
||||
|
||||
trainer = pl.Trainer(**cfg.trainer)
|
||||
exp_manager(trainer, cfg.get("exp_manager", None))
|
||||
asr_model = EncDecRNNTBPEModel(cfg=cfg.model, trainer=trainer)
|
||||
|
||||
# Initialize the weights of the model from another model, if provided via config
|
||||
asr_model.maybe_init_from_pretrained_checkpoint(cfg)
|
||||
|
||||
trainer.fit(asr_model)
|
||||
|
||||
if hasattr(cfg.model, 'test_ds') and cfg.model.test_ds.manifest_filepath is not None:
|
||||
|
|
|
@ -38,7 +38,7 @@ from argparse import ArgumentParser
|
|||
import torch
|
||||
|
||||
from nemo.collections.asr.models import EncDecClassificationModel
|
||||
from nemo.collections.asr.parts.vad_utils import get_vad_stream_status, prepare_manifest
|
||||
from nemo.collections.asr.parts.utils.vad_utils import get_vad_stream_status, prepare_manifest
|
||||
from nemo.utils import logging
|
||||
|
||||
try:
|
||||
|
|
|
@ -14,7 +14,7 @@
|
|||
# limitations under the License.
|
||||
|
||||
"""
|
||||
This script contains an example of how to train and test the NeMo SGD-QA Model.
|
||||
This script contains an example of how to train and test the NeMo SGD-QA Model (https://arxiv.org/abs/2105.08049).
|
||||
The SGD-QA model is a fast multi-pass schema-guided state-tracking model, that is trained on the Google schema-guided state tracking dataset (https://arxiv.org/abs/1909.05855).
|
||||
The model takes dialogue as input and outputs the dialogue state, which includes slot-value pairs.
|
||||
The model consists of two components: a neural natural language understanding model (NLU), and a rule-based state tracker.
|
||||
|
|
|
@ -111,6 +111,7 @@ model:
|
|||
hidden_act: relu
|
||||
mask_future: false
|
||||
pre_ln: false
|
||||
pre_ln_final_layer_norm: true
|
||||
|
||||
decoder:
|
||||
library: nemo
|
||||
|
@ -129,6 +130,7 @@ model:
|
|||
attn_layer_dropout: 0.1
|
||||
hidden_act: relu
|
||||
pre_ln: false
|
||||
pre_ln_final_layer_norm: true
|
||||
|
||||
head:
|
||||
num_layers: 1
|
||||
|
|
|
@ -30,8 +30,8 @@ def main(cfg: DictConfig) -> None:
|
|||
logging.info(f'\nConfig Params:\n{cfg.pretty()}')
|
||||
trainer = pl.Trainer(plugins=[NLPDDPPlugin()], **cfg.trainer)
|
||||
exp_manager(trainer, cfg.get("exp_manager", None))
|
||||
|
||||
model = TextClassificationModel.restore_from(cfg.model.nemo_path, trainer=trainer)
|
||||
# TODO: can we drop strict=False
|
||||
model = TextClassificationModel.restore_from(cfg.model.nemo_path, trainer=trainer, strict=False)
|
||||
model.setup_test_data(test_data_config=cfg.model.test_ds)
|
||||
|
||||
trainer.test(model=model, ckpt_path=None)
|
||||
|
|
|
@ -110,7 +110,8 @@ def main(cfg: DictConfig) -> None:
|
|||
model = TokenClassificationModel(cfg.model, trainer=trainer)
|
||||
else:
|
||||
if os.path.exists(cfg.pretrained_model):
|
||||
model = TokenClassificationModel.restore_from(cfg.pretrained_model, trainer=trainer)
|
||||
# TODO: can we drop strict=False?
|
||||
model = TokenClassificationModel.restore_from(cfg.pretrained_model, trainer=trainer, strict=False)
|
||||
elif cfg.pretrained_model in TokenClassificationModel.get_available_model_names():
|
||||
model = TokenClassificationModel.from_pretrained(cfg.pretrained_model)
|
||||
else:
|
||||
|
|
|
@ -42,7 +42,7 @@ model:
|
|||
num_workers: 8
|
||||
|
||||
preprocessor:
|
||||
_target_: nemo.collections.asr.parts.features.FilterbankFeatures
|
||||
_target_: nemo.collections.asr.parts.preprocessing.features.FilterbankFeatures
|
||||
dither: 0.0
|
||||
nfilt: ${model.n_mel_channels}
|
||||
frame_splicing: 1
|
||||
|
|
|
@ -47,7 +47,7 @@ model:
|
|||
num_workers: 4
|
||||
|
||||
preprocessor:
|
||||
_target_: nemo.collections.asr.parts.features.FilterbankFeatures
|
||||
_target_: nemo.collections.asr.parts.preprocessing.features.FilterbankFeatures
|
||||
dither: 0.0
|
||||
nfilt: ${n_mels}
|
||||
frame_splicing: 1
|
||||
|
|
|
@ -52,7 +52,7 @@ model:
|
|||
num_workers: 4
|
||||
|
||||
preprocessor:
|
||||
_target_: nemo.collections.asr.parts.features.FilterbankFeatures
|
||||
_target_: nemo.collections.asr.parts.preprocessing.features.FilterbankFeatures
|
||||
dither: 0.0
|
||||
nfilt: ${n_mels}
|
||||
frame_splicing: 1
|
||||
|
|
|
@ -9,7 +9,7 @@ defaults:
|
|||
|
||||
model:
|
||||
preprocessor:
|
||||
_target_: nemo.collections.asr.parts.features.FilterbankFeatures
|
||||
_target_: nemo.collections.asr.parts.preprocessing.features.FilterbankFeatures
|
||||
dither: 0.0
|
||||
frame_splicing: 1
|
||||
nfilt: 80
|
||||
|
|
|
@ -38,7 +38,7 @@ model:
|
|||
num_workers: 4
|
||||
|
||||
preprocessor:
|
||||
_target_: nemo.collections.asr.parts.features.FilterbankFeatures
|
||||
_target_: nemo.collections.asr.parts.preprocessing.features.FilterbankFeatures
|
||||
dither: 0.0
|
||||
frame_splicing: 1
|
||||
nfilt: ${n_mels}
|
||||
|
|
|
@ -42,7 +42,7 @@ model:
|
|||
num_workers: 4
|
||||
|
||||
preprocessor:
|
||||
_target_: nemo.collections.asr.parts.features.FilterbankFeatures
|
||||
_target_: nemo.collections.asr.parts.preprocessing.features.FilterbankFeatures
|
||||
dither: 0.0
|
||||
nfilt: ${n_mels}
|
||||
frame_splicing: 1
|
||||
|
|
|
@ -55,7 +55,7 @@ model:
|
|||
num_workers: 8
|
||||
|
||||
preprocessor:
|
||||
_target_: nemo.collections.asr.parts.features.FilterbankFeatures
|
||||
_target_: nemo.collections.asr.parts.preprocessing.features.FilterbankFeatures
|
||||
dither: 0.0
|
||||
nfilt: ${n_mels}
|
||||
frame_splicing: 1
|
||||
|
|
|
@ -39,7 +39,7 @@ model:
|
|||
num_workers: 4
|
||||
|
||||
preprocessor:
|
||||
_target_: nemo.collections.asr.parts.features.FilterbankFeatures
|
||||
_target_: nemo.collections.asr.parts.preprocessing.features.FilterbankFeatures
|
||||
dither: 0.0
|
||||
nfilt: ${n_mels}
|
||||
frame_splicing: 1
|
||||
|
|
|
@ -38,7 +38,7 @@ model:
|
|||
num_workers: 4
|
||||
|
||||
preprocessor:
|
||||
_target_: nemo.collections.asr.parts.features.FilterbankFeatures
|
||||
_target_: nemo.collections.asr.parts.preprocessing.features.FilterbankFeatures
|
||||
dither: 0.0
|
||||
nfilt: ${n_mels}
|
||||
frame_splicing: 1
|
||||
|
|
|
@ -26,7 +26,7 @@ import torch
|
|||
|
||||
from nemo.collections.asr.metrics.wer import word_error_rate
|
||||
from nemo.collections.asr.models import EncDecCTCModel
|
||||
from nemo.collections.asr.parts import parsers
|
||||
from nemo.collections.common.parts.preprocessing import parsers
|
||||
from nemo.collections.tts.models.base import SpectrogramGenerator, Vocoder
|
||||
from nemo.utils import logging
|
||||
|
||||
|
|
|
@ -20,7 +20,7 @@ import braceexpand
|
|||
import torch
|
||||
import webdataset as wd
|
||||
|
||||
from nemo.collections.asr.parts import collections
|
||||
from nemo.collections.common.parts.preprocessing import collections
|
||||
from nemo.core.classes import Dataset, IterableDataset
|
||||
from nemo.core.neural_types import AudioSignal, LabelsType, LengthsType, NeuralType, RegressionValuesType
|
||||
from nemo.utils import logging
|
||||
|
|
|
@ -25,8 +25,8 @@ from scipy.stats import betabinom
|
|||
from torch.nn import functional as F
|
||||
|
||||
from nemo.collections.asr.data import vocabs
|
||||
from nemo.collections.asr.parts import collections, parsers
|
||||
from nemo.collections.asr.parts.features import WaveformFeaturizer
|
||||
from nemo.collections.asr.parts.preprocessing.features import WaveformFeaturizer
|
||||
from nemo.collections.common.parts.preprocessing import collections, parsers
|
||||
from nemo.core.classes import Dataset, IterableDataset
|
||||
from nemo.core.neural_types import *
|
||||
from nemo.core.neural_types.elements import ProbsType
|
||||
|
|
|
@ -16,11 +16,10 @@ import math
|
|||
from collections.abc import Iterator
|
||||
from typing import Callable, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from omegaconf import DictConfig
|
||||
|
||||
from nemo.collections.asr.parts import parsers
|
||||
from nemo.collections.common.parts.preprocessing import parsers
|
||||
from nemo.utils.decorators import experimental
|
||||
|
||||
try:
|
||||
|
@ -128,6 +127,7 @@ class AudioToCharDALIDataset(Iterator):
|
|||
world_size: int = 1,
|
||||
preprocessor_cfg: DictConfig = None,
|
||||
):
|
||||
self.drop_last = drop_last # used by lr_scheduler
|
||||
if not HAVE_DALI:
|
||||
raise ModuleNotFoundError(
|
||||
f"{self} requires NVIDIA DALI to be installed. "
|
||||
|
@ -173,18 +173,18 @@ class AudioToCharDALIDataset(Iterator):
|
|||
|
||||
has_preprocessor = preprocessor_cfg is not None
|
||||
if has_preprocessor:
|
||||
if preprocessor_cfg.cls == "nemo.collections.asr.modules.AudioToMelSpectrogramPreprocessor":
|
||||
if preprocessor_cfg._target_ == "nemo.collections.asr.modules.AudioToMelSpectrogramPreprocessor":
|
||||
feature_type = "mel_spectrogram"
|
||||
elif preprocessor_cfg.cls == "nemo.collections.asr.modules.AudioToMFCCPreprocessor":
|
||||
elif preprocessor_cfg._target_ == "nemo.collections.asr.modules.AudioToMFCCPreprocessor":
|
||||
feature_type = "mfcc"
|
||||
else:
|
||||
raise ValueError(
|
||||
f"{self} received an unexpected preprocessor configuration: {preprocessor_cfg.cls}."
|
||||
f"{self} received an unexpected preprocessor configuration: {preprocessor_cfg._target_}."
|
||||
f" Supported preprocessors are: AudioToMelSpectrogramPreprocessor, AudioToMFCCPreprocessor"
|
||||
)
|
||||
|
||||
# Default values taken from AudioToMelSpectrogramPreprocessor
|
||||
params = preprocessor_cfg.params
|
||||
params = preprocessor_cfg
|
||||
self.dither = params['dither'] if 'dither' in params else 0.0
|
||||
self.preemph = params['preemph'] if 'preemph' in params else 0.97
|
||||
self.window_size_sec = params['window_size'] if 'window_size' in params else 0.02
|
||||
|
@ -291,7 +291,7 @@ class AudioToCharDALIDataset(Iterator):
|
|||
random_shuffle=shuffle,
|
||||
shard_id=self.shard_id,
|
||||
num_shards=self.num_shards,
|
||||
pad_last_batch=True,
|
||||
pad_last_batch=False,
|
||||
)
|
||||
|
||||
transcript_len = dali.fn.shapes(dali.fn.reshape(transcript, shape=[-1]))
|
||||
|
@ -310,8 +310,8 @@ class AudioToCharDALIDataset(Iterator):
|
|||
|
||||
if not has_preprocessor:
|
||||
# No preprocessing, the output is the audio signal
|
||||
audio = dali.fn.pad(audio)
|
||||
audio_len = dali.fn.shapes(dali.fn.reshape(audio, shape=[-1]))
|
||||
audio = dali.fn.pad(audio)
|
||||
self.pipe.set_outputs(audio, audio_len, transcript, transcript_len)
|
||||
else:
|
||||
# Additive gaussian noise (dither)
|
||||
|
@ -354,20 +354,11 @@ class AudioToCharDALIDataset(Iterator):
|
|||
spec = dali.fn.normalize(spec, axes=self.normalization_axes)
|
||||
|
||||
# Extracting the length of the spectrogram
|
||||
shape_start = dali.types.Constant(np.array([1], dtype=np.float32), device='cpu')
|
||||
shape_len = dali.types.Constant(np.array([1], dtype=np.float32), device='cpu')
|
||||
spec_len = dali.fn.slice(
|
||||
dali.fn.shapes(spec),
|
||||
shape_start,
|
||||
shape_len,
|
||||
normalized_anchor=False,
|
||||
normalized_shape=False,
|
||||
axes=(0,),
|
||||
)
|
||||
spec_len = dali.fn.slice(dali.fn.shapes(spec), 1, 1, axes=(0,))
|
||||
|
||||
# Pads feature dimension to be a multiple of `pad_to` and the temporal dimension to be as big as the largest sample (shape -1)
|
||||
spec = dali.fn.pad(spec, fill_value=self.pad_value, axes=(0, 1), align=(self.pad_to, 1), shape=(1, -1))
|
||||
self.pipe.set_outputs(spec, spec_len, transcript, transcript_len)
|
||||
self.pipe.set_outputs(spec, spec_len, transcript, transcript_len)
|
||||
# Building DALI pipeline
|
||||
self.pipe.build()
|
||||
|
||||
|
@ -415,9 +406,15 @@ class AudioToCharDALIDataset(Iterator):
|
|||
def __next__(self):
|
||||
outputs = self._iter.next()
|
||||
assert len(outputs) == 1
|
||||
out = outputs[0]
|
||||
text_raw_len = out['transcript_raw_len'].numpy()
|
||||
text_raw = out['transcript_raw'].numpy()
|
||||
dali_out = outputs[0]
|
||||
text_raw_len = dali_out['transcript_raw_len'].numpy()
|
||||
text_raw = dali_out['transcript_raw'].numpy()
|
||||
|
||||
out = {}
|
||||
out_names = ['processed_signal', 'processed_signal_len', 'audio', 'audio_len']
|
||||
for out_name in out_names:
|
||||
if out_name in dali_out:
|
||||
out[out_name] = dali_out[out_name].detach().clone()
|
||||
|
||||
text_tokens = []
|
||||
text_tokens_len = []
|
||||
|
|
|
@ -15,9 +15,50 @@
|
|||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from omegaconf import DictConfig
|
||||
from omegaconf import DictConfig, open_dict
|
||||
|
||||
from nemo.collections.asr.data import audio_to_text, audio_to_text_dali
|
||||
from nemo.utils import logging
|
||||
|
||||
|
||||
def inject_dataloader_value_from_model_config(model_cfg: dict, dataloader_cfg: dict, key: str):
|
||||
"""
|
||||
Extracts the label set provided at the top level of the model, and propagates it to the dataloader
|
||||
config.
|
||||
|
||||
Args:
|
||||
model_cfg: A DictConfig representing the model's config.
|
||||
dataloader_cfg: A DictConfig representing the individual data loader
|
||||
key: A str value representing a key in the model_cfg whose value will be propagated to the
|
||||
dataloader config.
|
||||
"""
|
||||
if key not in model_cfg:
|
||||
logging.info(
|
||||
f"Model level config does not container `{key}`, please explicitly provide `{key}` to the dataloaders."
|
||||
)
|
||||
return
|
||||
|
||||
# If key exists in the data loader config (either set explicitly or as a placeholder (via None))
|
||||
if key in dataloader_cfg:
|
||||
# Dataloader `labels` is provided and is non-null
|
||||
if dataloader_cfg[key] is not None and model_cfg[key] != dataloader_cfg[key]:
|
||||
# Model level `labels` dont match Dataloader level `labels`
|
||||
logging.warning(
|
||||
f'`{key}` is explicitly provided to the data loader, and is different from '
|
||||
f'the `{key}` provided at the model level config.\n'
|
||||
f'If this is incorrect, please set the dataloader\'s `{key}` to None.'
|
||||
)
|
||||
|
||||
else:
|
||||
# Dataloader `key` is None or values match
|
||||
# Propagate from model level `key` (even if they match)
|
||||
with open_dict(dataloader_cfg):
|
||||
dataloader_cfg[key] = model_cfg[key]
|
||||
|
||||
else:
|
||||
# If key key doesnt even exist in dataloader_cfg, inject it explicitly
|
||||
with open_dict(dataloader_cfg):
|
||||
dataloader_cfg[key] = model_cfg[key]
|
||||
|
||||
|
||||
def get_char_dataset(config: dict, augmentor: Optional['AudioAugmentor'] = None) -> audio_to_text.AudioToCharDataset:
|
||||
|
|
|
@ -15,7 +15,7 @@ from typing import Dict, List, Optional
|
|||
|
||||
import torch
|
||||
|
||||
from nemo.collections.asr.parts import collections
|
||||
from nemo.collections.common.parts.preprocessing import collections
|
||||
from nemo.core.classes import Dataset
|
||||
from nemo.core.neural_types import AcousticEncodedRepresentation, LabelsType, LengthsType, NeuralType
|
||||
from nemo.utils import logging
|
||||
|
|
|
@ -21,18 +21,8 @@ from builtins import str as unicode
|
|||
from typing import List
|
||||
|
||||
import nltk
|
||||
from nltk.corpus import cmudict
|
||||
|
||||
from nemo.collections.asr.parts import parsers
|
||||
|
||||
try:
|
||||
nltk.data.find('taggers/averaged_perceptron_tagger.zip')
|
||||
except LookupError:
|
||||
nltk.download('averaged_perceptron_tagger', quiet=True)
|
||||
try:
|
||||
nltk.data.find('corpora/cmudict.zip')
|
||||
except LookupError:
|
||||
nltk.download('cmudict', quiet=True)
|
||||
from nemo.collections.common.parts.preprocessing import parsers
|
||||
|
||||
try:
|
||||
import g2p_en # noqa
|
||||
|
@ -70,6 +60,16 @@ class G2p:
|
|||
text_preprocessing_func=_text_preprocessing,
|
||||
word_tokenize_func=_word_tokenize,
|
||||
):
|
||||
# Download NLTK datasets if this class is to be instantiated
|
||||
try:
|
||||
nltk.data.find('taggers/averaged_perceptron_tagger.zip')
|
||||
except LookupError:
|
||||
nltk.download('averaged_perceptron_tagger', quiet=True)
|
||||
try:
|
||||
nltk.data.find('corpora/cmudict.zip')
|
||||
except LookupError:
|
||||
nltk.download('cmudict', quiet=True)
|
||||
|
||||
self.homograph2features = _g2p.homograph2features
|
||||
self.g2p_dict = self._construct_grapheme2phoneme_dict(phoneme_dict_path)
|
||||
self.use_seq2seq_for_oov = use_seq2seq_for_oov
|
||||
|
@ -81,6 +81,8 @@ class G2p:
|
|||
@staticmethod
|
||||
def _construct_grapheme2phoneme_dict(phoneme_dict_path=None, encoding='latin-1'):
|
||||
if phoneme_dict_path is None:
|
||||
from nltk.corpus import cmudict
|
||||
|
||||
return cmudict.dict()
|
||||
|
||||
_alt_re = re.compile(r'\([0-9]+\)')
|
||||
|
|
|
@ -20,9 +20,9 @@ import editdistance
|
|||
import torch
|
||||
from torchmetrics import Metric
|
||||
|
||||
from nemo.collections.asr.parts import rnnt_beam_decoding as beam_decode
|
||||
from nemo.collections.asr.parts import rnnt_greedy_decoding as greedy_decode
|
||||
from nemo.collections.asr.parts.rnnt_utils import Hypothesis, NBestHypotheses
|
||||
from nemo.collections.asr.parts.submodules import rnnt_beam_decoding as beam_decode
|
||||
from nemo.collections.asr.parts.submodules import rnnt_greedy_decoding as greedy_decode
|
||||
from nemo.collections.asr.parts.utils.rnnt_utils import Hypothesis, NBestHypotheses
|
||||
from nemo.utils import logging
|
||||
|
||||
__all__ = ['RNNTDecoding', 'RNNTWER']
|
||||
|
|
|
@ -13,16 +13,15 @@
|
|||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional
|
||||
from typing import List
|
||||
|
||||
import editdistance
|
||||
import torch
|
||||
from torchmetrics import Metric
|
||||
|
||||
from nemo.collections.asr.metrics.rnnt_wer import AbstractRNNTDecoding
|
||||
from nemo.collections.asr.parts import rnnt_beam_decoding as beam_decode
|
||||
from nemo.collections.asr.parts import rnnt_greedy_decoding as greedy_decode
|
||||
from nemo.collections.asr.parts.rnnt_utils import Hypothesis, NBestHypotheses
|
||||
from nemo.collections.asr.parts.submodules import rnnt_beam_decoding as beam_decode
|
||||
from nemo.collections.asr.parts.submodules import rnnt_greedy_decoding as greedy_decode
|
||||
from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec
|
||||
from nemo.utils import logging
|
||||
|
||||
|
|
|
@ -18,7 +18,7 @@ import editdistance
|
|||
import torch
|
||||
from torchmetrics import Metric
|
||||
|
||||
from nemo.collections.asr.parts.rnnt_utils import Hypothesis
|
||||
from nemo.collections.asr.parts.utils.rnnt_utils import Hypothesis
|
||||
from nemo.utils import logging
|
||||
|
||||
__all__ = ['word_error_rate', 'WER']
|
||||
|
|
|
@ -18,7 +18,7 @@ import editdistance
|
|||
import torch
|
||||
from torchmetrics import Metric
|
||||
|
||||
from nemo.collections.asr.parts.rnnt_utils import Hypothesis
|
||||
from nemo.collections.asr.parts.utils.rnnt_utils import Hypothesis
|
||||
from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec
|
||||
from nemo.utils import logging
|
||||
|
||||
|
|
|
@ -20,7 +20,6 @@ from abc import abstractmethod
|
|||
from math import ceil
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
import onnx
|
||||
import torch
|
||||
from omegaconf import DictConfig, ListConfig, OmegaConf
|
||||
from pytorch_lightning import Trainer
|
||||
|
@ -28,8 +27,8 @@ from pytorch_lightning.metrics.regression import MeanAbsoluteError, MeanSquaredE
|
|||
|
||||
from nemo.collections.asr.data import audio_to_label_dataset
|
||||
from nemo.collections.asr.models.asr_model import ASRModel, ExportableEncDecModel
|
||||
from nemo.collections.asr.parts.features import WaveformFeaturizer
|
||||
from nemo.collections.asr.parts.perturb import process_augmentations
|
||||
from nemo.collections.asr.parts.preprocessing.features import WaveformFeaturizer
|
||||
from nemo.collections.asr.parts.preprocessing.perturb import process_augmentations
|
||||
from nemo.collections.common.losses import CrossEntropyLoss, MSELoss
|
||||
from nemo.collections.common.metrics import TopKClassificationAccuracy
|
||||
from nemo.core.classes.common import PretrainedModelInfo, typecheck
|
||||
|
@ -478,6 +477,7 @@ class EncDecClassificationModel(_EncDecBaseModel):
|
|||
|
||||
self._accuracy(logits=logits, labels=labels)
|
||||
topk_scores = self._accuracy.compute()
|
||||
self._accuracy.reset()
|
||||
|
||||
for top_k, score in zip(self._accuracy.top_k, topk_scores):
|
||||
self.log('training_batch_accuracy_top@{}'.format(top_k), score)
|
||||
|
@ -520,6 +520,7 @@ class EncDecClassificationModel(_EncDecBaseModel):
|
|||
self._accuracy.correct_counts_k = correct_counts
|
||||
self._accuracy.total_counts_k = total_counts
|
||||
topk_scores = self._accuracy.compute()
|
||||
self._accuracy.reset()
|
||||
|
||||
tensorboard_log = {'val_loss': val_loss_mean}
|
||||
for top_k, score in zip(self._accuracy.top_k, topk_scores):
|
||||
|
@ -535,6 +536,7 @@ class EncDecClassificationModel(_EncDecBaseModel):
|
|||
self._accuracy.correct_counts_k = correct_counts
|
||||
self._accuracy.total_counts_k = total_counts
|
||||
topk_scores = self._accuracy.compute()
|
||||
self._accuracy.reset()
|
||||
|
||||
tensorboard_log = {'test_loss': test_loss_mean}
|
||||
for top_k, score in zip(self._accuracy.top_k, topk_scores):
|
||||
|
@ -699,7 +701,9 @@ class EncDecRegressionModel(_EncDecBaseModel):
|
|||
def multi_validation_epoch_end(self, outputs, dataloader_idx: int = 0):
|
||||
val_loss_mean = torch.stack([x['val_loss'] for x in outputs]).mean()
|
||||
val_mse = self._mse.compute()
|
||||
self._mse.reset()
|
||||
val_mae = self._mae.compute()
|
||||
self._mae.reset()
|
||||
|
||||
tensorboard_logs = {'val_loss': val_loss_mean, 'val_mse': val_mse, 'val_mae': val_mae}
|
||||
|
||||
|
@ -708,7 +712,9 @@ class EncDecRegressionModel(_EncDecBaseModel):
|
|||
def multi_test_epoch_end(self, outputs, dataloader_idx: int = 0):
|
||||
test_loss_mean = torch.stack([x['test_loss'] for x in outputs]).mean()
|
||||
test_mse = self._mse.compute()
|
||||
self._mse.reset()
|
||||
test_mae = self._mae.compute()
|
||||
self._mae.reset()
|
||||
|
||||
tensorboard_logs = {'test_loss': test_loss_mean, 'test_mse': test_mse, 'test_mae': test_mae}
|
||||
|
||||
|
|
|
@ -29,9 +29,9 @@ from tqdm import tqdm
|
|||
|
||||
from nemo.collections.asr.models.classification_models import EncDecClassificationModel
|
||||
from nemo.collections.asr.models.label_models import ExtractSpeakerEmbeddingsModel
|
||||
from nemo.collections.asr.parts.mixins import DiarizationMixin
|
||||
from nemo.collections.asr.parts.speaker_utils import audio_rttm_map, perform_diarization, write_rttm2manifest
|
||||
from nemo.collections.asr.parts.vad_utils import (
|
||||
from nemo.collections.asr.parts.mixins.mixins import DiarizationMixin
|
||||
from nemo.collections.asr.parts.utils.speaker_utils import audio_rttm_map, perform_diarization, write_rttm2manifest
|
||||
from nemo.collections.asr.parts.utils.vad_utils import (
|
||||
generate_overlap_vad_seq,
|
||||
generate_vad_segment_table,
|
||||
get_vad_stream_status,
|
||||
|
|
|
@ -24,7 +24,7 @@ from nemo.collections.asr.losses.ctc import CTCLoss
|
|||
from nemo.collections.asr.metrics.wer_bpe import WERBPE
|
||||
from nemo.collections.asr.models.ctc_models import EncDecCTCModel
|
||||
from nemo.collections.asr.parts.mixins import ASRBPEMixin
|
||||
from nemo.collections.asr.parts.perturb import process_augmentations
|
||||
from nemo.collections.asr.parts.preprocessing.perturb import process_augmentations
|
||||
from nemo.core.classes.common import PretrainedModelInfo
|
||||
from nemo.utils import logging, model_utils
|
||||
|
||||
|
@ -66,6 +66,14 @@ class EncDecCTCModelBPE(EncDecCTCModel, ASRBPEMixin):
|
|||
|
||||
results.append(model)
|
||||
|
||||
model = PretrainedModelInfo(
|
||||
pretrained_model_name="stt_es_citrinet_512",
|
||||
description="For details about this model, please visit https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_es_citrinet_512",
|
||||
location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/stt_es_citrinet_512/versions/1.0.0/files/stt_es_citrinet_512.nemo",
|
||||
)
|
||||
|
||||
results.append(model)
|
||||
|
||||
model = PretrainedModelInfo(
|
||||
pretrained_model_name="stt_en_conformer_ctc_small",
|
||||
description="For details about this model, please visit https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_en_conformer_ctc_small",
|
||||
|
|
|
@ -29,7 +29,7 @@ from nemo.collections.asr.losses.ctc import CTCLoss
|
|||
from nemo.collections.asr.metrics.wer import WER
|
||||
from nemo.collections.asr.models.asr_model import ASRModel, ExportableEncDecModel
|
||||
from nemo.collections.asr.parts.mixins import ASRModuleMixin
|
||||
from nemo.collections.asr.parts.perturb import process_augmentations
|
||||
from nemo.collections.asr.parts.preprocessing.perturb import process_augmentations
|
||||
from nemo.core.classes.common import PretrainedModelInfo, typecheck
|
||||
from nemo.core.neural_types import AudioSignal, LabelsType, LengthsType, LogprobsType, NeuralType, SpectrogramType
|
||||
from nemo.utils import logging
|
||||
|
@ -330,6 +330,12 @@ class EncDecCTCModel(ASRModel, ExportableEncDecModel, ASRModuleMixin):
|
|||
self._cfg.decoder = new_decoder_config
|
||||
OmegaConf.set_struct(self._cfg.decoder, True)
|
||||
|
||||
ds_keys = ['train_ds', 'validation_ds', 'test_ds']
|
||||
for key in ds_keys:
|
||||
if key in self.cfg:
|
||||
with open_dict(self.cfg[key]):
|
||||
self.cfg[key]['labels'] = OmegaConf.create(new_vocabulary)
|
||||
|
||||
logging.info(f"Changed decoder to output to {self.decoder.vocabulary} vocabulary.")
|
||||
|
||||
def _setup_dataloader_from_config(self, config: Optional[Dict]):
|
||||
|
@ -338,6 +344,10 @@ class EncDecCTCModel(ASRModel, ExportableEncDecModel, ASRModuleMixin):
|
|||
else:
|
||||
augmentor = None
|
||||
|
||||
# Automatically inject args from model config to dataloader config
|
||||
audio_to_text_dataset.inject_dataloader_value_from_model_config(self.cfg, config, key='sample_rate')
|
||||
audio_to_text_dataset.inject_dataloader_value_from_model_config(self.cfg, config, key='labels')
|
||||
|
||||
shuffle = config['shuffle']
|
||||
device = 'gpu' if torch.cuda.is_available() else 'cpu'
|
||||
if config.get('use_dali', False):
|
||||
|
|
|
@ -26,8 +26,8 @@ from pytorch_lightning import Trainer
|
|||
from nemo.collections.asr.data.audio_to_label import AudioToSpeechLabelDataset
|
||||
from nemo.collections.asr.losses.angularloss import AngularSoftmaxLoss
|
||||
from nemo.collections.asr.models.asr_model import ExportableEncDecModel
|
||||
from nemo.collections.asr.parts.features import WaveformFeaturizer
|
||||
from nemo.collections.asr.parts.perturb import process_augmentations
|
||||
from nemo.collections.asr.parts.preprocessing.features import WaveformFeaturizer
|
||||
from nemo.collections.asr.parts.preprocessing.perturb import process_augmentations
|
||||
from nemo.collections.common.losses import CrossEntropyLoss as CELoss
|
||||
from nemo.collections.common.metrics import TopKClassificationAccuracy
|
||||
from nemo.core.classes import ModelPT
|
||||
|
@ -205,6 +205,7 @@ class EncDecSpeakerLabelModel(ModelPT, ExportableEncDecModel):
|
|||
|
||||
self._accuracy(logits=logits, labels=labels)
|
||||
top_k = self._accuracy.compute()
|
||||
self._accuracy.reset()
|
||||
for i, top_i in enumerate(top_k):
|
||||
self.log(f'training_batch_accuracy_top@{i}', top_i)
|
||||
|
||||
|
@ -232,6 +233,7 @@ class EncDecSpeakerLabelModel(ModelPT, ExportableEncDecModel):
|
|||
self._accuracy.correct_counts_k = correct_counts
|
||||
self._accuracy.total_counts_k = total_counts
|
||||
topk_scores = self._accuracy.compute()
|
||||
self._accuracy.reset()
|
||||
|
||||
logging.info("val_loss: {:.3f}".format(val_loss_mean))
|
||||
self.log('val_loss', val_loss_mean)
|
||||
|
@ -265,6 +267,7 @@ class EncDecSpeakerLabelModel(ModelPT, ExportableEncDecModel):
|
|||
self._accuracy.correct_counts_k = correct_counts
|
||||
self._accuracy.total_counts_k = total_counts
|
||||
topk_scores = self._accuracy.compute()
|
||||
self._accuracy.reset()
|
||||
|
||||
logging.info("test_loss: {:.3f}".format(test_loss_mean))
|
||||
self.log('test_loss', test_loss_mean)
|
||||
|
|
|
@ -25,7 +25,7 @@ from nemo.collections.asr.losses.rnnt import RNNTLoss
|
|||
from nemo.collections.asr.metrics.rnnt_wer_bpe import RNNTBPEWER, RNNTBPEDecoding
|
||||
from nemo.collections.asr.models.rnnt_models import EncDecRNNTModel
|
||||
from nemo.collections.asr.parts.mixins import ASRBPEMixin
|
||||
from nemo.collections.asr.parts.perturb import process_augmentations
|
||||
from nemo.collections.asr.parts.preprocessing.perturb import process_augmentations
|
||||
from nemo.core.classes.common import PretrainedModelInfo
|
||||
from nemo.utils import logging, model_utils
|
||||
|
||||
|
|
|
@ -30,7 +30,7 @@ from nemo.collections.asr.losses.rnnt import RNNTLoss, resolve_rnnt_default_loss
|
|||
from nemo.collections.asr.metrics.rnnt_wer import RNNTWER, RNNTDecoding
|
||||
from nemo.collections.asr.models.asr_model import ASRModel
|
||||
from nemo.collections.asr.parts.mixins import ASRModuleMixin
|
||||
from nemo.collections.asr.parts.perturb import process_augmentations
|
||||
from nemo.collections.asr.parts.preprocessing.perturb import process_augmentations
|
||||
from nemo.core.classes.common import PretrainedModelInfo, typecheck
|
||||
from nemo.core.neural_types import AcousticEncodedRepresentation, AudioSignal, LengthsType, NeuralType, SpectrogramType
|
||||
from nemo.utils import logging
|
||||
|
@ -340,6 +340,12 @@ class EncDecRNNTModel(ASRModel, ASRModuleMixin):
|
|||
with open_dict(self.cfg.decoding):
|
||||
self.cfg.decoding = decoding_cfg
|
||||
|
||||
ds_keys = ['train_ds', 'validation_ds', 'test_ds']
|
||||
for key in ds_keys:
|
||||
if key in self.cfg:
|
||||
with open_dict(self.cfg[key]):
|
||||
self.cfg[key]['labels'] = OmegaConf.create(new_vocabulary)
|
||||
|
||||
logging.info(f"Changed decoder to output to {self.joint.vocabulary} vocabulary.")
|
||||
|
||||
def change_decoding_strategy(self, decoding_cfg: DictConfig):
|
||||
|
@ -384,6 +390,10 @@ class EncDecRNNTModel(ASRModel, ASRModuleMixin):
|
|||
else:
|
||||
augmentor = None
|
||||
|
||||
# Automatically inject args from model config to dataloader config
|
||||
audio_to_text_dataset.inject_dataloader_value_from_model_config(self.cfg, config, key='sample_rate')
|
||||
audio_to_text_dataset.inject_dataloader_value_from_model_config(self.cfg, config, key='labels')
|
||||
|
||||
shuffle = config['shuffle']
|
||||
device = 'gpu' if torch.cuda.is_available() else 'cpu'
|
||||
if config.get('use_dali', False):
|
||||
|
|
|
@ -31,8 +31,8 @@ from nemo.collections.asr.data import audio_to_text_dataset
|
|||
from nemo.collections.asr.losses.wav2vecloss import Wav2VecLoss
|
||||
from nemo.collections.asr.models.wav2vec.wav2vec_config import Wav2VecEncoderModelConfig
|
||||
from nemo.collections.asr.modules.wav2vec_modules import GumbelVectorQuantizer, compute_mask_indices
|
||||
from nemo.collections.asr.parts.perturb import process_augmentations
|
||||
from nemo.collections.asr.parts.wav2vec import ConvFeatureEncoder, GradMultiply, Wav2VecTransformerEncoder
|
||||
from nemo.collections.asr.parts.preprocessing.perturb import process_augmentations
|
||||
from nemo.collections.asr.parts.submodules.wav2vec import ConvFeatureEncoder, GradMultiply, Wav2VecTransformerEncoder
|
||||
from nemo.core import ModelPT
|
||||
from nemo.core.classes.common import PretrainedModelInfo, typecheck
|
||||
from nemo.core.neural_types import AudioSignal, EncodedRepresentation, LossType, MaskType, NeuralType
|
||||
|
|
|
@ -20,8 +20,8 @@ from typing import Any, Optional
|
|||
import torch
|
||||
from packaging import version
|
||||
|
||||
from nemo.collections.asr.parts.features import FilterbankFeatures
|
||||
from nemo.collections.asr.parts.spectr_augment import SpecAugment, SpecCutout
|
||||
from nemo.collections.asr.parts.preprocessing.features import FilterbankFeatures
|
||||
from nemo.collections.asr.parts.submodules.spectr_augment import SpecAugment, SpecCutout
|
||||
from nemo.core.classes import NeuralModule, typecheck
|
||||
from nemo.core.neural_types import (
|
||||
AudioSignal,
|
||||
|
@ -425,13 +425,6 @@ class SpectrogramAugmentation(NeuralModule):
|
|||
Defaults to 25.
|
||||
"""
|
||||
|
||||
def save_to(self, save_path: str):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def restore_from(cls, restore_path: str):
|
||||
pass
|
||||
|
||||
@property
|
||||
def input_types(self):
|
||||
"""Returns definitions of module input types
|
||||
|
@ -462,7 +455,7 @@ class SpectrogramAugmentation(NeuralModule):
|
|||
self.spec_cutout = SpecCutout(rect_masks=rect_masks, rect_time=rect_time, rect_freq=rect_freq, rng=rng,)
|
||||
# self.spec_cutout.to(self._device)
|
||||
else:
|
||||
self.spec_cutout = lambda x: x
|
||||
self.spec_cutout = lambda input_spec: input_spec
|
||||
|
||||
if freq_masks + time_masks > 0:
|
||||
self.spec_augment = SpecAugment(
|
||||
|
@ -474,12 +467,12 @@ class SpectrogramAugmentation(NeuralModule):
|
|||
mask_value=mask_value,
|
||||
)
|
||||
else:
|
||||
self.spec_augment = lambda x: x
|
||||
self.spec_augment = lambda input_spec: input_spec
|
||||
|
||||
@typecheck()
|
||||
def forward(self, input_spec):
|
||||
augmented_spec = self.spec_cutout(input_spec)
|
||||
augmented_spec = self.spec_augment(augmented_spec)
|
||||
augmented_spec = self.spec_cutout(input_spec=input_spec)
|
||||
augmented_spec = self.spec_augment(input_spec=augmented_spec)
|
||||
return augmented_spec
|
||||
|
||||
|
||||
|
|
|
@ -18,9 +18,9 @@ from collections import OrderedDict
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from nemo.collections.asr.parts.conformer_modules import ConformerLayer
|
||||
from nemo.collections.asr.parts.multi_head_attention import PositionalEncoding, RelPositionalEncoding
|
||||
from nemo.collections.asr.parts.subsampling import ConvSubsampling
|
||||
from nemo.collections.asr.parts.submodules.conformer_modules import ConformerLayer
|
||||
from nemo.collections.asr.parts.submodules.multi_head_attention import PositionalEncoding, RelPositionalEncoding
|
||||
from nemo.collections.asr.parts.submodules.subsampling import ConvSubsampling
|
||||
from nemo.core.classes.common import typecheck
|
||||
from nemo.core.classes.exportable import Exportable
|
||||
from nemo.core.classes.module import NeuralModule
|
||||
|
|
|
@ -20,7 +20,7 @@ import torch.nn as nn
|
|||
import torch.nn.functional as F
|
||||
from omegaconf import MISSING, ListConfig, OmegaConf
|
||||
|
||||
from nemo.collections.asr.parts.jasper import (
|
||||
from nemo.collections.asr.parts.submodules.jasper import (
|
||||
JasperBlock,
|
||||
MaskedConv1d,
|
||||
StatsPoolLayer,
|
||||
|
@ -127,6 +127,11 @@ class ConvASREncoder(NeuralModule, Exportable):
|
|||
jasper = OmegaConf.to_container(jasper)
|
||||
|
||||
activation = jasper_activations[activation]()
|
||||
|
||||
# If the activation can be executed in place, do so.
|
||||
if hasattr(activation, 'inplace'):
|
||||
activation.inplace = True
|
||||
|
||||
feat_in = feat_in * frame_splicing
|
||||
|
||||
self._feat_in = feat_in
|
||||
|
|
|
@ -31,7 +31,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union
|
|||
import torch
|
||||
|
||||
from nemo.collections.asr.modules import rnnt_abstract
|
||||
from nemo.collections.asr.parts import rnnt_utils
|
||||
from nemo.collections.asr.parts.utils import rnnt_utils
|
||||
from nemo.collections.common.parts import rnn
|
||||
from nemo.core.classes import typecheck
|
||||
from nemo.core.neural_types import (
|
||||
|
|
|
@ -16,7 +16,7 @@ from typing import Any, Dict, List, Optional, Tuple
|
|||
|
||||
import torch
|
||||
|
||||
from nemo.collections.asr.parts.rnnt_utils import Hypothesis
|
||||
from nemo.collections.asr.parts.utils.rnnt_utils import Hypothesis
|
||||
from nemo.core import NeuralModule
|
||||
|
||||
|
||||
|
|
|
@ -32,401 +32,8 @@
|
|||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
# SOFTWARE.
|
||||
# This file contains code artifacts adapted from https://github.com/ryanleary/patter
|
||||
import math
|
||||
|
||||
import librosa
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from librosa.util import tiny
|
||||
from torch.autograd import Variable
|
||||
from torch_stft import STFT
|
||||
|
||||
from nemo.collections.asr.parts.perturb import AudioAugmentor
|
||||
from nemo.collections.asr.parts.segment import AudioSegment
|
||||
from nemo.collections.common.parts.patch_utils import stft_patch
|
||||
from nemo.utils import logging
|
||||
|
||||
CONSTANT = 1e-5
|
||||
|
||||
|
||||
def normalize_batch(x, seq_len, normalize_type):
|
||||
if normalize_type == "per_feature":
|
||||
x_mean = torch.zeros((seq_len.shape[0], x.shape[1]), dtype=x.dtype, device=x.device)
|
||||
x_std = torch.zeros((seq_len.shape[0], x.shape[1]), dtype=x.dtype, device=x.device)
|
||||
for i in range(x.shape[0]):
|
||||
if x[i, :, : seq_len[i]].shape[1] == 1:
|
||||
raise ValueError(
|
||||
"normalize_batch with `per_feature` normalize_type received a tensor of length 1. This will result "
|
||||
"in torch.std() returning nan"
|
||||
)
|
||||
x_mean[i, :] = x[i, :, : seq_len[i]].mean(dim=1)
|
||||
x_std[i, :] = x[i, :, : seq_len[i]].std(dim=1)
|
||||
# make sure x_std is not zero
|
||||
x_std += CONSTANT
|
||||
return (x - x_mean.unsqueeze(2)) / x_std.unsqueeze(2)
|
||||
elif normalize_type == "all_features":
|
||||
x_mean = torch.zeros(seq_len.shape, dtype=x.dtype, device=x.device)
|
||||
x_std = torch.zeros(seq_len.shape, dtype=x.dtype, device=x.device)
|
||||
for i in range(x.shape[0]):
|
||||
x_mean[i] = x[i, :, : seq_len[i].item()].mean()
|
||||
x_std[i] = x[i, :, : seq_len[i].item()].std()
|
||||
# make sure x_std is not zero
|
||||
x_std += CONSTANT
|
||||
return (x - x_mean.view(-1, 1, 1)) / x_std.view(-1, 1, 1)
|
||||
elif "fixed_mean" in normalize_type and "fixed_std" in normalize_type:
|
||||
x_mean = torch.tensor(normalize_type["fixed_mean"], device=x.device)
|
||||
x_std = torch.tensor(normalize_type["fixed_std"], device=x.device)
|
||||
return (x - x_mean.view(x.shape[0], x.shape[1]).unsqueeze(2)) / x_std.view(x.shape[0], x.shape[1]).unsqueeze(2)
|
||||
else:
|
||||
return x
|
||||
|
||||
|
||||
def splice_frames(x, frame_splicing):
|
||||
""" Stacks frames together across feature dim
|
||||
|
||||
input is batch_size, feature_dim, num_frames
|
||||
output is batch_size, feature_dim*frame_splicing, num_frames
|
||||
|
||||
"""
|
||||
seq = [x]
|
||||
for n in range(1, frame_splicing):
|
||||
seq.append(torch.cat([x[:, :, :n], x[:, :, n:]], dim=2))
|
||||
return torch.cat(seq, dim=1)
|
||||
|
||||
|
||||
class WaveformFeaturizer(object):
|
||||
def __init__(self, sample_rate=16000, int_values=False, augmentor=None):
|
||||
self.augmentor = augmentor if augmentor is not None else AudioAugmentor()
|
||||
self.sample_rate = sample_rate
|
||||
self.int_values = int_values
|
||||
|
||||
def max_augmentation_length(self, length):
|
||||
return self.augmentor.max_augmentation_length(length)
|
||||
|
||||
def process(self, file_path, offset=0, duration=0, trim=False, orig_sr=None):
|
||||
audio = AudioSegment.from_file(
|
||||
file_path,
|
||||
target_sr=self.sample_rate,
|
||||
int_values=self.int_values,
|
||||
offset=offset,
|
||||
duration=duration,
|
||||
trim=trim,
|
||||
orig_sr=orig_sr,
|
||||
)
|
||||
return self.process_segment(audio)
|
||||
|
||||
def process_segment(self, audio_segment):
|
||||
self.augmentor.perturb(audio_segment)
|
||||
return torch.tensor(audio_segment.samples, dtype=torch.float)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, input_config, perturbation_configs=None):
|
||||
if perturbation_configs is not None:
|
||||
aa = AudioAugmentor.from_config(perturbation_configs)
|
||||
else:
|
||||
aa = None
|
||||
|
||||
sample_rate = input_config.get("sample_rate", 16000)
|
||||
int_values = input_config.get("int_values", False)
|
||||
|
||||
return cls(sample_rate=sample_rate, int_values=int_values, augmentor=aa)
|
||||
|
||||
|
||||
class FeaturizerFactory(object):
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, input_cfg, perturbation_configs=None):
|
||||
return WaveformFeaturizer.from_config(input_cfg, perturbation_configs=perturbation_configs)
|
||||
|
||||
|
||||
# Create helper class to patch forward func for use with AMP
|
||||
class STFTPatch(STFT):
|
||||
def forward(self, input_data):
|
||||
return super().transform(input_data)[0]
|
||||
|
||||
|
||||
# Create helper class for STFT that yields num_frames = num_samples // hop_length
|
||||
class STFTExactPad(STFTPatch):
|
||||
"""adapted from Prem Seetharaman's https://github.com/pseeth/pytorch-stft"""
|
||||
|
||||
def __init__(self, *params, **kw_params):
|
||||
super().__init__(*params, **kw_params)
|
||||
self.pad_amount = (self.filter_length - self.hop_length) // 2
|
||||
|
||||
def inverse(self, magnitude, phase):
|
||||
recombine_magnitude_phase = torch.cat([magnitude * torch.cos(phase), magnitude * torch.sin(phase)], dim=1)
|
||||
|
||||
inverse_transform = F.conv_transpose1d(
|
||||
recombine_magnitude_phase,
|
||||
Variable(self.inverse_basis, requires_grad=False),
|
||||
stride=self.hop_length,
|
||||
padding=0,
|
||||
)
|
||||
|
||||
if self.window is not None:
|
||||
window_sum = librosa.filters.window_sumsquare(
|
||||
self.window,
|
||||
magnitude.size(-1),
|
||||
hop_length=self.hop_length,
|
||||
win_length=self.win_length,
|
||||
n_fft=self.filter_length,
|
||||
dtype=np.float32,
|
||||
)
|
||||
# remove modulation effects
|
||||
approx_nonzero_indices = torch.from_numpy(np.where(window_sum > tiny(window_sum))[0])
|
||||
window_sum = torch.autograd.Variable(torch.from_numpy(window_sum), requires_grad=False).to(
|
||||
magnitude.device
|
||||
)
|
||||
inverse_transform[..., approx_nonzero_indices] /= window_sum[approx_nonzero_indices]
|
||||
|
||||
# scale by hop ratio
|
||||
inverse_transform *= self.filter_length / self.hop_length
|
||||
|
||||
inverse_transform = inverse_transform[..., self.pad_amount :]
|
||||
inverse_transform = inverse_transform[..., : -self.pad_amount :]
|
||||
inverse_transform = inverse_transform.squeeze(1)
|
||||
|
||||
return inverse_transform
|
||||
|
||||
|
||||
class FilterbankFeatures(nn.Module):
|
||||
"""Featurizer that converts wavs to Mel Spectrograms.
|
||||
See AudioToMelSpectrogramPreprocessor for args.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
sample_rate=16000,
|
||||
n_window_size=320,
|
||||
n_window_stride=160,
|
||||
window="hann",
|
||||
normalize="per_feature",
|
||||
n_fft=None,
|
||||
preemph=0.97,
|
||||
nfilt=64,
|
||||
lowfreq=0,
|
||||
highfreq=None,
|
||||
log=True,
|
||||
log_zero_guard_type="add",
|
||||
log_zero_guard_value=2 ** -24,
|
||||
dither=CONSTANT,
|
||||
pad_to=16,
|
||||
max_duration=16.7,
|
||||
frame_splicing=1,
|
||||
exact_pad=False,
|
||||
stft_exact_pad=False, # TODO: Remove this in 1.1.0
|
||||
stft_conv=False, # TODO: Remove this in 1.1.0
|
||||
pad_value=0,
|
||||
mag_power=2.0,
|
||||
use_grads=False,
|
||||
):
|
||||
super().__init__()
|
||||
if stft_conv or stft_exact_pad:
|
||||
logging.warning(
|
||||
"Using torch_stft is deprecated and will be removed in 1.1.0. Please set stft_conv and stft_exact_pad "
|
||||
"to False for FilterbankFeatures and AudioToMelSpectrogramPreprocessor. Please set exact_pad to True "
|
||||
"as needed."
|
||||
)
|
||||
if (exact_pad or stft_exact_pad) and n_window_stride % 2 == 1:
|
||||
raise NotImplementedError(
|
||||
f"{self} received exact_pad == True, but hop_size was odd. If audio_length % hop_size == 0. Then the "
|
||||
"returned spectrogram would not be of length audio_length // hop_size. Please use an even hop_size."
|
||||
)
|
||||
self.log_zero_guard_value = log_zero_guard_value
|
||||
if (
|
||||
n_window_size is None
|
||||
or n_window_stride is None
|
||||
or not isinstance(n_window_size, int)
|
||||
or not isinstance(n_window_stride, int)
|
||||
or n_window_size <= 0
|
||||
or n_window_stride <= 0
|
||||
):
|
||||
raise ValueError(
|
||||
f"{self} got an invalid value for either n_window_size or "
|
||||
f"n_window_stride. Both must be positive ints."
|
||||
)
|
||||
logging.info(f"PADDING: {pad_to}")
|
||||
|
||||
self.win_length = n_window_size
|
||||
self.hop_length = n_window_stride
|
||||
self.n_fft = n_fft or 2 ** math.ceil(math.log2(self.win_length))
|
||||
self.stft_pad_amount = (self.n_fft - self.hop_length) // 2 if exact_pad else None
|
||||
self.stft_exact_pad = stft_exact_pad
|
||||
self.stft_conv = stft_conv
|
||||
|
||||
if stft_conv:
|
||||
logging.info("STFT using conv")
|
||||
if stft_exact_pad:
|
||||
logging.info("STFT using exact pad")
|
||||
self.stft = STFTExactPad(self.n_fft, self.hop_length, self.win_length, window)
|
||||
else:
|
||||
self.stft = STFTPatch(self.n_fft, self.hop_length, self.win_length, window)
|
||||
else:
|
||||
logging.info("STFT using torch")
|
||||
if exact_pad:
|
||||
logging.info("STFT using exact pad")
|
||||
torch_windows = {
|
||||
'hann': torch.hann_window,
|
||||
'hamming': torch.hamming_window,
|
||||
'blackman': torch.blackman_window,
|
||||
'bartlett': torch.bartlett_window,
|
||||
'none': None,
|
||||
}
|
||||
window_fn = torch_windows.get(window, None)
|
||||
window_tensor = window_fn(self.win_length, periodic=False) if window_fn else None
|
||||
self.register_buffer("window", window_tensor)
|
||||
self.stft = lambda x: stft_patch(
|
||||
x,
|
||||
n_fft=self.n_fft,
|
||||
hop_length=self.hop_length,
|
||||
win_length=self.win_length,
|
||||
center=False if exact_pad else True,
|
||||
window=self.window.to(dtype=torch.float),
|
||||
return_complex=False,
|
||||
)
|
||||
|
||||
self.normalize = normalize
|
||||
self.log = log
|
||||
self.dither = dither
|
||||
self.frame_splicing = frame_splicing
|
||||
self.nfilt = nfilt
|
||||
self.preemph = preemph
|
||||
self.pad_to = pad_to
|
||||
highfreq = highfreq or sample_rate / 2
|
||||
|
||||
filterbanks = torch.tensor(
|
||||
librosa.filters.mel(sample_rate, self.n_fft, n_mels=nfilt, fmin=lowfreq, fmax=highfreq), dtype=torch.float
|
||||
).unsqueeze(0)
|
||||
self.register_buffer("fb", filterbanks)
|
||||
|
||||
# Calculate maximum sequence length
|
||||
max_length = self.get_seq_len(torch.tensor(max_duration * sample_rate, dtype=torch.float))
|
||||
max_pad = pad_to - (max_length % pad_to) if pad_to > 0 else 0
|
||||
self.max_length = max_length + max_pad
|
||||
self.pad_value = pad_value
|
||||
self.mag_power = mag_power
|
||||
|
||||
# We want to avoid taking the log of zero
|
||||
# There are two options: either adding or clamping to a small value
|
||||
if log_zero_guard_type not in ["add", "clamp"]:
|
||||
raise ValueError(
|
||||
f"{self} received {log_zero_guard_type} for the "
|
||||
f"log_zero_guard_type parameter. It must be either 'add' or "
|
||||
f"'clamp'."
|
||||
)
|
||||
|
||||
self.use_grads = use_grads
|
||||
if not use_grads:
|
||||
self.forward = torch.no_grad()(self.forward)
|
||||
|
||||
# log_zero_guard_value is the the small we want to use, we support
|
||||
# an actual number, or "tiny", or "eps"
|
||||
self.log_zero_guard_type = log_zero_guard_type
|
||||
logging.debug(f"sr: {sample_rate}")
|
||||
logging.debug(f"n_fft: {self.n_fft}")
|
||||
logging.debug(f"win_length: {self.win_length}")
|
||||
logging.debug(f"hop_length: {self.hop_length}")
|
||||
logging.debug(f"n_mels: {nfilt}")
|
||||
logging.debug(f"fmin: {lowfreq}")
|
||||
logging.debug(f"fmax: {highfreq}")
|
||||
logging.debug(f"using grads: {use_grads}")
|
||||
|
||||
def log_zero_guard_value_fn(self, x):
|
||||
if isinstance(self.log_zero_guard_value, str):
|
||||
if self.log_zero_guard_value == "tiny":
|
||||
return torch.finfo(x.dtype).tiny
|
||||
elif self.log_zero_guard_value == "eps":
|
||||
return torch.finfo(x.dtype).eps
|
||||
else:
|
||||
raise ValueError(
|
||||
f"{self} received {self.log_zero_guard_value} for the "
|
||||
f"log_zero_guard_type parameter. It must be either a "
|
||||
f"number, 'tiny', or 'eps'"
|
||||
)
|
||||
else:
|
||||
return self.log_zero_guard_value
|
||||
|
||||
def get_seq_len(self, seq_len):
|
||||
if isinstance(self.stft, STFT):
|
||||
pad_amount = self.stft.pad_amount * 2
|
||||
else:
|
||||
# Assuming that center is True is stft_pad_amount = 0
|
||||
pad_amount = self.stft_pad_amount * 2 if self.stft_pad_amount is not None else self.n_fft // 2 * 2
|
||||
seq_len = torch.floor((seq_len + pad_amount - self.n_fft) / self.hop_length) + 1
|
||||
return seq_len.to(dtype=torch.long)
|
||||
|
||||
@property
|
||||
def filter_banks(self):
|
||||
return self.fb
|
||||
|
||||
def forward(self, x, seq_len):
|
||||
seq_len = self.get_seq_len(seq_len.float())
|
||||
|
||||
if self.stft_pad_amount is not None:
|
||||
x = torch.nn.functional.pad(
|
||||
x.unsqueeze(1), (self.stft_pad_amount, self.stft_pad_amount), "reflect"
|
||||
).squeeze(1)
|
||||
|
||||
# dither (only in training mode for eval determinism)
|
||||
if self.training and self.dither > 0:
|
||||
x += self.dither * torch.randn_like(x)
|
||||
|
||||
# do preemphasis
|
||||
if self.preemph is not None:
|
||||
x = torch.cat((x[:, 0].unsqueeze(1), x[:, 1:] - self.preemph * x[:, :-1]), dim=1)
|
||||
|
||||
# disable autocast to get full range of stft values
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
x = self.stft(x)
|
||||
|
||||
# torch returns real, imag; so convert to magnitude
|
||||
if not self.stft_conv:
|
||||
# guard is needed for sqrt if grads are passed through
|
||||
guard = 0 if not self.use_grads else CONSTANT
|
||||
if x.dtype in [torch.cfloat, torch.cdouble]:
|
||||
x = torch.view_as_real(x)
|
||||
x = torch.sqrt(x.pow(2).sum(-1) + guard)
|
||||
|
||||
# get power spectrum
|
||||
if self.mag_power != 1.0:
|
||||
x = x.pow(self.mag_power)
|
||||
|
||||
# dot with filterbank energies
|
||||
x = torch.matmul(self.fb.to(x.dtype), x)
|
||||
|
||||
# log features if required
|
||||
if self.log:
|
||||
if self.log_zero_guard_type == "add":
|
||||
x = torch.log(x + self.log_zero_guard_value_fn(x))
|
||||
elif self.log_zero_guard_type == "clamp":
|
||||
x = torch.log(torch.clamp(x, min=self.log_zero_guard_value_fn(x)))
|
||||
else:
|
||||
raise ValueError("log_zero_guard_type was not understood")
|
||||
|
||||
# frame splicing if required
|
||||
if self.frame_splicing > 1:
|
||||
x = splice_frames(x, self.frame_splicing)
|
||||
|
||||
# normalize if required
|
||||
if self.normalize:
|
||||
x = normalize_batch(x, seq_len, normalize_type=self.normalize)
|
||||
|
||||
# mask to zero any values beyond seq_len in batch, pad to multiple of `pad_to` (for efficiency)
|
||||
max_len = x.size(-1)
|
||||
mask = torch.arange(max_len).to(x.device)
|
||||
mask = mask.expand(x.size(0), max_len) >= seq_len.unsqueeze(1)
|
||||
x = x.masked_fill(mask.unsqueeze(1).type(torch.bool).to(device=x.device), self.pad_value)
|
||||
del mask
|
||||
pad_to = self.pad_to
|
||||
if pad_to == "max":
|
||||
x = nn.functional.pad(x, (0, self.max_length - x.size(-1)), value=self.pad_value)
|
||||
elif pad_to > 0:
|
||||
pad_amt = x.size(-1) % pad_to
|
||||
if pad_amt != 0:
|
||||
x = nn.functional.pad(x, (0, pad_to - pad_amt), value=self.pad_value)
|
||||
|
||||
return x, seq_len
|
||||
"""
|
||||
ALIAS FILE for backward compatibility
|
||||
"""
|
||||
from nemo.collections.asr.parts.preprocessing.features import *
|
||||
|
|
15
nemo/collections/asr/parts/mixins/__init__.py
Normal file
15
nemo/collections/asr/parts/mixins/__init__.py
Normal file
|
@ -0,0 +1,15 @@
|
|||
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from nemo.collections.asr.parts.mixins.mixins import ASRBPEMixin, ASRModuleMixin, DiarizationMixin
|
|
@ -18,7 +18,7 @@ from typing import List
|
|||
|
||||
from omegaconf import DictConfig, OmegaConf, open_dict
|
||||
|
||||
from nemo.collections.asr.parts import asr_module_utils
|
||||
from nemo.collections.asr.parts.utils import asr_module_utils
|
||||
from nemo.collections.common import tokenizers
|
||||
from nemo.utils import logging
|
||||
|
41
nemo/collections/asr/parts/preprocessing/__init__.py
Normal file
41
nemo/collections/asr/parts/preprocessing/__init__.py
Normal file
|
@ -0,0 +1,41 @@
|
|||
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from nemo.collections.asr.parts.preprocessing.feature_loader import ExternalFeatureLoader
|
||||
from nemo.collections.asr.parts.preprocessing.features import (
|
||||
STFT,
|
||||
FeaturizerFactory,
|
||||
FilterbankFeatures,
|
||||
STFTExactPad,
|
||||
STFTPatch,
|
||||
WaveformFeaturizer,
|
||||
)
|
||||
from nemo.collections.asr.parts.preprocessing.perturb import (
|
||||
AudioAugmentor,
|
||||
AugmentationDataset,
|
||||
GainPerturbation,
|
||||
ImpulsePerturbation,
|
||||
NoisePerturbation,
|
||||
Perturbation,
|
||||
RirAndNoisePerturbation,
|
||||
ShiftPerturbation,
|
||||
SpeedPerturbation,
|
||||
TimeStretchPerturbation,
|
||||
TranscodePerturbation,
|
||||
WhiteNoisePerturbation,
|
||||
perturbation_types,
|
||||
process_augmentations,
|
||||
register_perturbation,
|
||||
)
|
||||
from nemo.collections.asr.parts.preprocessing.segment import AudioSegment
|
432
nemo/collections/asr/parts/preprocessing/features.py
Normal file
432
nemo/collections/asr/parts/preprocessing/features.py
Normal file
|
@ -0,0 +1,432 @@
|
|||
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# Copyright (c) 2018 Ryan Leary
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
# of this software and associated documentation files (the "Software"), to deal
|
||||
# in the Software without restriction, including without limitation the rights
|
||||
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
# copies of the Software, and to permit persons to whom the Software is
|
||||
# furnished to do so, subject to the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be included in all
|
||||
# copies or substantial portions of the Software.
|
||||
#
|
||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
# SOFTWARE.
|
||||
# This file contains code artifacts adapted from https://github.com/ryanleary/patter
|
||||
import math
|
||||
|
||||
import librosa
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from librosa.util import tiny
|
||||
from torch.autograd import Variable
|
||||
from torch_stft import STFT
|
||||
|
||||
from nemo.collections.asr.parts.preprocessing.perturb import AudioAugmentor
|
||||
from nemo.collections.asr.parts.preprocessing.segment import AudioSegment
|
||||
from nemo.collections.common.parts.patch_utils import stft_patch
|
||||
from nemo.utils import logging
|
||||
|
||||
CONSTANT = 1e-5
|
||||
|
||||
|
||||
def normalize_batch(x, seq_len, normalize_type):
|
||||
if normalize_type == "per_feature":
|
||||
x_mean = torch.zeros((seq_len.shape[0], x.shape[1]), dtype=x.dtype, device=x.device)
|
||||
x_std = torch.zeros((seq_len.shape[0], x.shape[1]), dtype=x.dtype, device=x.device)
|
||||
for i in range(x.shape[0]):
|
||||
if x[i, :, : seq_len[i]].shape[1] == 1:
|
||||
raise ValueError(
|
||||
"normalize_batch with `per_feature` normalize_type received a tensor of length 1. This will result "
|
||||
"in torch.std() returning nan"
|
||||
)
|
||||
x_mean[i, :] = x[i, :, : seq_len[i]].mean(dim=1)
|
||||
x_std[i, :] = x[i, :, : seq_len[i]].std(dim=1)
|
||||
# make sure x_std is not zero
|
||||
x_std += CONSTANT
|
||||
return (x - x_mean.unsqueeze(2)) / x_std.unsqueeze(2)
|
||||
elif normalize_type == "all_features":
|
||||
x_mean = torch.zeros(seq_len.shape, dtype=x.dtype, device=x.device)
|
||||
x_std = torch.zeros(seq_len.shape, dtype=x.dtype, device=x.device)
|
||||
for i in range(x.shape[0]):
|
||||
x_mean[i] = x[i, :, : seq_len[i].item()].mean()
|
||||
x_std[i] = x[i, :, : seq_len[i].item()].std()
|
||||
# make sure x_std is not zero
|
||||
x_std += CONSTANT
|
||||
return (x - x_mean.view(-1, 1, 1)) / x_std.view(-1, 1, 1)
|
||||
elif "fixed_mean" in normalize_type and "fixed_std" in normalize_type:
|
||||
x_mean = torch.tensor(normalize_type["fixed_mean"], device=x.device)
|
||||
x_std = torch.tensor(normalize_type["fixed_std"], device=x.device)
|
||||
return (x - x_mean.view(x.shape[0], x.shape[1]).unsqueeze(2)) / x_std.view(x.shape[0], x.shape[1]).unsqueeze(2)
|
||||
else:
|
||||
return x
|
||||
|
||||
|
||||
def splice_frames(x, frame_splicing):
|
||||
""" Stacks frames together across feature dim
|
||||
|
||||
input is batch_size, feature_dim, num_frames
|
||||
output is batch_size, feature_dim*frame_splicing, num_frames
|
||||
|
||||
"""
|
||||
seq = [x]
|
||||
for n in range(1, frame_splicing):
|
||||
seq.append(torch.cat([x[:, :, :n], x[:, :, n:]], dim=2))
|
||||
return torch.cat(seq, dim=1)
|
||||
|
||||
|
||||
class WaveformFeaturizer(object):
|
||||
def __init__(self, sample_rate=16000, int_values=False, augmentor=None):
|
||||
self.augmentor = augmentor if augmentor is not None else AudioAugmentor()
|
||||
self.sample_rate = sample_rate
|
||||
self.int_values = int_values
|
||||
|
||||
def max_augmentation_length(self, length):
|
||||
return self.augmentor.max_augmentation_length(length)
|
||||
|
||||
def process(self, file_path, offset=0, duration=0, trim=False, orig_sr=None):
|
||||
audio = AudioSegment.from_file(
|
||||
file_path,
|
||||
target_sr=self.sample_rate,
|
||||
int_values=self.int_values,
|
||||
offset=offset,
|
||||
duration=duration,
|
||||
trim=trim,
|
||||
orig_sr=orig_sr,
|
||||
)
|
||||
return self.process_segment(audio)
|
||||
|
||||
def process_segment(self, audio_segment):
|
||||
self.augmentor.perturb(audio_segment)
|
||||
return torch.tensor(audio_segment.samples, dtype=torch.float)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, input_config, perturbation_configs=None):
|
||||
if perturbation_configs is not None:
|
||||
aa = AudioAugmentor.from_config(perturbation_configs)
|
||||
else:
|
||||
aa = None
|
||||
|
||||
sample_rate = input_config.get("sample_rate", 16000)
|
||||
int_values = input_config.get("int_values", False)
|
||||
|
||||
return cls(sample_rate=sample_rate, int_values=int_values, augmentor=aa)
|
||||
|
||||
|
||||
class FeaturizerFactory(object):
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, input_cfg, perturbation_configs=None):
|
||||
return WaveformFeaturizer.from_config(input_cfg, perturbation_configs=perturbation_configs)
|
||||
|
||||
|
||||
# Create helper class to patch forward func for use with AMP
|
||||
class STFTPatch(STFT):
|
||||
def forward(self, input_data):
|
||||
return super().transform(input_data)[0]
|
||||
|
||||
|
||||
# Create helper class for STFT that yields num_frames = num_samples // hop_length
|
||||
class STFTExactPad(STFTPatch):
|
||||
"""adapted from Prem Seetharaman's https://github.com/pseeth/pytorch-stft"""
|
||||
|
||||
def __init__(self, *params, **kw_params):
|
||||
super().__init__(*params, **kw_params)
|
||||
self.pad_amount = (self.filter_length - self.hop_length) // 2
|
||||
|
||||
def inverse(self, magnitude, phase):
|
||||
recombine_magnitude_phase = torch.cat([magnitude * torch.cos(phase), magnitude * torch.sin(phase)], dim=1)
|
||||
|
||||
inverse_transform = F.conv_transpose1d(
|
||||
recombine_magnitude_phase,
|
||||
Variable(self.inverse_basis, requires_grad=False),
|
||||
stride=self.hop_length,
|
||||
padding=0,
|
||||
)
|
||||
|
||||
if self.window is not None:
|
||||
window_sum = librosa.filters.window_sumsquare(
|
||||
self.window,
|
||||
magnitude.size(-1),
|
||||
hop_length=self.hop_length,
|
||||
win_length=self.win_length,
|
||||
n_fft=self.filter_length,
|
||||
dtype=np.float32,
|
||||
)
|
||||
# remove modulation effects
|
||||
approx_nonzero_indices = torch.from_numpy(np.where(window_sum > tiny(window_sum))[0])
|
||||
window_sum = torch.autograd.Variable(torch.from_numpy(window_sum), requires_grad=False).to(
|
||||
magnitude.device
|
||||
)
|
||||
inverse_transform[..., approx_nonzero_indices] /= window_sum[approx_nonzero_indices]
|
||||
|
||||
# scale by hop ratio
|
||||
inverse_transform *= self.filter_length / self.hop_length
|
||||
|
||||
inverse_transform = inverse_transform[..., self.pad_amount :]
|
||||
inverse_transform = inverse_transform[..., : -self.pad_amount :]
|
||||
inverse_transform = inverse_transform.squeeze(1)
|
||||
|
||||
return inverse_transform
|
||||
|
||||
|
||||
class FilterbankFeatures(nn.Module):
|
||||
"""Featurizer that converts wavs to Mel Spectrograms.
|
||||
See AudioToMelSpectrogramPreprocessor for args.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
sample_rate=16000,
|
||||
n_window_size=320,
|
||||
n_window_stride=160,
|
||||
window="hann",
|
||||
normalize="per_feature",
|
||||
n_fft=None,
|
||||
preemph=0.97,
|
||||
nfilt=64,
|
||||
lowfreq=0,
|
||||
highfreq=None,
|
||||
log=True,
|
||||
log_zero_guard_type="add",
|
||||
log_zero_guard_value=2 ** -24,
|
||||
dither=CONSTANT,
|
||||
pad_to=16,
|
||||
max_duration=16.7,
|
||||
frame_splicing=1,
|
||||
exact_pad=False,
|
||||
stft_exact_pad=False, # TODO: Remove this in 1.1.0
|
||||
stft_conv=False, # TODO: Remove this in 1.1.0
|
||||
pad_value=0,
|
||||
mag_power=2.0,
|
||||
use_grads=False,
|
||||
):
|
||||
super().__init__()
|
||||
if stft_conv or stft_exact_pad:
|
||||
logging.warning(
|
||||
"Using torch_stft is deprecated and will be removed in 1.1.0. Please set stft_conv and stft_exact_pad "
|
||||
"to False for FilterbankFeatures and AudioToMelSpectrogramPreprocessor. Please set exact_pad to True "
|
||||
"as needed."
|
||||
)
|
||||
if (exact_pad or stft_exact_pad) and n_window_stride % 2 == 1:
|
||||
raise NotImplementedError(
|
||||
f"{self} received exact_pad == True, but hop_size was odd. If audio_length % hop_size == 0. Then the "
|
||||
"returned spectrogram would not be of length audio_length // hop_size. Please use an even hop_size."
|
||||
)
|
||||
self.log_zero_guard_value = log_zero_guard_value
|
||||
if (
|
||||
n_window_size is None
|
||||
or n_window_stride is None
|
||||
or not isinstance(n_window_size, int)
|
||||
or not isinstance(n_window_stride, int)
|
||||
or n_window_size <= 0
|
||||
or n_window_stride <= 0
|
||||
):
|
||||
raise ValueError(
|
||||
f"{self} got an invalid value for either n_window_size or "
|
||||
f"n_window_stride. Both must be positive ints."
|
||||
)
|
||||
logging.info(f"PADDING: {pad_to}")
|
||||
|
||||
self.win_length = n_window_size
|
||||
self.hop_length = n_window_stride
|
||||
self.n_fft = n_fft or 2 ** math.ceil(math.log2(self.win_length))
|
||||
self.stft_pad_amount = (self.n_fft - self.hop_length) // 2 if exact_pad else None
|
||||
self.stft_exact_pad = stft_exact_pad
|
||||
self.stft_conv = stft_conv
|
||||
|
||||
if stft_conv:
|
||||
logging.info("STFT using conv")
|
||||
if stft_exact_pad:
|
||||
logging.info("STFT using exact pad")
|
||||
self.stft = STFTExactPad(self.n_fft, self.hop_length, self.win_length, window)
|
||||
else:
|
||||
self.stft = STFTPatch(self.n_fft, self.hop_length, self.win_length, window)
|
||||
else:
|
||||
logging.info("STFT using torch")
|
||||
if exact_pad:
|
||||
logging.info("STFT using exact pad")
|
||||
torch_windows = {
|
||||
'hann': torch.hann_window,
|
||||
'hamming': torch.hamming_window,
|
||||
'blackman': torch.blackman_window,
|
||||
'bartlett': torch.bartlett_window,
|
||||
'none': None,
|
||||
}
|
||||
window_fn = torch_windows.get(window, None)
|
||||
window_tensor = window_fn(self.win_length, periodic=False) if window_fn else None
|
||||
self.register_buffer("window", window_tensor)
|
||||
self.stft = lambda x: stft_patch(
|
||||
x,
|
||||
n_fft=self.n_fft,
|
||||
hop_length=self.hop_length,
|
||||
win_length=self.win_length,
|
||||
center=False if exact_pad else True,
|
||||
window=self.window.to(dtype=torch.float),
|
||||
return_complex=False,
|
||||
)
|
||||
|
||||
self.normalize = normalize
|
||||
self.log = log
|
||||
self.dither = dither
|
||||
self.frame_splicing = frame_splicing
|
||||
self.nfilt = nfilt
|
||||
self.preemph = preemph
|
||||
self.pad_to = pad_to
|
||||
highfreq = highfreq or sample_rate / 2
|
||||
|
||||
filterbanks = torch.tensor(
|
||||
librosa.filters.mel(sample_rate, self.n_fft, n_mels=nfilt, fmin=lowfreq, fmax=highfreq), dtype=torch.float
|
||||
).unsqueeze(0)
|
||||
self.register_buffer("fb", filterbanks)
|
||||
|
||||
# Calculate maximum sequence length
|
||||
max_length = self.get_seq_len(torch.tensor(max_duration * sample_rate, dtype=torch.float))
|
||||
max_pad = pad_to - (max_length % pad_to) if pad_to > 0 else 0
|
||||
self.max_length = max_length + max_pad
|
||||
self.pad_value = pad_value
|
||||
self.mag_power = mag_power
|
||||
|
||||
# We want to avoid taking the log of zero
|
||||
# There are two options: either adding or clamping to a small value
|
||||
if log_zero_guard_type not in ["add", "clamp"]:
|
||||
raise ValueError(
|
||||
f"{self} received {log_zero_guard_type} for the "
|
||||
f"log_zero_guard_type parameter. It must be either 'add' or "
|
||||
f"'clamp'."
|
||||
)
|
||||
|
||||
self.use_grads = use_grads
|
||||
if not use_grads:
|
||||
self.forward = torch.no_grad()(self.forward)
|
||||
|
||||
# log_zero_guard_value is the the small we want to use, we support
|
||||
# an actual number, or "tiny", or "eps"
|
||||
self.log_zero_guard_type = log_zero_guard_type
|
||||
logging.debug(f"sr: {sample_rate}")
|
||||
logging.debug(f"n_fft: {self.n_fft}")
|
||||
logging.debug(f"win_length: {self.win_length}")
|
||||
logging.debug(f"hop_length: {self.hop_length}")
|
||||
logging.debug(f"n_mels: {nfilt}")
|
||||
logging.debug(f"fmin: {lowfreq}")
|
||||
logging.debug(f"fmax: {highfreq}")
|
||||
logging.debug(f"using grads: {use_grads}")
|
||||
|
||||
def log_zero_guard_value_fn(self, x):
|
||||
if isinstance(self.log_zero_guard_value, str):
|
||||
if self.log_zero_guard_value == "tiny":
|
||||
return torch.finfo(x.dtype).tiny
|
||||
elif self.log_zero_guard_value == "eps":
|
||||
return torch.finfo(x.dtype).eps
|
||||
else:
|
||||
raise ValueError(
|
||||
f"{self} received {self.log_zero_guard_value} for the "
|
||||
f"log_zero_guard_type parameter. It must be either a "
|
||||
f"number, 'tiny', or 'eps'"
|
||||
)
|
||||
else:
|
||||
return self.log_zero_guard_value
|
||||
|
||||
def get_seq_len(self, seq_len):
|
||||
if isinstance(self.stft, STFT):
|
||||
pad_amount = self.stft.pad_amount * 2
|
||||
else:
|
||||
# Assuming that center is True is stft_pad_amount = 0
|
||||
pad_amount = self.stft_pad_amount * 2 if self.stft_pad_amount is not None else self.n_fft // 2 * 2
|
||||
seq_len = torch.floor((seq_len + pad_amount - self.n_fft) / self.hop_length) + 1
|
||||
return seq_len.to(dtype=torch.long)
|
||||
|
||||
@property
|
||||
def filter_banks(self):
|
||||
return self.fb
|
||||
|
||||
def forward(self, x, seq_len):
|
||||
seq_len = self.get_seq_len(seq_len.float())
|
||||
|
||||
if self.stft_pad_amount is not None:
|
||||
x = torch.nn.functional.pad(
|
||||
x.unsqueeze(1), (self.stft_pad_amount, self.stft_pad_amount), "reflect"
|
||||
).squeeze(1)
|
||||
|
||||
# dither (only in training mode for eval determinism)
|
||||
if self.training and self.dither > 0:
|
||||
x += self.dither * torch.randn_like(x)
|
||||
|
||||
# do preemphasis
|
||||
if self.preemph is not None:
|
||||
x = torch.cat((x[:, 0].unsqueeze(1), x[:, 1:] - self.preemph * x[:, :-1]), dim=1)
|
||||
|
||||
# disable autocast to get full range of stft values
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
x = self.stft(x)
|
||||
|
||||
# torch returns real, imag; so convert to magnitude
|
||||
if not self.stft_conv:
|
||||
# guard is needed for sqrt if grads are passed through
|
||||
guard = 0 if not self.use_grads else CONSTANT
|
||||
if x.dtype in [torch.cfloat, torch.cdouble]:
|
||||
x = torch.view_as_real(x)
|
||||
x = torch.sqrt(x.pow(2).sum(-1) + guard)
|
||||
|
||||
# get power spectrum
|
||||
if self.mag_power != 1.0:
|
||||
x = x.pow(self.mag_power)
|
||||
|
||||
# dot with filterbank energies
|
||||
x = torch.matmul(self.fb.to(x.dtype), x)
|
||||
|
||||
# log features if required
|
||||
if self.log:
|
||||
if self.log_zero_guard_type == "add":
|
||||
x = torch.log(x + self.log_zero_guard_value_fn(x))
|
||||
elif self.log_zero_guard_type == "clamp":
|
||||
x = torch.log(torch.clamp(x, min=self.log_zero_guard_value_fn(x)))
|
||||
else:
|
||||
raise ValueError("log_zero_guard_type was not understood")
|
||||
|
||||
# frame splicing if required
|
||||
if self.frame_splicing > 1:
|
||||
x = splice_frames(x, self.frame_splicing)
|
||||
|
||||
# normalize if required
|
||||
if self.normalize:
|
||||
x = normalize_batch(x, seq_len, normalize_type=self.normalize)
|
||||
|
||||
# mask to zero any values beyond seq_len in batch, pad to multiple of `pad_to` (for efficiency)
|
||||
max_len = x.size(-1)
|
||||
mask = torch.arange(max_len).to(x.device)
|
||||
mask = mask.expand(x.size(0), max_len) >= seq_len.unsqueeze(1)
|
||||
x = x.masked_fill(mask.unsqueeze(1).type(torch.bool).to(device=x.device), self.pad_value)
|
||||
del mask
|
||||
pad_to = self.pad_to
|
||||
if pad_to == "max":
|
||||
x = nn.functional.pad(x, (0, self.max_length - x.size(-1)), value=self.pad_value)
|
||||
elif pad_to > 0:
|
||||
pad_amt = x.size(-1) % pad_to
|
||||
if pad_amt != 0:
|
||||
x = nn.functional.pad(x, (0, pad_to - pad_amt), value=self.pad_value)
|
||||
|
||||
return x, seq_len
|
|
@ -48,12 +48,12 @@ from omegaconf import DictConfig, OmegaConf
|
|||
from scipy import signal
|
||||
from torch.utils.data import IterableDataset
|
||||
|
||||
from nemo.collections.asr.parts import collections, parsers
|
||||
from nemo.collections.asr.parts.segment import AudioSegment
|
||||
from nemo.collections.asr.parts.preprocessing.segment import AudioSegment
|
||||
from nemo.collections.common.parts.preprocessing import collections, parsers
|
||||
from nemo.utils import logging
|
||||
|
||||
try:
|
||||
from nemo.collections.asr.parts import numba_utils
|
||||
from nemo.collections.asr.parts.utils import numba_utils
|
||||
|
||||
HAVE_NUMBA = True
|
||||
except (ImportError, ModuleNotFoundError):
|
|
@ -42,6 +42,7 @@ import soundfile as sf
|
|||
from kaldiio.matio import read_kaldi
|
||||
from kaldiio.utils import open_like_kaldi
|
||||
from pydub import AudioSegment as Audio
|
||||
from pydub.exceptions import CouldntDecodeError
|
||||
|
||||
from nemo.utils import logging
|
||||
|
||||
|
@ -145,7 +146,8 @@ class AudioSegment(object):
|
|||
samples = samples.transpose()
|
||||
except RuntimeError as e:
|
||||
logging.error(
|
||||
f"Loading audio via SoundFile raised RuntimeError: `{e}`. NeMo will fallback to loading via pydub."
|
||||
f"Loading {audio_file} via SoundFile raised RuntimeError: `{e}`. "
|
||||
f"NeMo will fallback to loading via pydub."
|
||||
)
|
||||
elif isinstance(audio_file, str) and audio_file.strip()[-1] == "|":
|
||||
f = open_like_kaldi(audio_file, "rb")
|
||||
|
@ -159,16 +161,19 @@ class AudioSegment(object):
|
|||
samples = np.array(samples, dtype=np.float) / abs_max_value
|
||||
|
||||
if samples is None:
|
||||
samples = Audio.from_file(audio_file)
|
||||
sample_rate = samples.frame_rate
|
||||
if offset > 0:
|
||||
# pydub does things in milliseconds
|
||||
seconds = offset * 1000
|
||||
samples = samples[int(seconds * sample_rate) :]
|
||||
if duration > 0:
|
||||
seconds = duration * 1000
|
||||
samples = samples[: int(seconds)]
|
||||
samples = np.array(samples.get_array_of_samples())
|
||||
try:
|
||||
samples = Audio.from_file(audio_file)
|
||||
sample_rate = samples.frame_rate
|
||||
if offset > 0:
|
||||
# pydub does things in milliseconds
|
||||
seconds = offset * 1000
|
||||
samples = samples[int(seconds * sample_rate) :]
|
||||
if duration > 0:
|
||||
seconds = duration * 1000
|
||||
samples = samples[: int(seconds)]
|
||||
samples = np.array(samples.get_array_of_samples())
|
||||
except CouldntDecodeError as e:
|
||||
logging.error(f"Loading {audio_file} via pydub raised CouldntDecodeError: `{e}`.")
|
||||
|
||||
return cls(samples, sample_rate, target_sr=target_sr, trim=trim, orig_sr=orig_sr)
|
||||
|
||||
|
@ -179,15 +184,19 @@ class AudioSegment(object):
|
|||
|
||||
Note that audio_file can be either the file path, or a file-like object.
|
||||
"""
|
||||
with sf.SoundFile(audio_file, 'r') as f:
|
||||
sample_rate = f.samplerate
|
||||
if n_segments > 0 and len(f) > n_segments:
|
||||
max_audio_start = len(f) - n_segments
|
||||
audio_start = random.randint(0, max_audio_start)
|
||||
f.seek(audio_start)
|
||||
samples = f.read(n_segments, dtype='float32')
|
||||
else:
|
||||
samples = f.read(dtype='float32')
|
||||
try:
|
||||
with sf.SoundFile(audio_file, 'r') as f:
|
||||
sample_rate = f.samplerate
|
||||
if n_segments > 0 and len(f) > n_segments:
|
||||
max_audio_start = len(f) - n_segments
|
||||
audio_start = random.randint(0, max_audio_start)
|
||||
f.seek(audio_start)
|
||||
samples = f.read(n_segments, dtype='float32')
|
||||
else:
|
||||
samples = f.read(dtype='float32')
|
||||
samples = samples.transpose()
|
||||
except RuntimeError as e:
|
||||
logging.error(f"Loading {audio_file} via SoundFile raised RuntimeError: `{e}`.")
|
||||
|
||||
samples = samples.transpose()
|
||||
return cls(samples, sample_rate, target_sr=target_sr, trim=trim, orig_sr=orig_sr)
|
13
nemo/collections/asr/parts/submodules/__init__.py
Normal file
13
nemo/collections/asr/parts/submodules/__init__.py
Normal file
|
@ -0,0 +1,13 @@
|
|||
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
|
@ -16,8 +16,11 @@ import torch
|
|||
from torch import nn as nn
|
||||
from torch.nn import LayerNorm
|
||||
|
||||
from nemo.collections.asr.parts.activations import Swish
|
||||
from nemo.collections.asr.parts.multi_head_attention import MultiHeadAttention, RelPositionMultiHeadAttention
|
||||
from nemo.collections.asr.parts.submodules.multi_head_attention import (
|
||||
MultiHeadAttention,
|
||||
RelPositionMultiHeadAttention,
|
||||
)
|
||||
from nemo.collections.asr.parts.utils.activations import Swish
|
||||
|
||||
__all__ = ['ConformerConvolution', 'ConformerFeedForward', 'ConformerLayer']
|
||||
|
|
@ -21,7 +21,7 @@ from torch import Tensor
|
|||
from torch.nn.init import _calculate_correct_fan
|
||||
from torch.nn.modules.utils import _single
|
||||
|
||||
from nemo.collections.asr.parts.activations import Swish
|
||||
from nemo.collections.asr.parts.utils.activations import Swish
|
||||
from nemo.utils import logging
|
||||
|
||||
try:
|
||||
|
@ -34,7 +34,7 @@ try:
|
|||
except ImportError:
|
||||
PYTORCH_QUANTIZATION_AVAILABLE = False
|
||||
|
||||
jasper_activations = {"hardtanh": nn.Hardtanh, "relu": nn.ReLU, "selu": nn.SELU, "swish": Swish}
|
||||
jasper_activations = {"hardtanh": nn.Hardtanh, "relu": nn.ReLU, "selu": nn.SELU, "swish": Swish, "silu": nn.SiLU}
|
||||
|
||||
|
||||
def tds_uniform_(tensor, mode='fan_in'):
|
|
@ -35,8 +35,7 @@ import torch
|
|||
from tqdm import tqdm
|
||||
|
||||
from nemo.collections.asr.modules import rnnt_abstract
|
||||
from nemo.collections.asr.parts import rnnt_utils
|
||||
from nemo.collections.asr.parts.rnnt_utils import Hypothesis, NBestHypotheses
|
||||
from nemo.collections.asr.parts.utils.rnnt_utils import Hypothesis, NBestHypotheses
|
||||
from nemo.core.classes import Typing, typecheck
|
||||
from nemo.core.neural_types import AcousticEncodedRepresentation, HypothesisType, LengthsType, NeuralType
|
||||
|
|
@ -32,7 +32,7 @@ from typing import List, Optional, Union
|
|||
import torch
|
||||
|
||||
from nemo.collections.asr.modules import rnnt_abstract
|
||||
from nemo.collections.asr.parts import rnnt_utils
|
||||
from nemo.collections.asr.parts.utils import rnnt_utils
|
||||
from nemo.collections.common.parts.rnn import label_collate
|
||||
from nemo.core.classes import Typing, typecheck
|
||||
from nemo.core.neural_types import AcousticEncodedRepresentation, HypothesisType, LengthsType, NeuralType
|
|
@ -17,8 +17,11 @@ import random
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from nemo.core.classes import Typing, typecheck
|
||||
from nemo.core.neural_types import NeuralType, SpectrogramType
|
||||
|
||||
class SpecAugment(nn.Module):
|
||||
|
||||
class SpecAugment(nn.Module, Typing):
|
||||
"""
|
||||
Zeroes out(cuts) random continuous horisontal or
|
||||
vertical segments of the spectrogram as described in
|
||||
|
@ -36,6 +39,18 @@ class SpecAugment(nn.Module):
|
|||
are cut adaptively.
|
||||
"""
|
||||
|
||||
@property
|
||||
def input_types(self):
|
||||
"""Returns definitions of module input types
|
||||
"""
|
||||
return {"input_spec": NeuralType(('B', 'D', 'T'), SpectrogramType())}
|
||||
|
||||
@property
|
||||
def output_types(self):
|
||||
"""Returns definitions of module output types
|
||||
"""
|
||||
return {"augmented_spec": NeuralType(('B', 'D', 'T'), SpectrogramType())}
|
||||
|
||||
def __init__(
|
||||
self, freq_masks=0, time_masks=0, freq_width=10, time_width=10, rng=None, mask_value=0.0,
|
||||
):
|
||||
|
@ -59,9 +74,10 @@ class SpecAugment(nn.Module):
|
|||
|
||||
self.adaptive_temporal_width = True
|
||||
|
||||
@typecheck()
|
||||
@torch.no_grad()
|
||||
def forward(self, x):
|
||||
sh = x.shape
|
||||
def forward(self, input_spec):
|
||||
sh = input_spec.shape
|
||||
|
||||
if self.adaptive_temporal_width:
|
||||
time_width = max(1, int(sh[2] * self.time_width))
|
||||
|
@ -74,19 +90,19 @@ class SpecAugment(nn.Module):
|
|||
|
||||
w = self._rng.randint(0, self.freq_width)
|
||||
|
||||
x[idx, x_left : x_left + w, :] = self.mask_value
|
||||
input_spec[idx, x_left : x_left + w, :] = self.mask_value
|
||||
|
||||
for i in range(self.time_masks):
|
||||
y_left = self._rng.randint(0, sh[2] - time_width)
|
||||
|
||||
w = self._rng.randint(0, time_width)
|
||||
|
||||
x[idx, :, y_left : y_left + w] = self.mask_value
|
||||
input_spec[idx, :, y_left : y_left + w] = self.mask_value
|
||||
|
||||
return x
|
||||
return input_spec
|
||||
|
||||
|
||||
class SpecCutout(nn.Module):
|
||||
class SpecCutout(nn.Module, Typing):
|
||||
"""
|
||||
Zeroes out(cuts) random rectangles in the spectrogram
|
||||
as described in (https://arxiv.org/abs/1708.04552).
|
||||
|
@ -97,6 +113,18 @@ class SpecCutout(nn.Module):
|
|||
rect_time - maximum size of cut rectangles along the time dimension
|
||||
"""
|
||||
|
||||
@property
|
||||
def input_types(self):
|
||||
"""Returns definitions of module input types
|
||||
"""
|
||||
return {"input_spec": NeuralType(('B', 'D', 'T'), SpectrogramType())}
|
||||
|
||||
@property
|
||||
def output_types(self):
|
||||
"""Returns definitions of module output types
|
||||
"""
|
||||
return {"augmented_spec": NeuralType(('B', 'D', 'T'), SpectrogramType())}
|
||||
|
||||
def __init__(self, rect_masks=0, rect_time=5, rect_freq=20, rng=None):
|
||||
super(SpecCutout, self).__init__()
|
||||
|
||||
|
@ -106,9 +134,10 @@ class SpecCutout(nn.Module):
|
|||
self.rect_time = rect_time
|
||||
self.rect_freq = rect_freq
|
||||
|
||||
@typecheck()
|
||||
@torch.no_grad()
|
||||
def forward(self, x):
|
||||
sh = x.shape
|
||||
def forward(self, input_spec):
|
||||
sh = input_spec.shape
|
||||
|
||||
for idx in range(sh[0]):
|
||||
for i in range(self.rect_masks):
|
||||
|
@ -118,6 +147,6 @@ class SpecCutout(nn.Module):
|
|||
w_x = self._rng.randint(0, self.rect_freq)
|
||||
w_y = self._rng.randint(0, self.rect_time)
|
||||
|
||||
x[idx, rect_x : rect_x + w_x, rect_y : rect_y + w_y] = 0.0
|
||||
input_spec[idx, rect_x : rect_x + w_x, rect_y : rect_y + w_y] = 0.0
|
||||
|
||||
return x
|
||||
return input_spec
|
13
nemo/collections/asr/parts/utils/__init__.py
Normal file
13
nemo/collections/asr/parts/utils/__init__.py
Normal file
|
@ -0,0 +1,13 @@
|
|||
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
|
@ -12,16 +12,13 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
__all__ = ['Swish']
|
||||
|
||||
|
||||
class Swish(nn.Module):
|
||||
class Swish(nn.SiLU):
|
||||
"""
|
||||
Swish activation function introduced in 'https://arxiv.org/abs/1710.05941'
|
||||
Mathematically identical to SiLU. See note in nn.SiLU for references.
|
||||
"""
|
||||
|
||||
def forward(self, x):
|
||||
return x * torch.sigmoid(x)
|
|
@ -17,7 +17,7 @@ from typing import Optional
|
|||
from omegaconf import DictConfig, open_dict
|
||||
|
||||
from nemo.collections.asr.modules import conv_asr
|
||||
from nemo.collections.asr.parts import jasper
|
||||
from nemo.collections.asr.parts.submodules import jasper
|
||||
from nemo.utils import logging
|
||||
|
||||
|
|
@ -23,7 +23,7 @@ from pyannote.core import Annotation, Segment
|
|||
from pyannote.metrics.diarization import DiarizationErrorRate
|
||||
from tqdm import tqdm
|
||||
|
||||
from nemo.collections.asr.parts.nmse_clustering import COSclustering
|
||||
from nemo.collections.asr.parts.utils.nmse_clustering import COSclustering
|
||||
from nemo.utils import logging
|
||||
|
||||
|
|
@ -13,7 +13,7 @@
|
|||
# limitations under the License.
|
||||
|
||||
import nemo.collections.common.callbacks
|
||||
from nemo.collections.common import losses, parts, tokenizers
|
||||
from nemo.collections.common import data, losses, parts, tokenizers
|
||||
from nemo.package_info import __version__
|
||||
|
||||
# Set collection version equal to NeMo version.
|
||||
|
|
15
nemo/collections/common/data/__init__.py
Normal file
15
nemo/collections/common/data/__init__.py
Normal file
|
@ -0,0 +1,15 @@
|
|||
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from nemo.collections.common.data.dataset import ConcatDataset
|
178
nemo/collections/common/data/dataset.py
Normal file
178
nemo/collections/common/data/dataset.py
Normal file
|
@ -0,0 +1,178 @@
|
|||
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any, List
|
||||
|
||||
import numpy as np
|
||||
import torch.utils.data as pt_data
|
||||
from torch.utils.data import IterableDataset
|
||||
|
||||
__all__ = ['ConcatDataset']
|
||||
|
||||
|
||||
class ConcatDataset(IterableDataset):
|
||||
"""
|
||||
A dataset that accepts as argument multiple datasets and then samples from them based on the specified
|
||||
sampling technique.
|
||||
Args:
|
||||
datasets (list): A list of datasets to sample from.
|
||||
shuffle (bool): Whether to shuffle individual datasets. Only works with non-iterable datasets.
|
||||
Defaults to True.
|
||||
sampling_technique (str): Sampling technique to choose which dataset to draw a sample from.
|
||||
Defaults to 'temperature'. Currently supports 'temperature', 'random' and 'round-robin'.
|
||||
sampling_temperature (int): Temperature value for sampling. Only used when sampling_technique = 'temperature'.
|
||||
Defaults to 5.
|
||||
sampling_probabilities (list): Probability values for sampling. Only used when sampling_technique = 'random'.
|
||||
global_rank (int): Worker rank, used for partitioning map style datasets. Defaults to 0.
|
||||
world_size (int): Total number of processes, used for partitioning map style datasets. Defaults to 1.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
datasets: List[Any],
|
||||
shuffle: bool = True,
|
||||
sampling_technique: str = 'temperature',
|
||||
sampling_temperature: int = 5,
|
||||
sampling_probabilities: List[float] = None,
|
||||
global_rank: int = 0,
|
||||
world_size: int = 1,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
supported_sampling_techniques = ['temperature', 'random', 'round-robin']
|
||||
self.datasets = datasets
|
||||
self.iterables = [None] * len(datasets)
|
||||
self.shuffle = shuffle
|
||||
self.global_rank = global_rank
|
||||
self.world_size = world_size
|
||||
self.sampling_kwargs = {}
|
||||
if sampling_technique == 'temperature':
|
||||
self.index_generator = ConcatDataset.temperature_generator
|
||||
self.sampling_kwargs['temperature'] = sampling_temperature
|
||||
elif sampling_technique == 'random':
|
||||
self.index_generator = ConcatDataset.random_generator
|
||||
self.sampling_kwargs['p'] = sampling_probabilities
|
||||
elif sampling_technique == 'round-robin':
|
||||
self.index_generator = ConcatDataset.round_robin_generator
|
||||
else:
|
||||
raise ValueError(f"Currently we only support sampling techniques in {supported_sampling_techniques}.")
|
||||
self.length = 0
|
||||
|
||||
if isinstance(datasets[0], IterableDataset):
|
||||
self.kind = 'iterable'
|
||||
else:
|
||||
self.kind = 'map'
|
||||
|
||||
for idx, dataset in enumerate(datasets):
|
||||
isiterable = isinstance(dataset, IterableDataset)
|
||||
if (isiterable and not self.kind == 'iterable') or (not isiterable and self.kind == 'iterable'):
|
||||
raise ValueError("All datasets in ConcatDataset must be of the same kind (Iterable or Map).")
|
||||
|
||||
if self.kind == 'map':
|
||||
self.length += len(dataset) // world_size
|
||||
else:
|
||||
self.length += len(dataset)
|
||||
|
||||
def get_iterable(self, dataset):
|
||||
if isinstance(dataset, IterableDataset):
|
||||
return dataset.__iter__()
|
||||
else:
|
||||
indices = np.arange(len(dataset))
|
||||
if self.shuffle:
|
||||
np.random.shuffle(indices)
|
||||
return iter(indices)
|
||||
|
||||
def __iter__(self):
|
||||
worker_info = pt_data.get_worker_info()
|
||||
if worker_info is None:
|
||||
max_elements = self.length
|
||||
wid = 0
|
||||
wnum = 1
|
||||
else:
|
||||
wid = worker_info.id
|
||||
wnum = worker_info.num_workers
|
||||
max_elements = len(range(wid, self.length, wnum))
|
||||
|
||||
if self.kind == 'map':
|
||||
for idx in range(len(self.datasets)):
|
||||
start_idx = (len(self.datasets[idx]) // self.world_size) * self.global_rank
|
||||
end_idx = start_idx + (len(self.datasets[idx]) // self.world_size)
|
||||
if self.global_rank == self.world_size - 1:
|
||||
end_idx = len(self.datasets[idx])
|
||||
indices = range(start_idx + wid, end_idx, wnum)
|
||||
self.datasets[idx] = pt_data.Subset(self.datasets[idx], indices)
|
||||
|
||||
for idx, dataset in enumerate(self.datasets):
|
||||
iterable = self.get_iterable(dataset)
|
||||
self.iterables[idx] = iterable
|
||||
|
||||
n = 0
|
||||
ind_gen = self.index_generator(self.datasets, **self.sampling_kwargs)
|
||||
while n < max_elements:
|
||||
n += 1
|
||||
try:
|
||||
ind = next(ind_gen)
|
||||
except StopIteration:
|
||||
return
|
||||
try:
|
||||
val = next(self.iterables[ind])
|
||||
if self.kind == 'map':
|
||||
val = self.datasets[ind][val]
|
||||
yield val
|
||||
except StopIteration:
|
||||
self.iterables[ind] = self.get_iterable(self.datasets[ind])
|
||||
n -= 1
|
||||
|
||||
def __len__(self):
|
||||
return self.length
|
||||
|
||||
@staticmethod
|
||||
def temperature_generator(datasets, **kwargs):
|
||||
temp = kwargs.get('temperature')
|
||||
if not temp:
|
||||
raise ValueError("Temperature generator expects a 'temperature' keyowrd argument.")
|
||||
|
||||
lengths = []
|
||||
num = len(datasets)
|
||||
for dataset in datasets:
|
||||
lengths.append(len(dataset))
|
||||
|
||||
p = np.array(lengths) / np.sum(lengths)
|
||||
p = np.power(p, 1 / temp)
|
||||
p = p / np.sum(p)
|
||||
|
||||
while True:
|
||||
ind = np.random.choice(np.arange(num), p=p)
|
||||
yield ind
|
||||
|
||||
@staticmethod
|
||||
def round_robin_generator(datasets, **kwargs):
|
||||
num = len(datasets)
|
||||
while True:
|
||||
for i in range(num):
|
||||
yield i
|
||||
|
||||
@staticmethod
|
||||
def random_generator(datasets, **kwargs):
|
||||
p = kwargs.get('p')
|
||||
if not p:
|
||||
raise ValueError("Random generator expects a 'p' keyowrd argument for sampling probabilities.")
|
||||
|
||||
num = len(datasets)
|
||||
if len(p) != num:
|
||||
raise ValueError("Length of probabilities list must be equal to the number of datasets.")
|
||||
|
||||
while True:
|
||||
ind = np.random.choice(np.arange(num), p=p)
|
||||
yield ind
|
13
nemo/collections/common/parts/preprocessing/__init__.py
Normal file
13
nemo/collections/common/parts/preprocessing/__init__.py
Normal file
|
@ -0,0 +1,13 @@
|
|||
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
|
@ -19,7 +19,7 @@ from typing import Any, Dict, List, Optional, Union
|
|||
|
||||
import pandas as pd
|
||||
|
||||
from nemo.collections.asr.parts import manifest, parsers
|
||||
from nemo.collections.common.parts.preprocessing import manifest, parsers
|
||||
from nemo.utils import logging
|
||||
|
||||
|
|
@ -16,7 +16,7 @@ from typing import List, Optional
|
|||
|
||||
import frozendict
|
||||
|
||||
from nemo.collections.asr.parts import cleaners
|
||||
from nemo.collections.common.parts.preprocessing import cleaners
|
||||
|
||||
|
||||
class CharParser:
|
|
@ -27,7 +27,6 @@ from nemo.collections.nlp.data.language_modeling.lm_bert_dataset import (
|
|||
)
|
||||
from nemo.collections.nlp.data.language_modeling.sentence_dataset import SentenceDataset, TarredSentenceDataset
|
||||
from nemo.collections.nlp.data.machine_translation.machine_translation_dataset import (
|
||||
ConcatTranslationDataset,
|
||||
TarredTranslationDataset,
|
||||
TranslationDataset,
|
||||
)
|
||||
|
|
|
@ -13,7 +13,6 @@
|
|||
# limitations under the License.
|
||||
|
||||
from nemo.collections.nlp.data.machine_translation.machine_translation_dataset import (
|
||||
ConcatTranslationDataset,
|
||||
TarredTranslationDataset,
|
||||
TranslationDataset,
|
||||
)
|
||||
|
|
|
@ -23,7 +23,6 @@ from typing import Any, List, Optional
|
|||
|
||||
import braceexpand
|
||||
import numpy as np
|
||||
import torch.utils.data as pt_data
|
||||
import webdataset as wd
|
||||
from torch.utils.data import IterableDataset
|
||||
|
||||
|
@ -31,7 +30,7 @@ from nemo.collections.nlp.data.data_utils.data_preprocessing import dataset_to_i
|
|||
from nemo.core import Dataset
|
||||
from nemo.utils import logging
|
||||
|
||||
__all__ = ['TranslationDataset', 'TarredTranslationDataset', 'ConcatTranslationDataset']
|
||||
__all__ = ['TranslationDataset', 'TarredTranslationDataset']
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -444,162 +443,3 @@ class TarredTranslationDataset(IterableDataset):
|
|||
|
||||
def __len__(self):
|
||||
return self.length
|
||||
|
||||
|
||||
class ConcatTranslationDataset(IterableDataset):
|
||||
"""
|
||||
A dataset that accepts as argument multiple datasets and then samples from them based on the specified
|
||||
sampling technique.
|
||||
Args:
|
||||
datasets (list): A list of datasets to sample from.
|
||||
shuffle (bool): Whether to shuffle individual datasets. Only works with non-iterable datasets.
|
||||
Defaults to True.
|
||||
sampling_technique (str): Sampling technique to choose which dataset to draw a sample from.
|
||||
Defaults to 'temperature'. Currently supports 'temperature', 'random' and 'round-robin'.
|
||||
sampling_temperature (int): Temperature value for sampling. Only used when sampling_technique = 'temperature'.
|
||||
Defaults to 5.
|
||||
sampling_probabilities (list): Probability values for sampling. Only used when sampling_technique = 'random'.
|
||||
global_rank (int): Worker rank, used for partitioning map style datasets. Defaults to 0.
|
||||
world_size (int): Total number of processes, used for partitioning map style datasets. Defaults to 1.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
datasets: List[Any],
|
||||
shuffle: bool = True,
|
||||
sampling_technique: str = 'temperature',
|
||||
sampling_temperature: int = 5,
|
||||
sampling_probabilities: List[float] = None,
|
||||
global_rank: int = 0,
|
||||
world_size: int = 1,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
supported_sampling_techniques = ['temperature', 'random', 'round-robin']
|
||||
self.datasets = datasets
|
||||
self.iterables = [None] * len(datasets)
|
||||
self.shuffle = shuffle
|
||||
self.global_rank = global_rank
|
||||
self.world_size = world_size
|
||||
self.sampling_kwargs = {}
|
||||
if sampling_technique == 'temperature':
|
||||
self.index_generator = ConcatTranslationDataset.temperature_generator
|
||||
self.sampling_kwargs['temperature'] = sampling_temperature
|
||||
elif sampling_technique == 'random':
|
||||
self.index_generator = ConcatTranslationDataset.random_generator
|
||||
self.sampling_kwargs['p'] = sampling_probabilities
|
||||
elif sampling_technique == 'round-robin':
|
||||
self.index_generator = ConcatTranslationDataset.round_robin_generator
|
||||
else:
|
||||
raise ValueError(f"Currently we only support sampling techniques in {supported_sampling_techniques}.")
|
||||
self.N = 0
|
||||
|
||||
if isinstance(datasets[0], IterableDataset):
|
||||
self.kind = 'iterable'
|
||||
else:
|
||||
self.kind = 'map'
|
||||
|
||||
for idx, dataset in enumerate(datasets):
|
||||
isiterable = isinstance(dataset, IterableDataset)
|
||||
if (isiterable and not self.kind == 'iterable') or (not isiterable and self.kind == 'iterable'):
|
||||
raise ValueError(
|
||||
"All datasets in ConcatTranslationDataset must be of the same kind (Iterable or Map)."
|
||||
)
|
||||
|
||||
if self.kind == 'map':
|
||||
self.N += len(dataset) // world_size
|
||||
else:
|
||||
self.N += len(dataset)
|
||||
|
||||
def get_iterable(self, dataset):
|
||||
if isinstance(dataset, IterableDataset):
|
||||
return dataset.__iter__()
|
||||
else:
|
||||
indices = np.arange(len(dataset))
|
||||
if self.shuffle:
|
||||
np.random.shuffle(indices)
|
||||
return iter(indices)
|
||||
|
||||
def __iter__(self):
|
||||
worker_info = pt_data.get_worker_info()
|
||||
if worker_info is None:
|
||||
max_elements = self.N
|
||||
wid = 0
|
||||
wnum = 1
|
||||
else:
|
||||
wid = worker_info.id
|
||||
wnum = worker_info.num_workers
|
||||
max_elements = len(range(wid, self.N, wnum))
|
||||
|
||||
if self.kind == 'map':
|
||||
for idx in range(len(self.datasets)):
|
||||
start_idx = (len(self.datasets[idx]) // self.world_size) * self.global_rank
|
||||
end_idx = start_idx + (len(self.datasets[idx]) // self.world_size)
|
||||
if self.global_rank == self.world_size - 1:
|
||||
end_idx = len(self.datasets[idx])
|
||||
indices = range(start_idx + wid, end_idx, wnum)
|
||||
self.datasets[idx] = pt_data.Subset(self.datasets[idx], indices)
|
||||
|
||||
for idx, dataset in enumerate(self.datasets):
|
||||
iterable = self.get_iterable(dataset)
|
||||
self.iterables[idx] = iterable
|
||||
|
||||
n = 0
|
||||
ind_gen = self.index_generator(self.datasets, **self.sampling_kwargs)
|
||||
while n < max_elements:
|
||||
n += 1
|
||||
try:
|
||||
ind = next(ind_gen)
|
||||
except StopIteration:
|
||||
return
|
||||
try:
|
||||
val = next(self.iterables[ind])
|
||||
if self.kind == 'map':
|
||||
val = self.datasets[ind][val]
|
||||
yield val
|
||||
except StopIteration:
|
||||
self.iterables[ind] = self.get_iterable(self.datasets[ind])
|
||||
n -= 1
|
||||
|
||||
def __len__(self):
|
||||
return self.N
|
||||
|
||||
@staticmethod
|
||||
def temperature_generator(datasets, **kwargs):
|
||||
temp = kwargs.get('temperature')
|
||||
if not temp:
|
||||
raise ValueError("Temperature generator expects a 'temperature' keyowrd argument.")
|
||||
|
||||
lengths = []
|
||||
num = len(datasets)
|
||||
for dataset in datasets:
|
||||
lengths.append(len(dataset))
|
||||
|
||||
p = np.array(lengths) / np.sum(lengths)
|
||||
p = np.power(p, 1 / temp)
|
||||
p = p / np.sum(p)
|
||||
|
||||
while True:
|
||||
ind = np.random.choice(np.arange(num), p=p)
|
||||
yield ind
|
||||
|
||||
@staticmethod
|
||||
def round_robin_generator(datasets, **kwargs):
|
||||
num = len(datasets)
|
||||
while True:
|
||||
for i in range(num):
|
||||
yield i
|
||||
|
||||
@staticmethod
|
||||
def random_generator(datasets, **kwargs):
|
||||
p = kwargs.get('p')
|
||||
if not p:
|
||||
raise ValueError("Random generator expects a 'p' keyowrd argument for sampling probabilities.")
|
||||
|
||||
num = len(datasets)
|
||||
if len(p) != num:
|
||||
raise ValueError("Length of probabilities list must be equal to the number of datasets.")
|
||||
|
||||
while True:
|
||||
ind = np.random.choice(np.arange(num), p=p)
|
||||
yield ind
|
||||
|
|
|
@ -28,13 +28,14 @@ from pytorch_lightning import Trainer
|
|||
from pytorch_lightning.utilities import rank_zero_only
|
||||
from sacrebleu import corpus_bleu
|
||||
|
||||
from nemo.collections.common.data import ConcatDataset
|
||||
from nemo.collections.common.losses import NLLLoss, SmoothedCrossEntropyLoss
|
||||
from nemo.collections.common.metrics import GlobalAverageLossMetric
|
||||
from nemo.collections.common.parts import transformer_weights_init
|
||||
from nemo.collections.common.tokenizers.chinese_tokenizers import ChineseProcessor
|
||||
from nemo.collections.common.tokenizers.en_ja_tokenizers import EnJaProcessor
|
||||
from nemo.collections.common.tokenizers.moses_tokenizers import MosesProcessor
|
||||
from nemo.collections.nlp.data import ConcatTranslationDataset, TarredTranslationDataset, TranslationDataset
|
||||
from nemo.collections.nlp.data import TarredTranslationDataset, TranslationDataset
|
||||
from nemo.collections.nlp.models.enc_dec_nlp_model import EncDecNLPModel
|
||||
from nemo.collections.nlp.models.machine_translation.mt_enc_dec_config import MTEncDecModelConfig
|
||||
from nemo.collections.nlp.modules.common import TokenClassifier
|
||||
|
@ -128,7 +129,12 @@ class MTEncDecModel(EncDecNLPModel):
|
|||
model_name = encoder_cfg_dict.pop('model_name', None)
|
||||
pretrained = encoder_cfg_dict.pop('pretrained', False)
|
||||
self.encoder = get_transformer(
|
||||
library=library, model_name=model_name, pretrained=pretrained, config_dict=encoder_cfg_dict, encoder=True,
|
||||
library=library,
|
||||
model_name=model_name,
|
||||
pretrained=pretrained,
|
||||
config_dict=encoder_cfg_dict,
|
||||
encoder=True,
|
||||
pre_ln_final_layer_norm=encoder_cfg_dict.get('pre_ln_final_layer_norm', False),
|
||||
)
|
||||
|
||||
# decoder from NeMo, Megatron-LM, or HuggingFace
|
||||
|
@ -139,7 +145,12 @@ class MTEncDecModel(EncDecNLPModel):
|
|||
pretrained = decoder_cfg_dict.pop('pretrained', False)
|
||||
decoder_cfg_dict['hidden_size'] = self.encoder.hidden_size
|
||||
self.decoder = get_transformer(
|
||||
library=library, model_name=model_name, pretrained=pretrained, config_dict=decoder_cfg_dict, encoder=False,
|
||||
library=library,
|
||||
model_name=model_name,
|
||||
pretrained=pretrained,
|
||||
config_dict=decoder_cfg_dict,
|
||||
encoder=False,
|
||||
pre_ln_final_layer_norm=decoder_cfg_dict.get('pre_ln_final_layer_norm', False),
|
||||
)
|
||||
|
||||
self.log_softmax = TokenClassifier(
|
||||
|
@ -274,6 +285,9 @@ class MTEncDecModel(EncDecNLPModel):
|
|||
# if user specifies one validation dataloader, then PTL reverts to giving a list of dictionary instead of a list of list of dictionary
|
||||
if isinstance(outputs[0], dict):
|
||||
outputs = [outputs]
|
||||
|
||||
loss_list = []
|
||||
sb_score_list = []
|
||||
for dataloader_idx, output in enumerate(outputs):
|
||||
if dataloader_idx == 0:
|
||||
eval_loss = getattr(self, f'{mode}_loss').compute()
|
||||
|
@ -328,6 +342,8 @@ class MTEncDecModel(EncDecNLPModel):
|
|||
else:
|
||||
sb_score = 0.0
|
||||
|
||||
loss_list.append(eval_loss.cpu().numpy())
|
||||
sb_score_list.append(sb_score)
|
||||
if dataloader_idx == 0:
|
||||
self.log(f"{mode}_loss", eval_loss, sync_dist=True)
|
||||
self.log(f"{mode}_sacreBLEU", sb_score, sync_dist=True)
|
||||
|
@ -337,6 +353,10 @@ class MTEncDecModel(EncDecNLPModel):
|
|||
self.log(f"{mode}_sacreBLEU_dl_index_{dataloader_idx}", sb_score, sync_dist=True)
|
||||
getattr(self, f'{mode}_loss_{dataloader_idx}').reset()
|
||||
|
||||
if len(loss_list) > 1:
|
||||
self.log(f"{mode}_loss_avg", np.mean(loss_list), sync_dist=True)
|
||||
self.log(f"{mode}_sacreBLEU_avg", np.mean(sb_score_list), sync_dist=True)
|
||||
|
||||
def validation_epoch_end(self, outputs):
|
||||
"""
|
||||
Called at the end of validation to aggregate outputs.
|
||||
|
@ -470,7 +490,7 @@ class MTEncDecModel(EncDecNLPModel):
|
|||
datasets.append(dataset)
|
||||
|
||||
if self.multilingual:
|
||||
dataset = ConcatTranslationDataset(
|
||||
dataset = ConcatDataset(
|
||||
datasets=datasets,
|
||||
sampling_technique=cfg.get('concat_sampling_technique'),
|
||||
sampling_temperature=cfg.get('concat_sampling_temperature'),
|
||||
|
@ -522,7 +542,7 @@ class MTEncDecModel(EncDecNLPModel):
|
|||
datasets.append(dataset)
|
||||
|
||||
if self.multilingual:
|
||||
dataset = ConcatTranslationDataset(
|
||||
dataset = ConcatDataset(
|
||||
datasets=datasets,
|
||||
shuffle=cfg.get('shuffle'),
|
||||
sampling_technique=cfg.get('concat_sampling_technique'),
|
||||
|
|
|
@ -331,7 +331,7 @@ class NLPModel(ModelPT, Exportable):
|
|||
restore_path: str,
|
||||
override_config_path: Optional[Union[OmegaConf, str]] = None,
|
||||
map_location: Optional[torch.device] = None,
|
||||
strict: bool = False,
|
||||
strict: bool = True,
|
||||
return_config: bool = False,
|
||||
trainer: Trainer = None,
|
||||
):
|
||||
|
@ -344,7 +344,7 @@ class NLPModel(ModelPT, Exportable):
|
|||
config file or an OmegaConf / DictConfig object representing the model config.
|
||||
map_location: Optional torch.device() to map the instantiated model to a device.
|
||||
By default (None), it will select a GPU if available, falling back to CPU otherwise.
|
||||
strict: Passed to load_state_dict.
|
||||
strict: Passed to load_state_dict. Set to True by default.
|
||||
return_config: If set to true, will return just the underlying config of the restored
|
||||
model as an OmegaConf DictConfig object without instantiating the model.
|
||||
trainer: PyTorch Lightning trainer. Must be passed in order to use model parallel .nemo
|
||||
|
@ -422,7 +422,7 @@ class NLPModel(ModelPT, Exportable):
|
|||
restored_model = cls._default_restore_from(
|
||||
restore_path, override_config_path, map_location, strict, return_config
|
||||
)
|
||||
restored_model._trainer = trainer
|
||||
restored_model.set_trainer(trainer)
|
||||
return restored_model
|
||||
else:
|
||||
return super().restore_from(restore_path, override_config_path, map_location, strict, return_config)
|
||||
|
|
|
@ -124,6 +124,7 @@ def get_transformer(
|
|||
config_dict: Optional[dict] = None,
|
||||
checkpoint_file: Optional[str] = None,
|
||||
encoder: bool = True,
|
||||
pre_ln_final_layer_norm=True,
|
||||
) -> Union[EncoderModule, DecoderModule]:
|
||||
"""Gets Transformer based model to be used as an Encoder or Decoder in NeMo NLP.
|
||||
First choose the library to get the transformer from. This can be huggingface,
|
||||
|
@ -159,7 +160,11 @@ def get_transformer(
|
|||
if isinstance(config_dict, NeMoTransformerConfig):
|
||||
config_dict = asdict(config_dict)
|
||||
model = get_nemo_transformer(
|
||||
model_name=model_name, pretrained=pretrained, config_dict=config_dict, encoder=encoder,
|
||||
model_name=model_name,
|
||||
pretrained=pretrained,
|
||||
config_dict=config_dict,
|
||||
encoder=encoder,
|
||||
pre_ln_final_layer_norm=pre_ln_final_layer_norm,
|
||||
)
|
||||
|
||||
if checkpoint_file is not None:
|
||||
|
|
|
@ -134,6 +134,11 @@ class TransformerDecoder(nn.Module):
|
|||
):
|
||||
super().__init__()
|
||||
|
||||
if pre_ln and pre_ln_final_layer_norm:
|
||||
self.final_layer_norm = nn.LayerNorm(hidden_size, eps=1e-5)
|
||||
else:
|
||||
self.final_layer_norm = None
|
||||
|
||||
layer = TransformerDecoderBlock(
|
||||
hidden_size,
|
||||
inner_size,
|
||||
|
|
|
@ -118,6 +118,11 @@ class TransformerEncoder(nn.Module):
|
|||
):
|
||||
super().__init__()
|
||||
|
||||
if pre_ln and pre_ln_final_layer_norm:
|
||||
self.final_layer_norm = nn.LayerNorm(hidden_size, eps=1e-5)
|
||||
else:
|
||||
self.final_layer_norm = None
|
||||
|
||||
layer = TransformerEncoderBlock(
|
||||
hidden_size,
|
||||
inner_size,
|
||||
|
|
|
@ -27,6 +27,7 @@ def get_nemo_transformer(
|
|||
pretrained: bool = False,
|
||||
config_dict: Optional[Union[dict, DictConfig]] = None,
|
||||
encoder: bool = True,
|
||||
pre_ln_final_layer_norm: bool = True,
|
||||
) -> Union[TransformerEncoderNM, TransformerDecoderNM]:
|
||||
"""Returns NeMo transformer.
|
||||
The following configurations are mandatory:
|
||||
|
@ -78,7 +79,7 @@ def get_nemo_transformer(
|
|||
hidden_act=cfg.get('hidden_act', 'relu'),
|
||||
mask_future=cfg.get('mask_future', False),
|
||||
pre_ln=cfg.get('pre_ln', False),
|
||||
pre_ln_final_layer_norm=cfg.get('pre_ln_final_layer_norm', True),
|
||||
pre_ln_final_layer_norm=pre_ln_final_layer_norm,
|
||||
num_token_types=cfg.get('num_token_types', 2),
|
||||
)
|
||||
else:
|
||||
|
@ -96,7 +97,7 @@ def get_nemo_transformer(
|
|||
attn_layer_dropout=cfg.get('attn_layer_dropout', 0.0),
|
||||
hidden_act=cfg.get('hidden_act', 'relu'),
|
||||
pre_ln=cfg.get('pre_ln', False),
|
||||
pre_ln_final_layer_norm=cfg.get('pre_ln_final_layer_norm', True),
|
||||
pre_ln_final_layer_norm=pre_ln_final_layer_norm,
|
||||
num_token_types=cfg.get('num_token_types', 2),
|
||||
)
|
||||
|
||||
|
|
|
@ -55,9 +55,9 @@ import torch
|
|||
from torch.nn.utils.rnn import pad_sequence
|
||||
from tqdm import tqdm
|
||||
|
||||
from nemo.collections.asr.parts import collections, parsers
|
||||
from nemo.collections.asr.parts.features import WaveformFeaturizer
|
||||
from nemo.collections.asr.parts.segment import AudioSegment
|
||||
from nemo.collections.asr.parts.preprocessing.features import WaveformFeaturizer
|
||||
from nemo.collections.asr.parts.preprocessing.segment import AudioSegment
|
||||
from nemo.collections.common.parts.preprocessing import collections, parsers
|
||||
from nemo.core.classes import Dataset
|
||||
from nemo.core.neural_types.elements import *
|
||||
from nemo.core.neural_types.neural_type import NeuralType
|
||||
|
|
|
@ -21,7 +21,7 @@ from omegaconf import MISSING, DictConfig, OmegaConf
|
|||
from pytorch_lightning import Trainer
|
||||
|
||||
from nemo.collections.asr.data.audio_to_text import FastPitchDataset
|
||||
from nemo.collections.asr.parts import parsers
|
||||
from nemo.collections.common.parts.preprocessing import parsers
|
||||
from nemo.collections.tts.losses.fastpitchloss import FastPitchLoss
|
||||
from nemo.collections.tts.models.base import SpectrogramGenerator
|
||||
from nemo.collections.tts.modules.fastpitch import FastPitchModule
|
||||
|
@ -223,12 +223,12 @@ class FastPitchModel(SpectrogramGenerator):
|
|||
List of available pre-trained models.
|
||||
"""
|
||||
list_of_models = []
|
||||
# model = PretrainedModelInfo(
|
||||
# pretrained_model_name="",
|
||||
# location="",
|
||||
# description="",
|
||||
# class_=cls,
|
||||
# )
|
||||
# list_of_models.append(model)
|
||||
model = PretrainedModelInfo(
|
||||
pretrained_model_name="tts_en_fastpitch",
|
||||
location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/tts_en_fastpitch/versions/1.0.0/files/tts_en_fastpitch.nemo",
|
||||
description="This model is trained on LJSpeech sampled at 22050Hz with and can be used to generate female English voices with an American accent.",
|
||||
class_=cls,
|
||||
)
|
||||
list_of_models.append(model)
|
||||
|
||||
return list_of_models
|
||||
|
|
|
@ -24,7 +24,7 @@ from pytorch_lightning import Trainer
|
|||
from pytorch_lightning.loggers import LoggerCollection, TensorBoardLogger
|
||||
|
||||
from nemo.collections.asr.data.audio_to_text import FastPitchDataset
|
||||
from nemo.collections.asr.parts import parsers
|
||||
from nemo.collections.common.parts.preprocessing import parsers
|
||||
from nemo.collections.tts.helpers.helpers import plot_spectrogram_to_numpy
|
||||
from nemo.collections.tts.losses.fastpitchloss import BaseFastPitchLoss
|
||||
from nemo.collections.tts.losses.fastspeech2loss import L1MelLoss
|
||||
|
@ -32,7 +32,7 @@ from nemo.collections.tts.losses.hifigan_losses import DiscriminatorLoss, Featur
|
|||
from nemo.collections.tts.models.base import TextToWaveform
|
||||
from nemo.collections.tts.modules.fastpitch import regulate_len
|
||||
from nemo.collections.tts.modules.hifigan_modules import MultiPeriodDiscriminator, MultiScaleDiscriminator
|
||||
from nemo.core.classes.common import typecheck
|
||||
from nemo.core.classes.common import PretrainedModelInfo, typecheck
|
||||
from nemo.core.neural_types.elements import (
|
||||
MelSpectrogramType,
|
||||
RegressionValuesType,
|
||||
|
@ -202,7 +202,7 @@ class FastPitchHifiGanE2EModel(TextToWaveform):
|
|||
"splice": NeuralType(optional=True),
|
||||
},
|
||||
output_types={
|
||||
"audio": NeuralType(('B', 'T'), MelSpectrogramType()),
|
||||
"audio": NeuralType(('B', 'S', 'T'), MelSpectrogramType()),
|
||||
"splices": NeuralType(),
|
||||
"log_dur_preds": NeuralType(('B', 'T'), TokenLogDurationType()),
|
||||
"pitch_preds": NeuralType(('B', 'T'), RegressionValuesType()),
|
||||
|
@ -248,7 +248,7 @@ class FastPitchHifiGanE2EModel(TextToWaveform):
|
|||
splices.append(start)
|
||||
gen_in = torch.stack(output)
|
||||
|
||||
output = self.generator(gen_in.transpose(1, 2))
|
||||
output = self.generator(x=gen_in.transpose(1, 2))
|
||||
|
||||
return output, splices, log_durs_predicted, pitch_predicted
|
||||
|
||||
|
@ -410,13 +410,13 @@ class FastPitchHifiGanE2EModel(TextToWaveform):
|
|||
List of available pre-trained models.
|
||||
"""
|
||||
list_of_models = []
|
||||
# model = PretrainedModelInfo(
|
||||
# pretrained_model_name="",
|
||||
# location="",
|
||||
# description="",
|
||||
# class_=cls,
|
||||
# )
|
||||
# list_of_models.append(model)
|
||||
model = PretrainedModelInfo(
|
||||
pretrained_model_name="tts_en_e2e_fastpitchhifigan",
|
||||
location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/tts_en_e2e_fastpitchhifigan/versions/1.0.0/files/tts_en_e2e_fastpitchhifigan.nemo",
|
||||
description="This model is trained on LJSpeech sampled at 22050Hz with and can be used to generate female English voices with an American accent.",
|
||||
class_=cls,
|
||||
)
|
||||
list_of_models.append(model)
|
||||
|
||||
return list_of_models
|
||||
|
||||
|
@ -427,8 +427,8 @@ class FastPitchHifiGanE2EModel(TextToWaveform):
|
|||
"""
|
||||
self.eval()
|
||||
audio, _, log_dur_pred, _ = self(text=tokens, splice=False)
|
||||
audio = audio.squeeze()
|
||||
durations = torch.sum(torch.clamp(torch.exp(log_dur_pred) - 1, 0, self.max_token_duration), 1)
|
||||
audio = audio.squeeze(1)
|
||||
durations = torch.sum(torch.clamp(torch.exp(log_dur_pred) - 1, 0, self.max_token_duration), 1).to(torch.int)
|
||||
audio_list = []
|
||||
for i, sample in enumerate(audio):
|
||||
audio_list.append(sample[: durations[i] * self.hop_size])
|
||||
|
|
|
@ -22,7 +22,7 @@ from hydra.utils import instantiate
|
|||
from omegaconf import MISSING, DictConfig, OmegaConf, open_dict
|
||||
from pytorch_lightning.loggers import LoggerCollection, TensorBoardLogger
|
||||
|
||||
from nemo.collections.asr.parts import parsers
|
||||
from nemo.collections.common.parts.preprocessing import parsers
|
||||
from nemo.collections.tts.helpers.helpers import plot_spectrogram_to_numpy
|
||||
from nemo.collections.tts.losses.fastspeech2loss import DurationLoss, L2MelLoss
|
||||
from nemo.collections.tts.models.base import SpectrogramGenerator
|
||||
|
|
|
@ -22,13 +22,13 @@ from hydra.utils import instantiate
|
|||
from omegaconf import DictConfig, OmegaConf, open_dict
|
||||
from pytorch_lightning.loggers import LoggerCollection, TensorBoardLogger
|
||||
|
||||
from nemo.collections.asr.parts import parsers
|
||||
from nemo.collections.common.parts.preprocessing import parsers
|
||||
from nemo.collections.tts.helpers.helpers import plot_spectrogram_to_numpy
|
||||
from nemo.collections.tts.losses.fastspeech2loss import DurationLoss, L1MelLoss
|
||||
from nemo.collections.tts.losses.hifigan_losses import DiscriminatorLoss, FeatureMatchingLoss, GeneratorLoss
|
||||
from nemo.collections.tts.models.base import TextToWaveform
|
||||
from nemo.collections.tts.modules.hifigan_modules import MultiPeriodDiscriminator, MultiScaleDiscriminator
|
||||
from nemo.core.classes.common import typecheck
|
||||
from nemo.core.classes.common import PretrainedModelInfo, typecheck
|
||||
from nemo.core.neural_types.elements import (
|
||||
LengthsType,
|
||||
MaskType,
|
||||
|
@ -85,7 +85,14 @@ class FastSpeech2HifiGanE2EModel(TextToWaveform):
|
|||
|
||||
# Parser and mappings are used for inference only.
|
||||
self.parser = parsers.make_parser(name='en')
|
||||
with open(cfg.mappings_filepath, 'r') as f:
|
||||
if 'mappings_filepath' in cfg:
|
||||
mappings_filepath = cfg.get('mappings_filepath')
|
||||
else:
|
||||
logging.error(
|
||||
"ERROR: You must specify a mappings.json file in the config file under model.mappings_filepath."
|
||||
)
|
||||
mappings_filepath = self.register_artifact('mappings_filepath', mappings_filepath)
|
||||
with open(mappings_filepath, 'r') as f:
|
||||
mappings = json.load(f)
|
||||
self.word2phones = mappings['word2phones']
|
||||
self.phone2idx = mappings['phone2idx']
|
||||
|
@ -138,7 +145,7 @@ class FastSpeech2HifiGanE2EModel(TextToWaveform):
|
|||
"energies": NeuralType(('B', 'T'), RegressionValuesType(), optional=True),
|
||||
},
|
||||
output_types={
|
||||
"audio": NeuralType(('B', 'T'), MelSpectrogramType()),
|
||||
"audio": NeuralType(('B', 'S', 'T'), MelSpectrogramType()),
|
||||
"splices": NeuralType(),
|
||||
"log_dur_preds": NeuralType(('B', 'T'), TokenLogDurationType()),
|
||||
"pitch_preds": NeuralType(('B', 'T'), RegressionValuesType()),
|
||||
|
@ -170,7 +177,7 @@ class FastSpeech2HifiGanE2EModel(TextToWaveform):
|
|||
splices.append(start)
|
||||
gen_in = torch.stack(output)
|
||||
|
||||
output = self.generator(gen_in.transpose(1, 2))
|
||||
output = self.generator(x=gen_in.transpose(1, 2))
|
||||
|
||||
return output, splices, log_dur_preds, pitch_preds, energy_preds, encoded_text_mask
|
||||
|
||||
|
@ -348,9 +355,6 @@ class FastSpeech2HifiGanE2EModel(TextToWaveform):
|
|||
def setup_training_data(self, cfg):
|
||||
self._train_dl = self.__setup_dataloader_from_config(cfg)
|
||||
|
||||
def list_available_models(self):
|
||||
pass
|
||||
|
||||
def setup_validation_data(self, cfg):
|
||||
self._validation_dl = self.__setup_dataloader_from_config(cfg, shuffle_should_be=False, name="validation")
|
||||
|
||||
|
@ -399,10 +403,28 @@ class FastSpeech2HifiGanE2EModel(TextToWaveform):
|
|||
self.eval()
|
||||
token_len = torch.tensor([len(i) for i in tokens]).to(self.device)
|
||||
audio, _, log_dur_pred, _, _, _ = self(text=tokens, text_length=token_len, splice=False)
|
||||
audio = audio.squeeze()
|
||||
durations = torch.sum(torch.exp(log_dur_pred) - 1, 1)
|
||||
audio = audio.squeeze(1)
|
||||
durations = torch.sum(torch.exp(log_dur_pred) - 1, 1).to(torch.int)
|
||||
audio_list = []
|
||||
for i, sample in enumerate(audio):
|
||||
audio_list.append(sample[: durations[i] * self.hop_size])
|
||||
|
||||
return audio_list
|
||||
|
||||
@classmethod
|
||||
def list_available_models(cls) -> 'List[PretrainedModelInfo]':
|
||||
"""
|
||||
This method returns a list of pre-trained model which can be instantiated directly from NVIDIA's NGC cloud.
|
||||
Returns:
|
||||
List of available pre-trained models.
|
||||
"""
|
||||
list_of_models = []
|
||||
model = PretrainedModelInfo(
|
||||
pretrained_model_name="tts_en_e2e_fastspeech2hifigan",
|
||||
location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/tts_en_e2e_fastspeech2hifigan/versions/1.0.0/files/tts_en_e2e_fastspeech2hifigan.nemo",
|
||||
description="This model is trained on LJSpeech sampled at 22050Hz with and can be used to generate female English voices with an American accent.",
|
||||
class_=cls,
|
||||
)
|
||||
list_of_models.append(model)
|
||||
|
||||
return list_of_models
|
||||
|
|
|
@ -23,7 +23,7 @@ from pytorch_lightning import Trainer
|
|||
from pytorch_lightning.loggers import LoggerCollection, TensorBoardLogger
|
||||
|
||||
from nemo.collections.asr.data.audio_to_text import _AudioTextDataset
|
||||
from nemo.collections.asr.parts.perturb import process_augmentations
|
||||
from nemo.collections.asr.parts.preprocessing.perturb import process_augmentations
|
||||
from nemo.collections.tts.helpers.helpers import log_audio_to_tb, plot_alignment_to_numpy, plot_spectrogram_to_numpy
|
||||
from nemo.collections.tts.losses.glow_tts_loss import GlowTTSLoss
|
||||
from nemo.collections.tts.models.base import SpectrogramGenerator
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show more
Loading…
Reference in a new issue