Self-supervised pre-training for speech models (#3139)
* self-supervised training Signed-off-by: sam1373 <samuelkriman@gmail.com> * test Signed-off-by: sam1373 <samuelkriman@gmail.com> * remove imports Signed-off-by: sam1373 <samuelkriman@gmail.com> * fix Signed-off-by: sam1373 <samuelkriman@gmail.com> * sort imports Signed-off-by: sam1373 <samuelkriman@gmail.com> * fix audio_to_text Signed-off-by: sam1373 <samuelkriman@gmail.com> * manifest handle no text Signed-off-by: sam1373 <samuelkriman@gmail.com> * loss init Signed-off-by: sam1373 <samuelkriman@gmail.com> * style Signed-off-by: sam1373 <samuelkriman@gmail.com> * remove tokenizer from config Signed-off-by: sam1373 <samuelkriman@gmail.com> * config changes Signed-off-by: sam1373 <samuelkriman@gmail.com> * remove hydra import Signed-off-by: sam1373 <samuelkriman@gmail.com> * always spec augment Signed-off-by: sam1373 <samuelkriman@gmail.com> * fixes Signed-off-by: sam1373 <samuelkriman@gmail.com> * copyright Signed-off-by: sam1373 <samuelkriman@gmail.com> * fix cosine sim Signed-off-by: sam1373 <samuelkriman@gmail.com> * fix cosine sim Signed-off-by: sam1373 <samuelkriman@gmail.com> * fix cosine sim Signed-off-by: sam1373 <samuelkriman@gmail.com> * changes based on comments Signed-off-by: sam1373 <samuelkriman@gmail.com> * changes based on comments Signed-off-by: sam1373 <samuelkriman@gmail.com> * configs Signed-off-by: sam1373 <samuelkriman@gmail.com> * name fix Signed-off-by: sam1373 <samuelkriman@gmail.com> * ci config changes Signed-off-by: sam1373 <samuelkriman@gmail.com> * renamed to num_negatives Signed-off-by: sam1373 <samuelkriman@gmail.com> * minor changes Signed-off-by: sam1373 <samuelkriman@gmail.com> * name changes, type annotations Signed-off-by: sam1373 <samuelkriman@gmail.com> Co-authored-by: Yang Zhang <yzhang123@users.noreply.github.com>
This commit is contained in:
parent
9ab22a40bd
commit
b7a175b7b9
13
Jenkinsfile
vendored
13
Jenkinsfile
vendored
|
@ -351,6 +351,19 @@ pipeline {
|
|||
}
|
||||
}
|
||||
|
||||
stage('L2: Speech Pre-training - CitriNet') {
|
||||
steps {
|
||||
sh 'python examples/asr/speech_pre_training.py \
|
||||
--config-path="conf/citrinet_ssl/" --config-name="citrinet_ssl_ci" \
|
||||
model.train_ds.manifest_filepath=/home/TestData/an4_dataset/an4_train.json \
|
||||
model.validation_ds.manifest_filepath=/home/TestData/an4_dataset/an4_val.json \
|
||||
trainer.gpus=[1] \
|
||||
+trainer.fast_dev_run=True \
|
||||
exp_manager.exp_dir=examples/asr/speech_pre_training_results'
|
||||
sh 'rm -rf examples/asr/speech_pre_training_results'
|
||||
}
|
||||
}
|
||||
|
||||
stage('L2: Speech to Text WPE - Conformer') {
|
||||
steps {
|
||||
sh 'python examples/asr/speech_to_text_bpe.py \
|
||||
|
|
474
examples/asr/conf/citrinet/citrinet_1024.yaml
Normal file
474
examples/asr/conf/citrinet/citrinet_1024.yaml
Normal file
|
@ -0,0 +1,474 @@
|
|||
# This config contains the default values for training a Citrinet model with CTC loss and BPE-based vocabulary.
|
||||
# Default learning parameters in this config are set for effective batch size of 1k on 32 GPUs.
|
||||
# To train it with smaller batch sizes, you may need to re-tune the learning parameters or use higher accumulate_grad_batches.
|
||||
# If training for a short time, you can also reduce weight decay to 0.
|
||||
|
||||
# Training Recipe
|
||||
# This model can be trained using the default settings in this config with FP32 precision.
|
||||
# When training under AMP, increase `warmup_steps` to 5000 for stable training.
|
||||
# In order to create Citrinet-C, change the model.model_defaults.filters parameter.
|
||||
# When reducing the receptive field of these models, it is advised to reduce the amount of augmentation
|
||||
# for larger models from 10x time masking to 5x or 2x time masking.
|
||||
# For further details regarding Citrinet, visit - https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/main/asr/configs.html#citrinet
|
||||
|
||||
name: &name "Citrinet-1024-8x-Stride"
|
||||
|
||||
model:
|
||||
sample_rate: &sample_rate 16000
|
||||
log_prediction: true # enables logging sample predictions in the output during training
|
||||
|
||||
train_ds:
|
||||
manifest_filepath: ???
|
||||
sample_rate: 16000
|
||||
batch_size: 32
|
||||
trim_silence: false
|
||||
max_duration: 20.0
|
||||
shuffle: true
|
||||
is_tarred: false
|
||||
tarred_audio_filepaths: null
|
||||
use_start_end_token: false
|
||||
|
||||
validation_ds:
|
||||
manifest_filepath: ???
|
||||
sample_rate: 16000
|
||||
batch_size: 32
|
||||
shuffle: false
|
||||
use_start_end_token: false
|
||||
|
||||
test_ds:
|
||||
manifest_filepath: null
|
||||
sample_rate: 16000
|
||||
batch_size: 32
|
||||
shuffle: false
|
||||
use_start_end_token: false
|
||||
|
||||
model_defaults:
|
||||
repeat: 5
|
||||
dropout: 0.1
|
||||
separable: true
|
||||
se: true
|
||||
se_context_size: -1
|
||||
kernel_size_factor: 0.25
|
||||
filters: 1024
|
||||
|
||||
tokenizer:
|
||||
dir: ??? # path to directory which contains either tokenizer.model (bpe) or vocab.txt (for wpe)
|
||||
type: ??? # Can be either bpe or wpe
|
||||
|
||||
preprocessor:
|
||||
_target_: nemo.collections.asr.modules.AudioToMelSpectrogramPreprocessor
|
||||
sample_rate: *sample_rate
|
||||
normalize: "per_feature"
|
||||
window_size: 0.025
|
||||
window_stride: 0.01
|
||||
window: "hann"
|
||||
features: &n_mels 80
|
||||
n_fft: 512
|
||||
frame_splicing: 1
|
||||
dither: 0.00001
|
||||
pad_to: 16
|
||||
stft_conv: false
|
||||
|
||||
spec_augment:
|
||||
_target_: nemo.collections.asr.modules.SpectrogramAugmentation
|
||||
freq_masks: 2
|
||||
time_masks: 10
|
||||
freq_width: 27
|
||||
time_width: 0.05
|
||||
|
||||
encoder:
|
||||
_target_: nemo.collections.asr.modules.ConvASREncoder
|
||||
feat_in: *n_mels
|
||||
activation: relu
|
||||
conv_mask: true
|
||||
|
||||
jasper:
|
||||
- filters: ${model.model_defaults.filters}
|
||||
repeat: 1
|
||||
kernel: [5]
|
||||
stride: [1]
|
||||
dilation: [1]
|
||||
dropout: 0.0
|
||||
residual: false
|
||||
separable: ${model.model_defaults.separable}
|
||||
se: ${model.model_defaults.se}
|
||||
se_context_size: ${model.model_defaults.se_context_size}
|
||||
|
||||
|
||||
|
||||
- filters: ${model.model_defaults.filters}
|
||||
repeat: ${model.model_defaults.repeat}
|
||||
kernel: [11]
|
||||
stride: [2]
|
||||
dilation: [1]
|
||||
dropout: ${model.model_defaults.dropout}
|
||||
residual: true
|
||||
separable: ${model.model_defaults.separable}
|
||||
se: ${model.model_defaults.se}
|
||||
se_context_size: ${model.model_defaults.se_context_size}
|
||||
stride_last: true
|
||||
residual_mode: "stride_add"
|
||||
kernel_size_factor: ${model.model_defaults.kernel_size_factor}
|
||||
|
||||
|
||||
|
||||
- filters: ${model.model_defaults.filters}
|
||||
repeat: ${model.model_defaults.repeat}
|
||||
kernel: [13]
|
||||
stride: [1]
|
||||
dilation: [1]
|
||||
dropout: ${model.model_defaults.dropout}
|
||||
residual: true
|
||||
separable: ${model.model_defaults.separable}
|
||||
se: ${model.model_defaults.se}
|
||||
se_context_size: ${model.model_defaults.se_context_size}
|
||||
kernel_size_factor: ${model.model_defaults.kernel_size_factor}
|
||||
|
||||
|
||||
|
||||
- filters: ${model.model_defaults.filters}
|
||||
repeat: ${model.model_defaults.repeat}
|
||||
kernel: [15]
|
||||
stride: [1]
|
||||
dilation: [1]
|
||||
dropout: ${model.model_defaults.dropout}
|
||||
residual: true
|
||||
separable: ${model.model_defaults.separable}
|
||||
se: ${model.model_defaults.se}
|
||||
se_context_size: ${model.model_defaults.se_context_size}
|
||||
kernel_size_factor: ${model.model_defaults.kernel_size_factor}
|
||||
|
||||
|
||||
|
||||
- filters: ${model.model_defaults.filters}
|
||||
repeat: ${model.model_defaults.repeat}
|
||||
kernel: [17]
|
||||
stride: [1]
|
||||
dilation: [1]
|
||||
dropout: ${model.model_defaults.dropout}
|
||||
residual: true
|
||||
separable: ${model.model_defaults.separable}
|
||||
se: ${model.model_defaults.se}
|
||||
se_context_size: ${model.model_defaults.se_context_size}
|
||||
kernel_size_factor: ${model.model_defaults.kernel_size_factor}
|
||||
|
||||
|
||||
|
||||
- filters: ${model.model_defaults.filters}
|
||||
repeat: ${model.model_defaults.repeat}
|
||||
kernel: [19]
|
||||
stride: [1]
|
||||
dilation: [1]
|
||||
dropout: ${model.model_defaults.dropout}
|
||||
residual: true
|
||||
separable: ${model.model_defaults.separable}
|
||||
se: ${model.model_defaults.se}
|
||||
se_context_size: ${model.model_defaults.se_context_size}
|
||||
kernel_size_factor: ${model.model_defaults.kernel_size_factor}
|
||||
|
||||
|
||||
|
||||
- filters: ${model.model_defaults.filters}
|
||||
repeat: ${model.model_defaults.repeat}
|
||||
kernel: [21]
|
||||
stride: [1]
|
||||
dilation: [1]
|
||||
dropout: ${model.model_defaults.dropout}
|
||||
residual: true
|
||||
separable: ${model.model_defaults.separable}
|
||||
se: ${model.model_defaults.se}
|
||||
se_context_size: ${model.model_defaults.se_context_size}
|
||||
kernel_size_factor: ${model.model_defaults.kernel_size_factor}
|
||||
|
||||
|
||||
|
||||
- filters: ${model.model_defaults.filters}
|
||||
repeat: ${model.model_defaults.repeat}
|
||||
kernel: [13]
|
||||
stride: [2] # *stride
|
||||
dilation: [1]
|
||||
dropout: ${model.model_defaults.dropout}
|
||||
residual: true
|
||||
separable: ${model.model_defaults.separable}
|
||||
se: ${model.model_defaults.se}
|
||||
se_context_size: ${model.model_defaults.se_context_size}
|
||||
stride_last: true
|
||||
residual_mode: "stride_add"
|
||||
kernel_size_factor: ${model.model_defaults.kernel_size_factor}
|
||||
|
||||
|
||||
|
||||
- filters: ${model.model_defaults.filters}
|
||||
repeat: ${model.model_defaults.repeat}
|
||||
kernel: [15]
|
||||
stride: [1]
|
||||
dilation: [1]
|
||||
dropout: ${model.model_defaults.dropout}
|
||||
residual: true
|
||||
separable: ${model.model_defaults.separable}
|
||||
se: ${model.model_defaults.se}
|
||||
se_context_size: ${model.model_defaults.se_context_size}
|
||||
kernel_size_factor: ${model.model_defaults.kernel_size_factor}
|
||||
|
||||
|
||||
|
||||
- filters: ${model.model_defaults.filters}
|
||||
repeat: ${model.model_defaults.repeat}
|
||||
kernel: [17]
|
||||
stride: [1]
|
||||
dilation: [1]
|
||||
dropout: ${model.model_defaults.dropout}
|
||||
residual: true
|
||||
separable: ${model.model_defaults.separable}
|
||||
se: ${model.model_defaults.se}
|
||||
se_context_size: ${model.model_defaults.se_context_size}
|
||||
kernel_size_factor: ${model.model_defaults.kernel_size_factor}
|
||||
|
||||
|
||||
|
||||
- filters: ${model.model_defaults.filters}
|
||||
repeat: ${model.model_defaults.repeat}
|
||||
kernel: [19]
|
||||
stride: [1]
|
||||
dilation: [1]
|
||||
dropout: ${model.model_defaults.dropout}
|
||||
residual: true
|
||||
separable: ${model.model_defaults.separable}
|
||||
se: ${model.model_defaults.se}
|
||||
se_context_size: ${model.model_defaults.se_context_size}
|
||||
kernel_size_factor: ${model.model_defaults.kernel_size_factor}
|
||||
|
||||
|
||||
|
||||
- filters: ${model.model_defaults.filters}
|
||||
repeat: ${model.model_defaults.repeat}
|
||||
kernel: [21]
|
||||
stride: [1]
|
||||
dilation: [1]
|
||||
dropout: ${model.model_defaults.dropout}
|
||||
residual: true
|
||||
separable: ${model.model_defaults.separable}
|
||||
se: ${model.model_defaults.se}
|
||||
se_context_size: ${model.model_defaults.se_context_size}
|
||||
kernel_size_factor: ${model.model_defaults.kernel_size_factor}
|
||||
|
||||
|
||||
|
||||
- filters: ${model.model_defaults.filters}
|
||||
repeat: ${model.model_defaults.repeat}
|
||||
kernel: [23]
|
||||
stride: [1]
|
||||
dilation: [1]
|
||||
dropout: ${model.model_defaults.dropout}
|
||||
residual: true
|
||||
separable: ${model.model_defaults.separable}
|
||||
se: ${model.model_defaults.se}
|
||||
se_context_size: ${model.model_defaults.se_context_size}
|
||||
kernel_size_factor: ${model.model_defaults.kernel_size_factor}
|
||||
|
||||
|
||||
|
||||
- filters: ${model.model_defaults.filters}
|
||||
repeat: ${model.model_defaults.repeat}
|
||||
kernel: [25]
|
||||
stride: [1]
|
||||
dilation: [1]
|
||||
dropout: ${model.model_defaults.dropout}
|
||||
residual: true
|
||||
separable: ${model.model_defaults.separable}
|
||||
se: ${model.model_defaults.se}
|
||||
se_context_size: ${model.model_defaults.se_context_size}
|
||||
kernel_size_factor: ${model.model_defaults.kernel_size_factor}
|
||||
|
||||
|
||||
|
||||
- filters: ${model.model_defaults.filters}
|
||||
repeat: ${model.model_defaults.repeat}
|
||||
kernel: [25]
|
||||
stride: [2] # stride
|
||||
dilation: [1]
|
||||
dropout: ${model.model_defaults.dropout}
|
||||
residual: true
|
||||
separable: ${model.model_defaults.separable}
|
||||
se: ${model.model_defaults.se}
|
||||
se_context_size: ${model.model_defaults.se_context_size}
|
||||
stride_last: true
|
||||
residual_mode: "stride_add"
|
||||
kernel_size_factor: ${model.model_defaults.kernel_size_factor}
|
||||
|
||||
|
||||
- filters: ${model.model_defaults.filters}
|
||||
repeat: ${model.model_defaults.repeat}
|
||||
kernel: [27]
|
||||
stride: [1]
|
||||
dilation: [1]
|
||||
dropout: ${model.model_defaults.dropout}
|
||||
residual: true
|
||||
separable: ${model.model_defaults.separable}
|
||||
se: ${model.model_defaults.se}
|
||||
se_context_size: ${model.model_defaults.se_context_size}
|
||||
kernel_size_factor: ${model.model_defaults.kernel_size_factor}
|
||||
|
||||
|
||||
|
||||
- filters: ${model.model_defaults.filters}
|
||||
repeat: ${model.model_defaults.repeat}
|
||||
kernel: [29]
|
||||
stride: [1]
|
||||
dilation: [1]
|
||||
dropout: ${model.model_defaults.dropout}
|
||||
residual: true
|
||||
separable: ${model.model_defaults.separable}
|
||||
se: ${model.model_defaults.se}
|
||||
se_context_size: ${model.model_defaults.se_context_size}
|
||||
kernel_size_factor: ${model.model_defaults.kernel_size_factor}
|
||||
|
||||
|
||||
|
||||
- filters: ${model.model_defaults.filters}
|
||||
repeat: ${model.model_defaults.repeat}
|
||||
kernel: [31]
|
||||
stride: [1]
|
||||
dilation: [1]
|
||||
dropout: ${model.model_defaults.dropout}
|
||||
residual: true
|
||||
separable: ${model.model_defaults.separable}
|
||||
se: ${model.model_defaults.se}
|
||||
se_context_size: ${model.model_defaults.se_context_size}
|
||||
kernel_size_factor: ${model.model_defaults.kernel_size_factor}
|
||||
|
||||
|
||||
|
||||
- filters: ${model.model_defaults.filters}
|
||||
repeat: ${model.model_defaults.repeat}
|
||||
kernel: [33]
|
||||
stride: [1]
|
||||
dilation: [1]
|
||||
dropout: ${model.model_defaults.dropout}
|
||||
residual: true
|
||||
separable: ${model.model_defaults.separable}
|
||||
se: ${model.model_defaults.se}
|
||||
se_context_size: ${model.model_defaults.se_context_size}
|
||||
kernel_size_factor: ${model.model_defaults.kernel_size_factor}
|
||||
|
||||
|
||||
|
||||
- filters: ${model.model_defaults.filters}
|
||||
repeat: ${model.model_defaults.repeat}
|
||||
kernel: [35]
|
||||
stride: [1]
|
||||
dilation: [1]
|
||||
dropout: ${model.model_defaults.dropout}
|
||||
residual: true
|
||||
separable: ${model.model_defaults.separable}
|
||||
se: ${model.model_defaults.se}
|
||||
se_context_size: ${model.model_defaults.se_context_size}
|
||||
kernel_size_factor: ${model.model_defaults.kernel_size_factor}
|
||||
|
||||
|
||||
|
||||
- filters: ${model.model_defaults.filters}
|
||||
repeat: ${model.model_defaults.repeat}
|
||||
kernel: [37]
|
||||
stride: [1]
|
||||
dilation: [1]
|
||||
dropout: ${model.model_defaults.dropout}
|
||||
residual: true
|
||||
separable: ${model.model_defaults.separable}
|
||||
se: ${model.model_defaults.se}
|
||||
se_context_size: ${model.model_defaults.se_context_size}
|
||||
kernel_size_factor: ${model.model_defaults.kernel_size_factor}
|
||||
|
||||
|
||||
|
||||
- filters: ${model.model_defaults.filters}
|
||||
repeat: ${model.model_defaults.repeat}
|
||||
kernel: [39]
|
||||
stride: [1]
|
||||
dilation: [1]
|
||||
dropout: ${model.model_defaults.dropout}
|
||||
residual: true
|
||||
separable: ${model.model_defaults.separable}
|
||||
se: ${model.model_defaults.se}
|
||||
se_context_size: ${model.model_defaults.se_context_size}
|
||||
kernel_size_factor: ${model.model_defaults.kernel_size_factor}
|
||||
|
||||
|
||||
|
||||
- filters: &enc_final 1024
|
||||
repeat: 1
|
||||
kernel: [41]
|
||||
stride: [1]
|
||||
dilation: [1]
|
||||
dropout: 0.0
|
||||
residual: false
|
||||
separable: ${model.model_defaults.separable}
|
||||
se: ${model.model_defaults.se}
|
||||
se_context_size: ${model.model_defaults.se_context_size}
|
||||
kernel_size_factor: ${model.model_defaults.kernel_size_factor}
|
||||
|
||||
|
||||
decoder:
|
||||
_target_: nemo.collections.asr.modules.ConvASRDecoder
|
||||
feat_in: *enc_final
|
||||
num_classes: -1 # filled with vocabulary size from tokenizer at runtime
|
||||
vocabulary: [] # filled with vocabulary from tokenizer at runtime
|
||||
|
||||
|
||||
optim:
|
||||
name: novograd
|
||||
lr: 0.05
|
||||
|
||||
# optimizer arguments
|
||||
betas: [0.8, 0.25]
|
||||
weight_decay: 0.001
|
||||
|
||||
# scheduler setup
|
||||
sched:
|
||||
name: CosineAnnealing
|
||||
|
||||
# scheduler config override
|
||||
warmup_steps: 5000
|
||||
warmup_ratio: null
|
||||
min_lr: 1e-5
|
||||
last_epoch: -1
|
||||
|
||||
trainer:
|
||||
gpus: 0 # number of gpus
|
||||
max_epochs: 100
|
||||
max_steps: null # computed at runtime if not set
|
||||
num_nodes: 1
|
||||
accelerator: ddp
|
||||
accumulate_grad_batches: 1
|
||||
checkpoint_callback: false # Provided by exp_manager
|
||||
logger: false # Provided by exp_manager
|
||||
log_every_n_steps: 100 # Interval of logging.
|
||||
val_check_interval: 1.0 # Set to 0.25 to check 4 times per epoch, or an int for number of iterations
|
||||
check_val_every_n_epoch: 1
|
||||
precision: 32
|
||||
sync_batchnorm: false
|
||||
benchmark: false
|
||||
|
||||
exp_manager:
|
||||
exp_dir: null
|
||||
name: *name
|
||||
create_tensorboard_logger: true
|
||||
create_checkpoint_callback: true
|
||||
checkpoint_callback_params:
|
||||
monitor: "val_wer"
|
||||
mode: "min"
|
||||
save_top_k: 3
|
||||
create_wandb_logger: false
|
||||
wandb_logger_kwargs:
|
||||
name: null
|
||||
project: null
|
||||
entity: null
|
||||
resume_if_exists: false
|
||||
resume_ignore_no_checkpoint: false
|
||||
|
||||
hydra:
|
||||
run:
|
||||
dir: .
|
||||
job_logging:
|
||||
root:
|
||||
handlers: null
|
473
examples/asr/conf/citrinet_ssl/citrinet_ssl_1024.yaml
Normal file
473
examples/asr/conf/citrinet_ssl/citrinet_ssl_1024.yaml
Normal file
|
@ -0,0 +1,473 @@
|
|||
# This config contains the default values for self-supervised pre-training of a Citrinet model with contrastive loss.
|
||||
# Default learning parameters in this config are set for effective batch size of 1k on 32 GPUs.
|
||||
# To train it with smaller batch sizes, you may need to re-tune the learning parameters or use higher accumulate_grad_batches.
|
||||
# If training for a short time, you can also reduce weight decay to 0.
|
||||
|
||||
# Training Recipe
|
||||
# This model can be trained using the default settings in this config with FP32 precision.
|
||||
# When training under AMP, increase `warmup_steps` to 5000 for stable training.
|
||||
# In order to create Citrinet-C, change the model.model_defaults.filters parameter.
|
||||
# When reducing the receptive field of these models, it is advised to reduce the amount of augmentation
|
||||
# for larger models from 10x time masking to 5x or 2x time masking.
|
||||
# For further details regarding Citrinet, visit - https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/main/asr/configs.html#citrinet
|
||||
|
||||
name: &name "Citrinet-1024-SSL-Contrastive"
|
||||
|
||||
model:
|
||||
sample_rate: &sample_rate 16000
|
||||
|
||||
train_ds:
|
||||
manifest_filepath: ???
|
||||
sample_rate: 16000
|
||||
batch_size: 32
|
||||
trim_silence: false
|
||||
max_duration: 35.0
|
||||
min_duration: 3.0
|
||||
shuffle: true
|
||||
is_tarred: false
|
||||
tarred_audio_filepaths: null
|
||||
use_start_end_token: false
|
||||
|
||||
validation_ds:
|
||||
manifest_filepath: ???
|
||||
sample_rate: 16000
|
||||
batch_size: 32
|
||||
shuffle: false
|
||||
use_start_end_token: false
|
||||
max_duration: 35.0
|
||||
min_duration: 3.0
|
||||
|
||||
model_defaults:
|
||||
repeat: 5
|
||||
dropout: 0.1
|
||||
separable: true
|
||||
se: true
|
||||
se_context_size: -1
|
||||
kernel_size_factor: 0.25
|
||||
filters: 1024
|
||||
decoder_out_channels: 128
|
||||
|
||||
|
||||
preprocessor:
|
||||
_target_: nemo.collections.asr.modules.AudioToMelSpectrogramPreprocessor
|
||||
sample_rate: *sample_rate
|
||||
normalize: "per_feature"
|
||||
window_size: 0.025
|
||||
window_stride: 0.01
|
||||
window: "hann"
|
||||
features: &n_mels 80
|
||||
n_fft: 512
|
||||
frame_splicing: 1
|
||||
dither: 0.00001
|
||||
pad_to: 16
|
||||
stft_conv: false
|
||||
|
||||
spec_augment:
|
||||
_target_: nemo.collections.asr.modules.SpectrogramAugmentation
|
||||
freq_masks: 4
|
||||
time_masks: 10
|
||||
freq_width: 27
|
||||
time_width: 0.05
|
||||
|
||||
encoder:
|
||||
_target_: nemo.collections.asr.modules.ConvASREncoder
|
||||
feat_in: *n_mels
|
||||
activation: relu
|
||||
conv_mask: true
|
||||
|
||||
jasper:
|
||||
- filters: ${model.model_defaults.filters}
|
||||
repeat: 1
|
||||
kernel: [5]
|
||||
stride: [1]
|
||||
dilation: [1]
|
||||
dropout: 0.0
|
||||
residual: false
|
||||
separable: ${model.model_defaults.separable}
|
||||
se: ${model.model_defaults.se}
|
||||
se_context_size: ${model.model_defaults.se_context_size}
|
||||
|
||||
|
||||
|
||||
- filters: ${model.model_defaults.filters}
|
||||
repeat: ${model.model_defaults.repeat}
|
||||
kernel: [11]
|
||||
stride: [2]
|
||||
dilation: [1]
|
||||
dropout: ${model.model_defaults.dropout}
|
||||
residual: true
|
||||
separable: ${model.model_defaults.separable}
|
||||
se: ${model.model_defaults.se}
|
||||
se_context_size: ${model.model_defaults.se_context_size}
|
||||
stride_last: true
|
||||
residual_mode: "stride_add"
|
||||
kernel_size_factor: ${model.model_defaults.kernel_size_factor}
|
||||
|
||||
|
||||
|
||||
- filters: ${model.model_defaults.filters}
|
||||
repeat: ${model.model_defaults.repeat}
|
||||
kernel: [13]
|
||||
stride: [1]
|
||||
dilation: [1]
|
||||
dropout: ${model.model_defaults.dropout}
|
||||
residual: true
|
||||
separable: ${model.model_defaults.separable}
|
||||
se: ${model.model_defaults.se}
|
||||
se_context_size: ${model.model_defaults.se_context_size}
|
||||
kernel_size_factor: ${model.model_defaults.kernel_size_factor}
|
||||
|
||||
|
||||
|
||||
- filters: ${model.model_defaults.filters}
|
||||
repeat: ${model.model_defaults.repeat}
|
||||
kernel: [15]
|
||||
stride: [1]
|
||||
dilation: [1]
|
||||
dropout: ${model.model_defaults.dropout}
|
||||
residual: true
|
||||
separable: ${model.model_defaults.separable}
|
||||
se: ${model.model_defaults.se}
|
||||
se_context_size: ${model.model_defaults.se_context_size}
|
||||
kernel_size_factor: ${model.model_defaults.kernel_size_factor}
|
||||
|
||||
|
||||
|
||||
- filters: ${model.model_defaults.filters}
|
||||
repeat: ${model.model_defaults.repeat}
|
||||
kernel: [17]
|
||||
stride: [1]
|
||||
dilation: [1]
|
||||
dropout: ${model.model_defaults.dropout}
|
||||
residual: true
|
||||
separable: ${model.model_defaults.separable}
|
||||
se: ${model.model_defaults.se}
|
||||
se_context_size: ${model.model_defaults.se_context_size}
|
||||
kernel_size_factor: ${model.model_defaults.kernel_size_factor}
|
||||
|
||||
|
||||
|
||||
- filters: ${model.model_defaults.filters}
|
||||
repeat: ${model.model_defaults.repeat}
|
||||
kernel: [19]
|
||||
stride: [1]
|
||||
dilation: [1]
|
||||
dropout: ${model.model_defaults.dropout}
|
||||
residual: true
|
||||
separable: ${model.model_defaults.separable}
|
||||
se: ${model.model_defaults.se}
|
||||
se_context_size: ${model.model_defaults.se_context_size}
|
||||
kernel_size_factor: ${model.model_defaults.kernel_size_factor}
|
||||
|
||||
|
||||
|
||||
- filters: ${model.model_defaults.filters}
|
||||
repeat: ${model.model_defaults.repeat}
|
||||
kernel: [21]
|
||||
stride: [1]
|
||||
dilation: [1]
|
||||
dropout: ${model.model_defaults.dropout}
|
||||
residual: true
|
||||
separable: ${model.model_defaults.separable}
|
||||
se: ${model.model_defaults.se}
|
||||
se_context_size: ${model.model_defaults.se_context_size}
|
||||
kernel_size_factor: ${model.model_defaults.kernel_size_factor}
|
||||
|
||||
|
||||
|
||||
- filters: ${model.model_defaults.filters}
|
||||
repeat: ${model.model_defaults.repeat}
|
||||
kernel: [13]
|
||||
stride: [2] # *stride
|
||||
dilation: [1]
|
||||
dropout: ${model.model_defaults.dropout}
|
||||
residual: true
|
||||
separable: ${model.model_defaults.separable}
|
||||
se: ${model.model_defaults.se}
|
||||
se_context_size: ${model.model_defaults.se_context_size}
|
||||
stride_last: true
|
||||
residual_mode: "stride_add"
|
||||
kernel_size_factor: ${model.model_defaults.kernel_size_factor}
|
||||
|
||||
|
||||
|
||||
- filters: ${model.model_defaults.filters}
|
||||
repeat: ${model.model_defaults.repeat}
|
||||
kernel: [15]
|
||||
stride: [1]
|
||||
dilation: [1]
|
||||
dropout: ${model.model_defaults.dropout}
|
||||
residual: true
|
||||
separable: ${model.model_defaults.separable}
|
||||
se: ${model.model_defaults.se}
|
||||
se_context_size: ${model.model_defaults.se_context_size}
|
||||
kernel_size_factor: ${model.model_defaults.kernel_size_factor}
|
||||
|
||||
|
||||
|
||||
- filters: ${model.model_defaults.filters}
|
||||
repeat: ${model.model_defaults.repeat}
|
||||
kernel: [17]
|
||||
stride: [1]
|
||||
dilation: [1]
|
||||
dropout: ${model.model_defaults.dropout}
|
||||
residual: true
|
||||
separable: ${model.model_defaults.separable}
|
||||
se: ${model.model_defaults.se}
|
||||
se_context_size: ${model.model_defaults.se_context_size}
|
||||
kernel_size_factor: ${model.model_defaults.kernel_size_factor}
|
||||
|
||||
|
||||
- filters: ${model.model_defaults.filters}
|
||||
repeat: ${model.model_defaults.repeat}
|
||||
kernel: [19]
|
||||
stride: [1]
|
||||
dilation: [1]
|
||||
dropout: ${model.model_defaults.dropout}
|
||||
residual: true
|
||||
separable: ${model.model_defaults.separable}
|
||||
se: ${model.model_defaults.se}
|
||||
se_context_size: ${model.model_defaults.se_context_size}
|
||||
kernel_size_factor: ${model.model_defaults.kernel_size_factor}
|
||||
|
||||
|
||||
- filters: ${model.model_defaults.filters}
|
||||
repeat: ${model.model_defaults.repeat}
|
||||
kernel: [21]
|
||||
stride: [1]
|
||||
dilation: [1]
|
||||
dropout: ${model.model_defaults.dropout}
|
||||
residual: true
|
||||
separable: ${model.model_defaults.separable}
|
||||
se: ${model.model_defaults.se}
|
||||
se_context_size: ${model.model_defaults.se_context_size}
|
||||
kernel_size_factor: ${model.model_defaults.kernel_size_factor}
|
||||
|
||||
|
||||
- filters: ${model.model_defaults.filters}
|
||||
repeat: ${model.model_defaults.repeat}
|
||||
kernel: [23]
|
||||
stride: [1]
|
||||
dilation: [1]
|
||||
dropout: ${model.model_defaults.dropout}
|
||||
residual: true
|
||||
separable: ${model.model_defaults.separable}
|
||||
se: ${model.model_defaults.se}
|
||||
se_context_size: ${model.model_defaults.se_context_size}
|
||||
kernel_size_factor: ${model.model_defaults.kernel_size_factor}
|
||||
|
||||
|
||||
|
||||
- filters: ${model.model_defaults.filters}
|
||||
repeat: ${model.model_defaults.repeat}
|
||||
kernel: [25]
|
||||
stride: [1]
|
||||
dilation: [1]
|
||||
dropout: ${model.model_defaults.dropout}
|
||||
residual: true
|
||||
separable: ${model.model_defaults.separable}
|
||||
se: ${model.model_defaults.se}
|
||||
se_context_size: ${model.model_defaults.se_context_size}
|
||||
kernel_size_factor: ${model.model_defaults.kernel_size_factor}
|
||||
|
||||
|
||||
|
||||
- filters: ${model.model_defaults.filters}
|
||||
repeat: ${model.model_defaults.repeat}
|
||||
kernel: [25]
|
||||
stride: [2] # stride
|
||||
dilation: [1]
|
||||
dropout: ${model.model_defaults.dropout}
|
||||
residual: true
|
||||
separable: ${model.model_defaults.separable}
|
||||
se: ${model.model_defaults.se}
|
||||
se_context_size: ${model.model_defaults.se_context_size}
|
||||
stride_last: true
|
||||
residual_mode: "stride_add"
|
||||
kernel_size_factor: ${model.model_defaults.kernel_size_factor}
|
||||
|
||||
|
||||
|
||||
- filters: ${model.model_defaults.filters}
|
||||
repeat: ${model.model_defaults.repeat}
|
||||
kernel: [27]
|
||||
stride: [1]
|
||||
dilation: [1]
|
||||
dropout: ${model.model_defaults.dropout}
|
||||
residual: true
|
||||
separable: ${model.model_defaults.separable}
|
||||
se: ${model.model_defaults.se}
|
||||
se_context_size: ${model.model_defaults.se_context_size}
|
||||
kernel_size_factor: ${model.model_defaults.kernel_size_factor}
|
||||
|
||||
|
||||
- filters: ${model.model_defaults.filters}
|
||||
repeat: ${model.model_defaults.repeat}
|
||||
kernel: [29]
|
||||
stride: [1]
|
||||
dilation: [1]
|
||||
dropout: ${model.model_defaults.dropout}
|
||||
residual: true
|
||||
separable: ${model.model_defaults.separable}
|
||||
se: ${model.model_defaults.se}
|
||||
se_context_size: ${model.model_defaults.se_context_size}
|
||||
kernel_size_factor: ${model.model_defaults.kernel_size_factor}
|
||||
|
||||
|
||||
|
||||
- filters: ${model.model_defaults.filters}
|
||||
repeat: ${model.model_defaults.repeat}
|
||||
kernel: [31]
|
||||
stride: [1]
|
||||
dilation: [1]
|
||||
dropout: ${model.model_defaults.dropout}
|
||||
residual: true
|
||||
separable: ${model.model_defaults.separable}
|
||||
se: ${model.model_defaults.se}
|
||||
se_context_size: ${model.model_defaults.se_context_size}
|
||||
kernel_size_factor: ${model.model_defaults.kernel_size_factor}
|
||||
|
||||
|
||||
|
||||
- filters: ${model.model_defaults.filters}
|
||||
repeat: ${model.model_defaults.repeat}
|
||||
kernel: [33]
|
||||
stride: [1]
|
||||
dilation: [1]
|
||||
dropout: ${model.model_defaults.dropout}
|
||||
residual: true
|
||||
separable: ${model.model_defaults.separable}
|
||||
se: ${model.model_defaults.se}
|
||||
se_context_size: ${model.model_defaults.se_context_size}
|
||||
kernel_size_factor: ${model.model_defaults.kernel_size_factor}
|
||||
|
||||
|
||||
|
||||
- filters: ${model.model_defaults.filters}
|
||||
repeat: ${model.model_defaults.repeat}
|
||||
kernel: [35]
|
||||
stride: [1]
|
||||
dilation: [1]
|
||||
dropout: ${model.model_defaults.dropout}
|
||||
residual: true
|
||||
separable: ${model.model_defaults.separable}
|
||||
se: ${model.model_defaults.se}
|
||||
se_context_size: ${model.model_defaults.se_context_size}
|
||||
kernel_size_factor: ${model.model_defaults.kernel_size_factor}
|
||||
|
||||
|
||||
|
||||
- filters: ${model.model_defaults.filters}
|
||||
repeat: ${model.model_defaults.repeat}
|
||||
kernel: [37]
|
||||
stride: [1]
|
||||
dilation: [1]
|
||||
dropout: ${model.model_defaults.dropout}
|
||||
residual: true
|
||||
separable: ${model.model_defaults.separable}
|
||||
se: ${model.model_defaults.se}
|
||||
se_context_size: ${model.model_defaults.se_context_size}
|
||||
kernel_size_factor: ${model.model_defaults.kernel_size_factor}
|
||||
|
||||
|
||||
|
||||
- filters: ${model.model_defaults.filters}
|
||||
repeat: ${model.model_defaults.repeat}
|
||||
kernel: [39]
|
||||
stride: [1]
|
||||
dilation: [1]
|
||||
dropout: ${model.model_defaults.dropout}
|
||||
residual: true
|
||||
separable: ${model.model_defaults.separable}
|
||||
se: ${model.model_defaults.se}
|
||||
se_context_size: ${model.model_defaults.se_context_size}
|
||||
kernel_size_factor: ${model.model_defaults.kernel_size_factor}
|
||||
|
||||
|
||||
|
||||
- filters: &enc_final 1024
|
||||
repeat: 1
|
||||
kernel: [41]
|
||||
stride: [1]
|
||||
dilation: [1]
|
||||
dropout: 0.0
|
||||
residual: false
|
||||
separable: ${model.model_defaults.separable}
|
||||
se: ${model.model_defaults.se}
|
||||
se_context_size: ${model.model_defaults.se_context_size}
|
||||
kernel_size_factor: ${model.model_defaults.kernel_size_factor}
|
||||
|
||||
|
||||
decoder:
|
||||
_target_: nemo.collections.asr.modules.ConvASRDecoderReconstruction
|
||||
feat_in: *enc_final
|
||||
feat_hidden: 128
|
||||
feat_out: ${model.model_defaults.decoder_out_channels}
|
||||
stride_layers: 1
|
||||
|
||||
|
||||
loss:
|
||||
_target_: nemo.collections.asr.losses.ContrastiveLoss
|
||||
in_dim: *n_mels
|
||||
proj_dim: ${model.model_defaults.decoder_out_channels}
|
||||
combine_time_steps: 4
|
||||
codebook_size: 1200
|
||||
|
||||
|
||||
optim:
|
||||
name: novograd
|
||||
lr: 0.05
|
||||
|
||||
# optimizer arguments
|
||||
betas: [0.8, 0.25]
|
||||
weight_decay: 0.001
|
||||
|
||||
# scheduler setup
|
||||
sched:
|
||||
name: CosineAnnealing
|
||||
|
||||
# scheduler config override
|
||||
warmup_steps: 5000
|
||||
warmup_ratio: null
|
||||
min_lr: 1e-5
|
||||
last_epoch: -1
|
||||
|
||||
trainer:
|
||||
gpus: 0 # number of gpus
|
||||
max_epochs: 100
|
||||
max_steps: null # computed at runtime if not set
|
||||
num_nodes: 1
|
||||
accelerator: ddp
|
||||
accumulate_grad_batches: 1
|
||||
checkpoint_callback: false # Provided by exp_manager
|
||||
logger: false # Provided by exp_manager
|
||||
log_every_n_steps: 100 # Interval of logging.
|
||||
val_check_interval: 1.0 # Set to 0.25 to check 4 times per epoch, or an int for number of iterations
|
||||
check_val_every_n_epoch: 1
|
||||
precision: 32
|
||||
sync_batchnorm: false
|
||||
benchmark: false
|
||||
|
||||
exp_manager:
|
||||
exp_dir: null
|
||||
name: *name
|
||||
create_tensorboard_logger: true
|
||||
create_checkpoint_callback: true
|
||||
checkpoint_callback_params:
|
||||
monitor: "val_loss"
|
||||
mode: "min"
|
||||
save_top_k: 1
|
||||
create_wandb_logger: false
|
||||
wandb_logger_kwargs:
|
||||
name: null
|
||||
project: null
|
||||
entity: null
|
||||
resume_if_exists: false
|
||||
resume_ignore_no_checkpoint: false
|
||||
|
||||
hydra:
|
||||
run:
|
||||
dir: .
|
||||
job_logging:
|
||||
root:
|
||||
handlers: null
|
462
examples/asr/conf/citrinet_ssl/citrinet_ssl_ci.yaml
Normal file
462
examples/asr/conf/citrinet_ssl/citrinet_ssl_ci.yaml
Normal file
|
@ -0,0 +1,462 @@
|
|||
# This config is used for the CI test of self-supervised learning for CitriNet
|
||||
|
||||
name: &name "Citrinet-SSL-CI"
|
||||
|
||||
model:
|
||||
sample_rate: &sample_rate 16000
|
||||
|
||||
train_ds:
|
||||
manifest_filepath: ???
|
||||
sample_rate: 16000
|
||||
batch_size: 32
|
||||
trim_silence: false
|
||||
max_duration: 35.0
|
||||
min_duration: 3.0
|
||||
shuffle: true
|
||||
is_tarred: false
|
||||
tarred_audio_filepaths: null
|
||||
use_start_end_token: false
|
||||
|
||||
validation_ds:
|
||||
manifest_filepath: ???
|
||||
sample_rate: 16000
|
||||
batch_size: 32
|
||||
shuffle: false
|
||||
use_start_end_token: false
|
||||
max_duration: 35.0
|
||||
min_duration: 3.0
|
||||
|
||||
model_defaults:
|
||||
repeat: 1
|
||||
dropout: 0.1
|
||||
separable: true
|
||||
se: true
|
||||
se_context_size: -1
|
||||
kernel_size_factor: 0.25
|
||||
filters: 128
|
||||
decoder_out_channels: 128
|
||||
|
||||
|
||||
preprocessor:
|
||||
_target_: nemo.collections.asr.modules.AudioToMelSpectrogramPreprocessor
|
||||
sample_rate: *sample_rate
|
||||
normalize: "per_feature"
|
||||
window_size: 0.025
|
||||
window_stride: 0.01
|
||||
window: "hann"
|
||||
features: &n_mels 80
|
||||
n_fft: 512
|
||||
frame_splicing: 1
|
||||
dither: 0.00001
|
||||
pad_to: 16
|
||||
stft_conv: false
|
||||
|
||||
spec_augment:
|
||||
_target_: nemo.collections.asr.modules.SpectrogramAugmentation
|
||||
freq_masks: 4
|
||||
time_masks: 10
|
||||
freq_width: 27
|
||||
time_width: 0.05
|
||||
|
||||
encoder:
|
||||
_target_: nemo.collections.asr.modules.ConvASREncoder
|
||||
feat_in: *n_mels
|
||||
activation: relu
|
||||
conv_mask: true
|
||||
|
||||
jasper:
|
||||
- filters: ${model.model_defaults.filters}
|
||||
repeat: 1
|
||||
kernel: [5]
|
||||
stride: [1]
|
||||
dilation: [1]
|
||||
dropout: 0.0
|
||||
residual: false
|
||||
separable: ${model.model_defaults.separable}
|
||||
se: ${model.model_defaults.se}
|
||||
se_context_size: ${model.model_defaults.se_context_size}
|
||||
|
||||
|
||||
|
||||
- filters: ${model.model_defaults.filters}
|
||||
repeat: ${model.model_defaults.repeat}
|
||||
kernel: [11]
|
||||
stride: [2]
|
||||
dilation: [1]
|
||||
dropout: ${model.model_defaults.dropout}
|
||||
residual: true
|
||||
separable: ${model.model_defaults.separable}
|
||||
se: ${model.model_defaults.se}
|
||||
se_context_size: ${model.model_defaults.se_context_size}
|
||||
stride_last: true
|
||||
residual_mode: "stride_add"
|
||||
kernel_size_factor: ${model.model_defaults.kernel_size_factor}
|
||||
|
||||
|
||||
|
||||
- filters: ${model.model_defaults.filters}
|
||||
repeat: ${model.model_defaults.repeat}
|
||||
kernel: [13]
|
||||
stride: [1]
|
||||
dilation: [1]
|
||||
dropout: ${model.model_defaults.dropout}
|
||||
residual: true
|
||||
separable: ${model.model_defaults.separable}
|
||||
se: ${model.model_defaults.se}
|
||||
se_context_size: ${model.model_defaults.se_context_size}
|
||||
kernel_size_factor: ${model.model_defaults.kernel_size_factor}
|
||||
|
||||
|
||||
|
||||
- filters: ${model.model_defaults.filters}
|
||||
repeat: ${model.model_defaults.repeat}
|
||||
kernel: [15]
|
||||
stride: [1]
|
||||
dilation: [1]
|
||||
dropout: ${model.model_defaults.dropout}
|
||||
residual: true
|
||||
separable: ${model.model_defaults.separable}
|
||||
se: ${model.model_defaults.se}
|
||||
se_context_size: ${model.model_defaults.se_context_size}
|
||||
kernel_size_factor: ${model.model_defaults.kernel_size_factor}
|
||||
|
||||
|
||||
|
||||
- filters: ${model.model_defaults.filters}
|
||||
repeat: ${model.model_defaults.repeat}
|
||||
kernel: [17]
|
||||
stride: [1]
|
||||
dilation: [1]
|
||||
dropout: ${model.model_defaults.dropout}
|
||||
residual: true
|
||||
separable: ${model.model_defaults.separable}
|
||||
se: ${model.model_defaults.se}
|
||||
se_context_size: ${model.model_defaults.se_context_size}
|
||||
kernel_size_factor: ${model.model_defaults.kernel_size_factor}
|
||||
|
||||
|
||||
|
||||
- filters: ${model.model_defaults.filters}
|
||||
repeat: ${model.model_defaults.repeat}
|
||||
kernel: [19]
|
||||
stride: [1]
|
||||
dilation: [1]
|
||||
dropout: ${model.model_defaults.dropout}
|
||||
residual: true
|
||||
separable: ${model.model_defaults.separable}
|
||||
se: ${model.model_defaults.se}
|
||||
se_context_size: ${model.model_defaults.se_context_size}
|
||||
kernel_size_factor: ${model.model_defaults.kernel_size_factor}
|
||||
|
||||
|
||||
|
||||
- filters: ${model.model_defaults.filters}
|
||||
repeat: ${model.model_defaults.repeat}
|
||||
kernel: [21]
|
||||
stride: [1]
|
||||
dilation: [1]
|
||||
dropout: ${model.model_defaults.dropout}
|
||||
residual: true
|
||||
separable: ${model.model_defaults.separable}
|
||||
se: ${model.model_defaults.se}
|
||||
se_context_size: ${model.model_defaults.se_context_size}
|
||||
kernel_size_factor: ${model.model_defaults.kernel_size_factor}
|
||||
|
||||
|
||||
|
||||
- filters: ${model.model_defaults.filters}
|
||||
repeat: ${model.model_defaults.repeat}
|
||||
kernel: [13]
|
||||
stride: [2] # *stride
|
||||
dilation: [1]
|
||||
dropout: ${model.model_defaults.dropout}
|
||||
residual: true
|
||||
separable: ${model.model_defaults.separable}
|
||||
se: ${model.model_defaults.se}
|
||||
se_context_size: ${model.model_defaults.se_context_size}
|
||||
stride_last: true
|
||||
residual_mode: "stride_add"
|
||||
kernel_size_factor: ${model.model_defaults.kernel_size_factor}
|
||||
|
||||
|
||||
|
||||
- filters: ${model.model_defaults.filters}
|
||||
repeat: ${model.model_defaults.repeat}
|
||||
kernel: [15]
|
||||
stride: [1]
|
||||
dilation: [1]
|
||||
dropout: ${model.model_defaults.dropout}
|
||||
residual: true
|
||||
separable: ${model.model_defaults.separable}
|
||||
se: ${model.model_defaults.se}
|
||||
se_context_size: ${model.model_defaults.se_context_size}
|
||||
kernel_size_factor: ${model.model_defaults.kernel_size_factor}
|
||||
|
||||
|
||||
|
||||
- filters: ${model.model_defaults.filters}
|
||||
repeat: ${model.model_defaults.repeat}
|
||||
kernel: [17]
|
||||
stride: [1]
|
||||
dilation: [1]
|
||||
dropout: ${model.model_defaults.dropout}
|
||||
residual: true
|
||||
separable: ${model.model_defaults.separable}
|
||||
se: ${model.model_defaults.se}
|
||||
se_context_size: ${model.model_defaults.se_context_size}
|
||||
kernel_size_factor: ${model.model_defaults.kernel_size_factor}
|
||||
|
||||
|
||||
- filters: ${model.model_defaults.filters}
|
||||
repeat: ${model.model_defaults.repeat}
|
||||
kernel: [19]
|
||||
stride: [1]
|
||||
dilation: [1]
|
||||
dropout: ${model.model_defaults.dropout}
|
||||
residual: true
|
||||
separable: ${model.model_defaults.separable}
|
||||
se: ${model.model_defaults.se}
|
||||
se_context_size: ${model.model_defaults.se_context_size}
|
||||
kernel_size_factor: ${model.model_defaults.kernel_size_factor}
|
||||
|
||||
|
||||
- filters: ${model.model_defaults.filters}
|
||||
repeat: ${model.model_defaults.repeat}
|
||||
kernel: [21]
|
||||
stride: [1]
|
||||
dilation: [1]
|
||||
dropout: ${model.model_defaults.dropout}
|
||||
residual: true
|
||||
separable: ${model.model_defaults.separable}
|
||||
se: ${model.model_defaults.se}
|
||||
se_context_size: ${model.model_defaults.se_context_size}
|
||||
kernel_size_factor: ${model.model_defaults.kernel_size_factor}
|
||||
|
||||
|
||||
- filters: ${model.model_defaults.filters}
|
||||
repeat: ${model.model_defaults.repeat}
|
||||
kernel: [23]
|
||||
stride: [1]
|
||||
dilation: [1]
|
||||
dropout: ${model.model_defaults.dropout}
|
||||
residual: true
|
||||
separable: ${model.model_defaults.separable}
|
||||
se: ${model.model_defaults.se}
|
||||
se_context_size: ${model.model_defaults.se_context_size}
|
||||
kernel_size_factor: ${model.model_defaults.kernel_size_factor}
|
||||
|
||||
|
||||
|
||||
- filters: ${model.model_defaults.filters}
|
||||
repeat: ${model.model_defaults.repeat}
|
||||
kernel: [25]
|
||||
stride: [1]
|
||||
dilation: [1]
|
||||
dropout: ${model.model_defaults.dropout}
|
||||
residual: true
|
||||
separable: ${model.model_defaults.separable}
|
||||
se: ${model.model_defaults.se}
|
||||
se_context_size: ${model.model_defaults.se_context_size}
|
||||
kernel_size_factor: ${model.model_defaults.kernel_size_factor}
|
||||
|
||||
|
||||
|
||||
- filters: ${model.model_defaults.filters}
|
||||
repeat: ${model.model_defaults.repeat}
|
||||
kernel: [25]
|
||||
stride: [2] # stride
|
||||
dilation: [1]
|
||||
dropout: ${model.model_defaults.dropout}
|
||||
residual: true
|
||||
separable: ${model.model_defaults.separable}
|
||||
se: ${model.model_defaults.se}
|
||||
se_context_size: ${model.model_defaults.se_context_size}
|
||||
stride_last: true
|
||||
residual_mode: "stride_add"
|
||||
kernel_size_factor: ${model.model_defaults.kernel_size_factor}
|
||||
|
||||
|
||||
|
||||
- filters: ${model.model_defaults.filters}
|
||||
repeat: ${model.model_defaults.repeat}
|
||||
kernel: [27]
|
||||
stride: [1]
|
||||
dilation: [1]
|
||||
dropout: ${model.model_defaults.dropout}
|
||||
residual: true
|
||||
separable: ${model.model_defaults.separable}
|
||||
se: ${model.model_defaults.se}
|
||||
se_context_size: ${model.model_defaults.se_context_size}
|
||||
kernel_size_factor: ${model.model_defaults.kernel_size_factor}
|
||||
|
||||
|
||||
- filters: ${model.model_defaults.filters}
|
||||
repeat: ${model.model_defaults.repeat}
|
||||
kernel: [29]
|
||||
stride: [1]
|
||||
dilation: [1]
|
||||
dropout: ${model.model_defaults.dropout}
|
||||
residual: true
|
||||
separable: ${model.model_defaults.separable}
|
||||
se: ${model.model_defaults.se}
|
||||
se_context_size: ${model.model_defaults.se_context_size}
|
||||
kernel_size_factor: ${model.model_defaults.kernel_size_factor}
|
||||
|
||||
|
||||
|
||||
- filters: ${model.model_defaults.filters}
|
||||
repeat: ${model.model_defaults.repeat}
|
||||
kernel: [31]
|
||||
stride: [1]
|
||||
dilation: [1]
|
||||
dropout: ${model.model_defaults.dropout}
|
||||
residual: true
|
||||
separable: ${model.model_defaults.separable}
|
||||
se: ${model.model_defaults.se}
|
||||
se_context_size: ${model.model_defaults.se_context_size}
|
||||
kernel_size_factor: ${model.model_defaults.kernel_size_factor}
|
||||
|
||||
|
||||
|
||||
- filters: ${model.model_defaults.filters}
|
||||
repeat: ${model.model_defaults.repeat}
|
||||
kernel: [33]
|
||||
stride: [1]
|
||||
dilation: [1]
|
||||
dropout: ${model.model_defaults.dropout}
|
||||
residual: true
|
||||
separable: ${model.model_defaults.separable}
|
||||
se: ${model.model_defaults.se}
|
||||
se_context_size: ${model.model_defaults.se_context_size}
|
||||
kernel_size_factor: ${model.model_defaults.kernel_size_factor}
|
||||
|
||||
|
||||
|
||||
- filters: ${model.model_defaults.filters}
|
||||
repeat: ${model.model_defaults.repeat}
|
||||
kernel: [35]
|
||||
stride: [1]
|
||||
dilation: [1]
|
||||
dropout: ${model.model_defaults.dropout}
|
||||
residual: true
|
||||
separable: ${model.model_defaults.separable}
|
||||
se: ${model.model_defaults.se}
|
||||
se_context_size: ${model.model_defaults.se_context_size}
|
||||
kernel_size_factor: ${model.model_defaults.kernel_size_factor}
|
||||
|
||||
|
||||
|
||||
- filters: ${model.model_defaults.filters}
|
||||
repeat: ${model.model_defaults.repeat}
|
||||
kernel: [37]
|
||||
stride: [1]
|
||||
dilation: [1]
|
||||
dropout: ${model.model_defaults.dropout}
|
||||
residual: true
|
||||
separable: ${model.model_defaults.separable}
|
||||
se: ${model.model_defaults.se}
|
||||
se_context_size: ${model.model_defaults.se_context_size}
|
||||
kernel_size_factor: ${model.model_defaults.kernel_size_factor}
|
||||
|
||||
|
||||
|
||||
- filters: ${model.model_defaults.filters}
|
||||
repeat: ${model.model_defaults.repeat}
|
||||
kernel: [39]
|
||||
stride: [1]
|
||||
dilation: [1]
|
||||
dropout: ${model.model_defaults.dropout}
|
||||
residual: true
|
||||
separable: ${model.model_defaults.separable}
|
||||
se: ${model.model_defaults.se}
|
||||
se_context_size: ${model.model_defaults.se_context_size}
|
||||
kernel_size_factor: ${model.model_defaults.kernel_size_factor}
|
||||
|
||||
|
||||
|
||||
- filters: &enc_final 256
|
||||
repeat: 1
|
||||
kernel: [41]
|
||||
stride: [1]
|
||||
dilation: [1]
|
||||
dropout: 0.0
|
||||
residual: false
|
||||
separable: ${model.model_defaults.separable}
|
||||
se: ${model.model_defaults.se}
|
||||
se_context_size: ${model.model_defaults.se_context_size}
|
||||
kernel_size_factor: ${model.model_defaults.kernel_size_factor}
|
||||
|
||||
|
||||
decoder:
|
||||
_target_: nemo.collections.asr.modules.ConvASRDecoderReconstruction
|
||||
feat_in: *enc_final
|
||||
feat_hidden: 128
|
||||
feat_out: ${model.model_defaults.decoder_out_channels}
|
||||
stride_layers: 1
|
||||
|
||||
|
||||
loss:
|
||||
_target_: nemo.collections.asr.losses.ContrastiveLoss
|
||||
in_dim: *n_mels
|
||||
proj_dim: ${model.model_defaults.decoder_out_channels}
|
||||
combine_time_steps: 4
|
||||
codebook_size: 300
|
||||
|
||||
|
||||
optim:
|
||||
name: novograd
|
||||
lr: 0.05
|
||||
|
||||
# optimizer arguments
|
||||
betas: [0.8, 0.25]
|
||||
weight_decay: 0.001
|
||||
|
||||
# scheduler setup
|
||||
sched:
|
||||
name: CosineAnnealing
|
||||
|
||||
# scheduler config override
|
||||
warmup_steps: 5000
|
||||
warmup_ratio: null
|
||||
min_lr: 1e-5
|
||||
last_epoch: -1
|
||||
|
||||
trainer:
|
||||
gpus: 0 # number of gpus
|
||||
max_epochs: 100
|
||||
max_steps: null # computed at runtime if not set
|
||||
num_nodes: 1
|
||||
accelerator: ddp
|
||||
accumulate_grad_batches: 1
|
||||
checkpoint_callback: false # Provided by exp_manager
|
||||
logger: false # Provided by exp_manager
|
||||
log_every_n_steps: 100 # Interval of logging.
|
||||
val_check_interval: 1.0 # Set to 0.25 to check 4 times per epoch, or an int for number of iterations
|
||||
check_val_every_n_epoch: 1
|
||||
precision: 32
|
||||
sync_batchnorm: false
|
||||
benchmark: false
|
||||
|
||||
exp_manager:
|
||||
exp_dir: null
|
||||
name: *name
|
||||
create_tensorboard_logger: true
|
||||
create_checkpoint_callback: true
|
||||
checkpoint_callback_params:
|
||||
monitor: "val_loss"
|
||||
mode: "min"
|
||||
save_top_k: 1
|
||||
create_wandb_logger: false
|
||||
wandb_logger_kwargs:
|
||||
name: null
|
||||
project: null
|
||||
entity: null
|
||||
resume_if_exists: false
|
||||
resume_ignore_no_checkpoint: false
|
||||
|
||||
hydra:
|
||||
run:
|
||||
dir: .
|
||||
job_logging:
|
||||
root:
|
||||
handlers: null
|
66
examples/asr/speech_pre_training.py
Normal file
66
examples/asr/speech_pre_training.py
Normal file
|
@ -0,0 +1,66 @@
|
|||
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from nemo.collections.asr.models.ssl_models import SpeechEncDecSelfSupervisedModel
|
||||
from nemo.core.config import hydra_runner
|
||||
from nemo.utils import logging
|
||||
from nemo.utils.exp_manager import exp_manager
|
||||
|
||||
"""
|
||||
# Example of unsupervised pre-training of a model
|
||||
```sh
|
||||
python speech_pre_training.py \
|
||||
# (Optional: --config-path=<path to dir of configs> --config-name=<name of config without .yaml>) \
|
||||
model.train_ds.manifest_filepath=<path to train manifest> \
|
||||
model.validation_ds.manifest_filepath=<path to val/test manifest> \
|
||||
trainer.gpus=-1 \
|
||||
trainer.accelerator="ddp" \
|
||||
trainer.max_epochs=100 \
|
||||
model.optim.name="adamw" \
|
||||
model.optim.lr=0.001 \
|
||||
model.optim.betas=[0.9,0.999] \
|
||||
model.optim.weight_decay=0.0001 \
|
||||
model.optim.sched.warmup_steps=2000
|
||||
exp_manager.create_wandb_logger=True \
|
||||
exp_manager.wandb_logger_kwargs.name="<Name of experiment>" \
|
||||
exp_manager.wandb_logger_kwargs.project="<Namex of project>"
|
||||
```
|
||||
|
||||
For documentation on fine-tuning, please visit -
|
||||
https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/main/asr/configs.html#fine-tuning-configurations
|
||||
When doing supervised fine-tuning from unsupervised pre-trained encoder, set flag init_strict to False
|
||||
|
||||
"""
|
||||
|
||||
|
||||
@hydra_runner(config_path="conf/citrinet_ssl/", config_name="citrinet_ssl_1024")
|
||||
def main(cfg):
|
||||
logging.info(f"Hydra config: {OmegaConf.to_yaml(cfg)}")
|
||||
|
||||
trainer = pl.Trainer(**cfg.trainer)
|
||||
exp_manager(trainer, cfg.get("exp_manager", None))
|
||||
asr_model = SpeechEncDecSelfSupervisedModel(cfg=cfg.model, trainer=trainer)
|
||||
|
||||
# Initialize the weights of the model from another model, if provided via config
|
||||
asr_model.maybe_init_from_pretrained_checkpoint(cfg)
|
||||
|
||||
trainer.fit(asr_model)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -1082,7 +1082,7 @@ class _TarredAudioToTextDataset(IterableDataset):
|
|||
logging.info("WebDataset will not shuffle files within the tar files.")
|
||||
|
||||
self._dataset = (
|
||||
self._dataset.rename(audio='wav', key='__key__')
|
||||
self._dataset.rename(audio='wav;ogg', key='__key__')
|
||||
.to_tuple('audio', 'key')
|
||||
.pipe(self._filter)
|
||||
.map(f=self._build_sample)
|
||||
|
|
|
@ -79,9 +79,12 @@ def get_char_dataset(config: dict, augmentor: Optional['AudioAugmentor'] = None)
|
|||
Returns:
|
||||
An instance of AudioToCharDataset.
|
||||
"""
|
||||
if 'labels' not in config:
|
||||
logging.warning(f"dataset does not have explicitly defined labels")
|
||||
|
||||
dataset = audio_to_text.AudioToCharDataset(
|
||||
manifest_filepath=config['manifest_filepath'],
|
||||
labels=config['labels'],
|
||||
labels=config.get('labels', None),
|
||||
sample_rate=config['sample_rate'],
|
||||
int_values=config.get('int_values', False),
|
||||
augmentor=augmentor,
|
||||
|
@ -161,6 +164,9 @@ def get_tarred_dataset(
|
|||
if len(manifest_filepaths) != len(tarred_audio_filepaths):
|
||||
raise ValueError(f"manifest_filepaths and tarred_audio_filepaths need to have the same number of buckets.")
|
||||
|
||||
if 'labels' not in config:
|
||||
logging.warning(f"dataset does not have explicitly defined labels")
|
||||
|
||||
for dataset_idx, (tarred_audio_filepath, manifest_filepath) in enumerate(
|
||||
zip(tarred_audio_filepaths, manifest_filepaths)
|
||||
):
|
||||
|
@ -170,7 +176,7 @@ def get_tarred_dataset(
|
|||
dataset = audio_to_text.TarredAudioToCharDataset(
|
||||
audio_tar_filepaths=tarred_audio_filepath,
|
||||
manifest_filepath=manifest_filepath,
|
||||
labels=config['labels'],
|
||||
labels=config.get('labels', None),
|
||||
sample_rate=config['sample_rate'],
|
||||
int_values=config.get('int_values', False),
|
||||
augmentor=augmentor,
|
||||
|
|
|
@ -14,3 +14,4 @@
|
|||
|
||||
from nemo.collections.asr.losses.angularloss import AngularSoftmaxLoss
|
||||
from nemo.collections.asr.losses.ctc import CTCLoss
|
||||
from nemo.collections.asr.losses.pt_losses.contrastive import ContrastiveLoss
|
||||
|
|
15
nemo/collections/asr/losses/pt_losses/__init__.py
Normal file
15
nemo/collections/asr/losses/pt_losses/__init__.py
Normal file
|
@ -0,0 +1,15 @@
|
|||
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from nemo.collections.asr.losses.pt_losses.contrastive import ContrastiveLoss
|
220
nemo/collections/asr/losses/pt_losses/contrastive.py
Normal file
220
nemo/collections/asr/losses/pt_losses/contrastive.py
Normal file
|
@ -0,0 +1,220 @@
|
|||
# ! /usr/bin/python
|
||||
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from nemo.core import Loss, typecheck
|
||||
from nemo.core.neural_types import LossType, NeuralType, SpectrogramType, VoidType
|
||||
|
||||
__all__ = ["ContrastiveLoss"]
|
||||
|
||||
|
||||
class ContrastiveLoss(Loss):
|
||||
@property
|
||||
def input_types(self):
|
||||
"""Input types definitions for Contrastive.
|
||||
"""
|
||||
return {
|
||||
"spectrograms": NeuralType(("B", "D", "T"), SpectrogramType()),
|
||||
"spec_masks": NeuralType(("B", "D", "T"), SpectrogramType()),
|
||||
"decoder_outputs": NeuralType(("B", "T", "D"), VoidType()),
|
||||
}
|
||||
|
||||
@property
|
||||
def output_types(self):
|
||||
"""Output types definitions for Contrastive.
|
||||
loss:
|
||||
NeuralType(None)
|
||||
"""
|
||||
return {"loss": NeuralType(elements_type=LossType())}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_dim: int,
|
||||
proj_dim: int = 128,
|
||||
combine_time_steps: int = 1,
|
||||
num_negatives: int = 100,
|
||||
quantized_targets: bool = True,
|
||||
codebook_size: int = 320,
|
||||
prob_ppl_weight: float = 0.1,
|
||||
logit_temp: float = 0.1,
|
||||
reduce: str = "sum",
|
||||
sample_from_non_masked: bool = True,
|
||||
sample_from_codebook: bool = False,
|
||||
group_loss: bool = False,
|
||||
num_groups: int = 2,
|
||||
quantizer_temp_start: float = 2,
|
||||
quantizer_temp_min: float = 0.5,
|
||||
quantizer_temp_decay: float = 0.999995,
|
||||
mask_threshold: float = 0.8,
|
||||
):
|
||||
"""
|
||||
Loss function representing the contrastive task of identifying the true latent speech representation of
|
||||
the masked spectrogram steps from a set of sampled distractors.
|
||||
|
||||
Args:
|
||||
in_dim: Number of spectrogram channels.
|
||||
proj_dim: Number of channels in the model outputs.
|
||||
combine_time_steps: How many time steps should be combined into a single representation.
|
||||
num_negatives: Number of sampled negatives for each target.
|
||||
quantized_targets: Bool that determines if the targets should be quantized.
|
||||
codebook_size: Number of vectors in the codebook per group.
|
||||
prob_ppl_weight: Float multiplier on the perplexity loss for target quantization.
|
||||
logit_temp: Float temperature for normalizing logits.
|
||||
reduce: String representing the type of reduction used for cross entropy.
|
||||
sample_from_non_masked: Bool that determines if negatives should be sampled from non-masked steps of the spectrogram.
|
||||
sample_from_codebook: Bool that determines if negatives should be sampled from entire codebook.
|
||||
group_loss: Bool that determines if loss should be computed separately for each group in the quantizer codebook.
|
||||
num_groups: Number of groups in the quantizer codebook.
|
||||
quantizer_temp_start: Starting temperature in quantizer.
|
||||
quantizer_temp_min: Minimum temperature in quantizer.
|
||||
quantizer_temp_decay: Decay rate of quantizer temperature per global step.
|
||||
mask_threshold: Float threshold for determining if a time step of the spectrogram is masked based on percent of masked channels.
|
||||
"""
|
||||
|
||||
super().__init__()
|
||||
quantizer_temp = (quantizer_temp_start, quantizer_temp_min, quantizer_temp_decay)
|
||||
self.quantized_targets = quantized_targets
|
||||
self.num_negatives = num_negatives
|
||||
self.prob_ppl_weight = prob_ppl_weight
|
||||
if self.quantized_targets:
|
||||
quantizer_cfg = {
|
||||
"_target_": "nemo.collections.asr.modules.wav2vec_modules.GumbelVectorQuantizer",
|
||||
"dim": in_dim * combine_time_steps,
|
||||
"vq_dim": proj_dim,
|
||||
"num_vars": codebook_size,
|
||||
"groups": num_groups,
|
||||
"temp": quantizer_temp,
|
||||
"combine_groups": True,
|
||||
"time_first": True,
|
||||
}
|
||||
self.quantizer = ContrastiveLoss.from_config_dict(quantizer_cfg)
|
||||
self.prob_ppl_weight = prob_ppl_weight
|
||||
self.logit_temp = logit_temp
|
||||
self.reduce = reduce
|
||||
self.combine_time_steps = combine_time_steps
|
||||
self.sample_from_non_masked = sample_from_non_masked
|
||||
self.sample_from_codebook = sample_from_codebook
|
||||
self.group_loss = group_loss
|
||||
self.mask_threshold = mask_threshold
|
||||
|
||||
if not self.quantized_targets:
|
||||
self.target_proj = nn.Linear(in_dim * combine_time_steps, proj_dim)
|
||||
|
||||
def sample_negatives(self, y, num):
|
||||
|
||||
high = y.shape[0]
|
||||
with torch.no_grad():
|
||||
neg_idxs = torch.randint(low=0, high=high - 1, size=(self.num_negatives * num,))
|
||||
|
||||
negs = y[neg_idxs.view(-1)]
|
||||
negs = negs.view(num, self.num_negatives, y.shape[-1]).permute(1, 0, 2) # to NxTxC
|
||||
return negs, neg_idxs
|
||||
|
||||
@typecheck()
|
||||
def forward(self, spectrograms, spec_masks, decoder_outputs):
|
||||
spec_in = spectrograms.transpose(-2, -1)
|
||||
masks = spec_masks.transpose(-2, -1)
|
||||
targets = spec_in
|
||||
# BxTxC
|
||||
|
||||
targets = targets.reshape(targets.shape[0], targets.shape[1] // self.combine_time_steps, -1)
|
||||
masks = masks.reshape(targets.shape)
|
||||
|
||||
if self.quantized_targets:
|
||||
targets, prob_ppl_loss, cur_codebook_temp = self.quantizer(targets)
|
||||
else:
|
||||
targets = self.target_proj(targets)
|
||||
|
||||
masks = masks.mean(-1) > self.mask_threshold
|
||||
out_masked_only = decoder_outputs[masks]
|
||||
targets_masked_only = targets[masks]
|
||||
# T'xC
|
||||
# number of masked time steps to predict (T')
|
||||
|
||||
if self.group_loss:
|
||||
num_groups = self.quantizer.groups
|
||||
negatives = self.quantizer.vars.reshape(num_groups, self.quantizer.num_vars, -1)
|
||||
# GxNx(C//G)
|
||||
negatives = negatives.transpose(0, 1)
|
||||
# NxGx(C//G)
|
||||
negatives = negatives.unsqueeze(1).expand(-1, out_masked_only.shape[0], -1, -1)
|
||||
# NxT'xGx(C//G)
|
||||
negatives = negatives.reshape(negatives.shape[0], -1, negatives.shape[-1])
|
||||
# NxT'Gx(C//G)
|
||||
|
||||
out_masked_only = out_masked_only.reshape(-1, out_masked_only.shape[-1] // num_groups)
|
||||
targets_masked_only = targets_masked_only.reshape(-1, targets_masked_only.shape[-1] // num_groups)
|
||||
# T'Gx(C//G)
|
||||
elif self.sample_from_codebook:
|
||||
# sample from the full codebook
|
||||
negatives = self.quantizer.sample_from_codebook(self.num_negatives, targets_masked_only.size(0))
|
||||
elif self.sample_from_non_masked:
|
||||
# sample from all steps in batch
|
||||
negatives, _ = self.sample_negatives(
|
||||
targets.reshape(targets.shape[0] * targets.shape[1], -1), targets_masked_only.size(0), # BTxC
|
||||
) # T'
|
||||
else:
|
||||
# only sample from masked steps
|
||||
negatives, _ = self.sample_negatives(targets_masked_only, targets_masked_only.size(0)) # T'xC # T'
|
||||
# NxT'xC
|
||||
|
||||
# Calculate similarity between logits and all targets
|
||||
similarity_scores = self._calculate_similarity(out_masked_only, negatives, targets_masked_only)
|
||||
# (1+N)xT'
|
||||
# cosine similarity of outs with targets + N negatives
|
||||
|
||||
# Create targets of size T
|
||||
similarity_targets = decoder_outputs.new_zeros(similarity_scores.size(1), dtype=torch.long)
|
||||
# T'
|
||||
# targets are 0, since it's the first, followed by N sampled negatives
|
||||
|
||||
# Transpose similarity scores to TxF for loss
|
||||
similarity_scores = similarity_scores.transpose(0, 1)
|
||||
# T'x(1+N)
|
||||
|
||||
loss = F.cross_entropy(similarity_scores, similarity_targets, reduction=self.reduce)
|
||||
|
||||
sample_size = similarity_targets.numel()
|
||||
|
||||
if self.prob_ppl_weight != 0 and self.quantized_targets:
|
||||
prob_ppl_loss = self.prob_ppl_weight * prob_ppl_loss * sample_size
|
||||
loss += prob_ppl_loss
|
||||
|
||||
if not isinstance(loss, torch.Tensor):
|
||||
loss = torch.Tensor([0]).to(device=decoder_outputs.device)
|
||||
|
||||
return loss
|
||||
|
||||
def _calculate_similarity(self, logits, negatives, targets):
|
||||
neg_is_pos = (targets == negatives).all(-1)
|
||||
# NxT' - true where the negative is actually the positive
|
||||
targets = targets.unsqueeze(0)
|
||||
# 1xT'xC
|
||||
targets = torch.cat([targets, negatives], dim=0)
|
||||
# (1+N)xT'XC
|
||||
logits = torch.cosine_similarity(
|
||||
logits.float().unsqueeze(0).expand(targets.shape[0], -1, -1), targets.float(), dim=-1
|
||||
).type_as(logits)
|
||||
# (1+N)xT'
|
||||
logits /= self.logit_temp
|
||||
if neg_is_pos.any():
|
||||
logits[1:][neg_is_pos] = float("-inf")
|
||||
return logits
|
||||
|
||||
def set_num_updates(self, num_updates):
|
||||
self.quantizer.set_num_updates(num_updates)
|
288
nemo/collections/asr/models/ssl_models.py
Normal file
288
nemo/collections/asr/models/ssl_models.py
Normal file
|
@ -0,0 +1,288 @@
|
|||
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from math import ceil
|
||||
from typing import Dict, Optional, Union
|
||||
|
||||
import torch
|
||||
from omegaconf import DictConfig, open_dict
|
||||
from pytorch_lightning import Trainer
|
||||
|
||||
from nemo.collections.asr.data import audio_to_text_dataset
|
||||
from nemo.collections.asr.models.asr_model import ExportableEncDecModel
|
||||
from nemo.collections.asr.parts.mixins import ASRModuleMixin
|
||||
from nemo.collections.asr.parts.preprocessing.perturb import process_augmentations
|
||||
from nemo.core.classes import ModelPT
|
||||
from nemo.core.classes.common import PretrainedModelInfo, typecheck
|
||||
from nemo.core.neural_types import AudioSignal, LengthsType, NeuralType, SpectrogramType, VoidType
|
||||
from nemo.utils import logging
|
||||
|
||||
__all__ = ['SpeechEncDecSelfSupervisedModel']
|
||||
|
||||
|
||||
class SpeechEncDecSelfSupervisedModel(ModelPT, ASRModuleMixin):
|
||||
"""Base class for encoder-decoder models used for self-supervised encoder pre-training"""
|
||||
|
||||
@classmethod
|
||||
def list_available_models(cls) -> Optional[PretrainedModelInfo]:
|
||||
"""
|
||||
This method returns a list of pre-trained model which can be instantiated directly from NVIDIA's NGC cloud.
|
||||
|
||||
Returns:
|
||||
List of available pre-trained models.
|
||||
"""
|
||||
return []
|
||||
|
||||
def __init__(self, cfg: DictConfig, trainer: Trainer = None):
|
||||
# Get global rank and total number of GPU workers for IterableDataset partitioning, if applicable
|
||||
# Global_rank and local_rank is set by LightningModule in Lightning 1.2.0
|
||||
self.world_size = 1
|
||||
if trainer is not None:
|
||||
self.world_size = trainer.num_nodes * trainer.num_gpus
|
||||
|
||||
super().__init__(cfg=cfg, trainer=trainer)
|
||||
self.preprocessor = SpeechEncDecSelfSupervisedModel.from_config_dict(self._cfg.preprocessor)
|
||||
self.encoder = SpeechEncDecSelfSupervisedModel.from_config_dict(self._cfg.encoder)
|
||||
|
||||
with open_dict(self._cfg):
|
||||
if "feat_in" not in self._cfg.decoder or (
|
||||
not self._cfg.decoder.feat_in and hasattr(self.encoder, '_feat_out')
|
||||
):
|
||||
self._cfg.decoder.feat_in = self.encoder._feat_out
|
||||
if "feat_in" not in self._cfg.decoder or not self._cfg.decoder.feat_in:
|
||||
raise ValueError("param feat_in of the decoder's config is not set!")
|
||||
|
||||
self.decoder_ssl = SpeechEncDecSelfSupervisedModel.from_config_dict(self._cfg.decoder)
|
||||
self.loss = SpeechEncDecSelfSupervisedModel.from_config_dict(self._cfg.loss)
|
||||
|
||||
self.spec_augmentation = SpeechEncDecSelfSupervisedModel.from_config_dict(self._cfg.spec_augment)
|
||||
|
||||
def _setup_dataloader_from_config(self, config: Optional[Dict]):
|
||||
if 'augmentor' in config:
|
||||
augmentor = process_augmentations(config['augmentor'])
|
||||
else:
|
||||
augmentor = None
|
||||
|
||||
# Automatically inject args from model config to dataloader config
|
||||
audio_to_text_dataset.inject_dataloader_value_from_model_config(self.cfg, config, key='sample_rate')
|
||||
|
||||
shuffle = config['shuffle']
|
||||
|
||||
# Instantiate tarred dataset loader or normal dataset loader
|
||||
if config.get('is_tarred', False):
|
||||
if ('tarred_audio_filepaths' in config and config['tarred_audio_filepaths'] is None) or (
|
||||
'manifest_filepath' in config and config['manifest_filepath'] is None
|
||||
):
|
||||
logging.warning(
|
||||
"Could not load dataset as `manifest_filepath` was None or "
|
||||
f"`tarred_audio_filepaths` is None. Provided config : {config}"
|
||||
)
|
||||
return None
|
||||
|
||||
shuffle_n = config.get('shuffle_n', 4 * config['batch_size']) if shuffle else 0
|
||||
dataset = audio_to_text_dataset.get_tarred_dataset(
|
||||
config=config,
|
||||
shuffle_n=shuffle_n,
|
||||
global_rank=self.global_rank,
|
||||
world_size=self.world_size,
|
||||
augmentor=augmentor,
|
||||
)
|
||||
shuffle = False
|
||||
else:
|
||||
if 'manifest_filepath' in config and config['manifest_filepath'] is None:
|
||||
logging.warning(f"Could not load dataset as `manifest_filepath` was None. Provided config : {config}")
|
||||
return None
|
||||
|
||||
dataset = audio_to_text_dataset.get_char_dataset(config=config, augmentor=augmentor)
|
||||
|
||||
return torch.utils.data.DataLoader(
|
||||
dataset=dataset,
|
||||
batch_size=config['batch_size'],
|
||||
collate_fn=dataset.collate_fn,
|
||||
drop_last=config.get('drop_last', False),
|
||||
shuffle=shuffle,
|
||||
num_workers=config.get('num_workers', 0),
|
||||
pin_memory=config.get('pin_memory', False),
|
||||
)
|
||||
|
||||
def setup_training_data(self, train_data_config: Optional[Union[DictConfig, Dict]]):
|
||||
"""
|
||||
Sets up the training data loader via a Dict-like object.
|
||||
|
||||
Args:
|
||||
train_data_config: A config that contains the information regarding construction
|
||||
of an ASR Training dataset.
|
||||
|
||||
Supported Datasets:
|
||||
- :class:`~nemo.collections.asr.data.audio_to_text.AudioToCharDataset`
|
||||
- :class:`~nemo.collections.asr.data.audio_to_text.AudioToBPEDataset`
|
||||
- :class:`~nemo.collections.asr.data.audio_to_text.TarredAudioToCharDataset`
|
||||
- :class:`~nemo.collections.asr.data.audio_to_text.TarredAudioToBPEDataset`
|
||||
- :class:`~nemo.collections.asr.data.audio_to_text_dali.AudioToCharDALIDataset`
|
||||
"""
|
||||
if 'shuffle' not in train_data_config:
|
||||
train_data_config['shuffle'] = True
|
||||
|
||||
# preserve config
|
||||
self._update_dataset_config(dataset_name='train', config=train_data_config)
|
||||
|
||||
self._train_dl = self._setup_dataloader_from_config(config=train_data_config)
|
||||
|
||||
# Need to set this because if using an IterableDataset, the length of the dataloader is the total number
|
||||
# of samples rather than the number of batches, and this messes up the tqdm progress bar.
|
||||
# So we set the number of steps manually (to the correct number) to fix this.
|
||||
if 'is_tarred' in train_data_config and train_data_config['is_tarred']:
|
||||
# We also need to check if limit_train_batches is already set.
|
||||
# If it's an int, we assume that the user has set it to something sane, i.e. <= # training batches,
|
||||
# and don't change it. Otherwise, adjust batches accordingly if it's a float (including 1.0).
|
||||
if isinstance(self._trainer.limit_train_batches, float):
|
||||
self._trainer.limit_train_batches = int(
|
||||
self._trainer.limit_train_batches
|
||||
* ceil((len(self._train_dl.dataset) / self.world_size) / train_data_config['batch_size'])
|
||||
)
|
||||
|
||||
def setup_validation_data(self, val_data_config: Optional[Union[DictConfig, Dict]]):
|
||||
"""
|
||||
Sets up the validation data loader via a Dict-like object.
|
||||
|
||||
Args:
|
||||
val_data_config: A config that contains the information regarding construction
|
||||
of an ASR Training dataset.
|
||||
|
||||
Supported Datasets:
|
||||
- :class:`~nemo.collections.asr.data.audio_to_text.AudioToCharDataset`
|
||||
- :class:`~nemo.collections.asr.data.audio_to_text.AudioToBPEDataset`
|
||||
- :class:`~nemo.collections.asr.data.audio_to_text.TarredAudioToCharDataset`
|
||||
- :class:`~nemo.collections.asr.data.audio_to_text.TarredAudioToBPEDataset`
|
||||
- :class:`~nemo.collections.asr.data.audio_to_text_dali.AudioToCharDALIDataset`
|
||||
"""
|
||||
if 'shuffle' not in val_data_config:
|
||||
val_data_config['shuffle'] = False
|
||||
|
||||
# preserve config
|
||||
self._update_dataset_config(dataset_name='validation', config=val_data_config)
|
||||
|
||||
self._validation_dl = self._setup_dataloader_from_config(config=val_data_config)
|
||||
|
||||
# Need to set this because if using an IterableDataset, the length of the dataloader is the total number
|
||||
# of samples rather than the number of batches, and this messes up the tqdm progress bar.
|
||||
# So we set the number of steps manually (to the correct number) to fix this.
|
||||
if 'is_tarred' in val_data_config and val_data_config['is_tarred']:
|
||||
# We also need to check if limit_train_batches is already set.
|
||||
# If it's an int, we assume that the user has set it to something sane, i.e. <= # training batches,
|
||||
# and don't change it. Otherwise, adjust batches accordingly if it's a float (including 1.0).
|
||||
if isinstance(self._trainer.limit_val_batches, float):
|
||||
self._trainer.limit_val_batches = int(
|
||||
self._trainer.limit_val_batches
|
||||
* ceil((len(self._validation_dl.dataset) / self.world_size) / val_data_config['batch_size'])
|
||||
)
|
||||
|
||||
@property
|
||||
def input_types(self) -> Optional[Dict[str, NeuralType]]:
|
||||
if hasattr(self.preprocessor, '_sample_rate'):
|
||||
input_signal_eltype = AudioSignal(freq=self.preprocessor._sample_rate)
|
||||
else:
|
||||
input_signal_eltype = AudioSignal()
|
||||
return {
|
||||
"input_signal": NeuralType(('B', 'T'), input_signal_eltype, optional=True),
|
||||
"input_signal_length": NeuralType(tuple('B'), LengthsType(), optional=True),
|
||||
"processed_signal": NeuralType(('B', 'D', 'T'), SpectrogramType(), optional=True),
|
||||
"processed_signal_length": NeuralType(tuple('B'), LengthsType(), optional=True),
|
||||
}
|
||||
|
||||
@property
|
||||
def output_types(self) -> Optional[Dict[str, NeuralType]]:
|
||||
return {
|
||||
"spectrograms": NeuralType(('B', 'D', 'T'), SpectrogramType()),
|
||||
"spec_masks": NeuralType(('B', 'D', 'T'), SpectrogramType()),
|
||||
"outputs": NeuralType(('B', 'T', 'D'), VoidType()),
|
||||
}
|
||||
|
||||
@typecheck()
|
||||
def forward(
|
||||
self, input_signal=None, input_signal_length=None, processed_signal=None, processed_signal_length=None
|
||||
):
|
||||
"""
|
||||
Forward pass of the model.
|
||||
|
||||
Args:
|
||||
input_signal: Tensor that represents a batch of raw audio signals,
|
||||
of shape [B, T]. T here represents timesteps, with 1 second of audio represented as
|
||||
`self.sample_rate` number of floating point values.
|
||||
input_signal_length: Vector of length B, that contains the individual lengths of the audio
|
||||
sequences.
|
||||
processed_signal: Tensor that represents a batch of processed audio signals,
|
||||
of shape (B, D, T) that has undergone processing via some DALI preprocessor.
|
||||
processed_signal_length: Vector of length B, that contains the individual lengths of the
|
||||
processed audio sequences.
|
||||
|
||||
Returns:
|
||||
A tuple of 3 elements -
|
||||
1) Processed spectrograms of shape [B, D, T].
|
||||
2) Masks applied to spectrograms of shape [B, D, T].
|
||||
3) Decoder outputs of shape [B, T, D].
|
||||
"""
|
||||
has_input_signal = input_signal is not None and input_signal_length is not None
|
||||
has_processed_signal = processed_signal is not None and processed_signal_length is not None
|
||||
if (has_input_signal ^ has_processed_signal) == False:
|
||||
raise ValueError(
|
||||
f"{self} Arguments ``input_signal`` and ``input_signal_length`` are mutually exclusive "
|
||||
" with ``processed_signal`` and ``processed_signal_len`` arguments."
|
||||
)
|
||||
|
||||
if not has_processed_signal:
|
||||
processed_signal, processed_signal_length = self.preprocessor(
|
||||
input_signal=input_signal, length=input_signal_length,
|
||||
)
|
||||
|
||||
spectrograms = processed_signal.detach().clone()
|
||||
|
||||
processed_signal = self.spec_augmentation(input_spec=processed_signal, length=processed_signal_length)
|
||||
|
||||
masked_spectrograms = processed_signal.detach()
|
||||
spec_masks = torch.logical_and(masked_spectrograms < 1e-5, masked_spectrograms > -1e-5).float()
|
||||
for idx, proc_len in enumerate(processed_signal_length):
|
||||
spec_masks[idx, :, proc_len:] = 0.0
|
||||
|
||||
encoded, encoded_len = self.encoder(audio_signal=processed_signal, length=processed_signal_length)
|
||||
|
||||
outputs = self.decoder_ssl(encoder_output=encoded)
|
||||
|
||||
return spectrograms, spec_masks, outputs
|
||||
|
||||
# PTL-specific methods
|
||||
def training_step(self, batch, batch_nb):
|
||||
signal, signal_len, transcript, transcript_len = batch
|
||||
spectrograms, spec_masks, outputs = self.forward(input_signal=signal, input_signal_length=signal_len)
|
||||
|
||||
self.loss.set_num_updates(self.trainer.global_step)
|
||||
loss_value = self.loss(spectrograms=spectrograms, spec_masks=spec_masks, decoder_outputs=outputs)
|
||||
|
||||
tensorboard_logs = {'train_loss': loss_value, 'learning_rate': self._optimizer.param_groups[0]['lr']}
|
||||
|
||||
return {'loss': loss_value, 'log': tensorboard_logs}
|
||||
|
||||
def validation_step(self, batch, batch_idx, dataloader_idx=0):
|
||||
signal, signal_len, transcript, transcript_len = batch
|
||||
spectrograms, spec_masks, outputs = self.forward(input_signal=signal, input_signal_length=signal_len)
|
||||
|
||||
loss_value = self.loss(spectrograms=spectrograms, spec_masks=spec_masks, decoder_outputs=outputs)
|
||||
|
||||
return {
|
||||
'val_loss': loss_value,
|
||||
}
|
||||
|
||||
def multi_validation_epoch_end(self, outputs, dataloader_idx: int = 0):
|
||||
val_loss_mean = torch.stack([x['val_loss'] for x in outputs]).mean()
|
||||
tensorboard_logs = {'val_loss': val_loss_mean}
|
||||
return {'val_loss': val_loss_mean, 'log': tensorboard_logs}
|
|
@ -28,6 +28,7 @@ try:
|
|||
from nemo.collections.asr.modules.conv_asr import (
|
||||
ConvASRDecoder,
|
||||
ConvASRDecoderClassification,
|
||||
ConvASRDecoderReconstruction,
|
||||
ConvASREncoder,
|
||||
ECAPAEncoder,
|
||||
ParallelConvASREncoder,
|
||||
|
@ -43,4 +44,5 @@ except ModuleNotFoundError:
|
|||
class ECAPAEncoder(CheckInstall): pass
|
||||
class ParallelConvASREncoder(CheckInstall): pass
|
||||
class SpeakerDecoder(CheckInstall): pass
|
||||
class ConvASRDecoderReconstruction(CheckInstall): pass
|
||||
# fmt: on
|
||||
|
|
|
@ -468,6 +468,91 @@ class ConvASRDecoder(NeuralModule, Exportable):
|
|||
return self._num_classes
|
||||
|
||||
|
||||
class ConvASRDecoderReconstruction(NeuralModule, Exportable):
|
||||
"""ASR Decoder for reconstructing masked regions of spectrogram
|
||||
"""
|
||||
|
||||
@property
|
||||
def input_types(self):
|
||||
return OrderedDict({"encoder_output": NeuralType(('B', 'D', 'T'), AcousticEncodedRepresentation())})
|
||||
|
||||
@property
|
||||
def output_types(self):
|
||||
return OrderedDict({"spec_recon": NeuralType(('B', 'T', 'D'), SpectrogramType())})
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
feat_in,
|
||||
feat_out,
|
||||
feat_hidden,
|
||||
stride_layers,
|
||||
kernel_size=11,
|
||||
init_mode="xavier_uniform",
|
||||
activation="relu",
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if stride_layers > 0 and (kernel_size < 3 or kernel_size % 2 == 0):
|
||||
raise ValueError(
|
||||
"Kernel size in this decoder needs to be >= 3 and odd when using at least 1 stride layer."
|
||||
)
|
||||
|
||||
activation = jasper_activations[activation]()
|
||||
|
||||
self.feat_in = feat_in
|
||||
self.feat_out = feat_out
|
||||
self.feat_hidden = feat_hidden
|
||||
|
||||
self.decoder_layers = [nn.Conv1d(self.feat_in, self.feat_hidden, kernel_size=1, bias=True)]
|
||||
for i in range(stride_layers):
|
||||
self.decoder_layers.append(activation)
|
||||
self.decoder_layers.append(
|
||||
nn.ConvTranspose1d(
|
||||
self.feat_hidden,
|
||||
self.feat_hidden,
|
||||
kernel_size,
|
||||
stride=2,
|
||||
padding=(kernel_size - 3) // 2 + 1,
|
||||
output_padding=1,
|
||||
bias=True,
|
||||
)
|
||||
)
|
||||
self.decoder_layers.append(nn.Conv1d(self.feat_hidden, self.feat_hidden, kernel_size=1, bias=True))
|
||||
self.decoder_layers.append(nn.BatchNorm1d(self.feat_hidden, eps=1e-3, momentum=0.1))
|
||||
|
||||
self.decoder_layers.append(activation)
|
||||
self.decoder_layers.append(nn.Conv1d(self.feat_hidden, self.feat_out, kernel_size=1, bias=True))
|
||||
|
||||
self.decoder_layers = nn.Sequential(*self.decoder_layers)
|
||||
|
||||
self.apply(lambda x: init_weights(x, mode=init_mode))
|
||||
|
||||
@typecheck()
|
||||
def forward(self, encoder_output):
|
||||
return self.decoder_layers(encoder_output).transpose(-2, -1)
|
||||
|
||||
def input_example(self):
|
||||
"""
|
||||
Generates input examples for tracing etc.
|
||||
Returns:
|
||||
A tuple of input examples.
|
||||
"""
|
||||
bs = 8
|
||||
seq = 64
|
||||
input_example = torch.randn(bs, self._feat_in, seq).to(next(self.parameters()).device)
|
||||
return tuple([input_example])
|
||||
|
||||
def _prepare_for_export(self, **kwargs):
|
||||
m_count = 0
|
||||
for m in self.modules():
|
||||
if type(m).__name__ == "MaskedConv1d":
|
||||
m.use_mask = False
|
||||
m_count += 1
|
||||
if m_count > 0:
|
||||
logging.warning(f"Turned off {m_count} masked convolutions")
|
||||
Exportable._prepare_for_export(self, **kwargs)
|
||||
|
||||
|
||||
class ConvASRDecoderClassification(NeuralModule, Exportable):
|
||||
"""Simple ASR Decoder for use with classification models such as JasperNet and QuartzNet
|
||||
|
||||
|
|
|
@ -103,15 +103,11 @@ def __parse_item(line: str, manifest_file: str) -> Dict[str, Any]:
|
|||
item['text'] = f.read().replace('\n', '')
|
||||
elif 'normalized_text' in item:
|
||||
item['text'] = item['normalized_text']
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Manifest file {manifest_file} has invalid json line structure: {line} without proper text key."
|
||||
)
|
||||
|
||||
item = dict(
|
||||
audio_file=item['audio_file'],
|
||||
duration=item['duration'],
|
||||
text=item['text'],
|
||||
text=item.get('text', ""),
|
||||
offset=item.get('offset', None),
|
||||
speaker=item.get('speaker', None),
|
||||
orig_sr=item.get('orig_sample_rate', None),
|
||||
|
|
|
@ -892,7 +892,9 @@ class ModelPT(LightningModule, Model):
|
|||
with open_dict(cfg):
|
||||
# Restore model
|
||||
model_path = cfg.pop('init_from_nemo_model')
|
||||
restored_model = self.restore_from(model_path, map_location=map_location, strict=True)
|
||||
restored_model = self.restore_from(
|
||||
model_path, map_location=map_location, strict=cfg.get("init_strict", True)
|
||||
)
|
||||
|
||||
# Restore checkpoint into current model
|
||||
self.load_state_dict(restored_model.state_dict(), strict=False)
|
||||
|
@ -918,7 +920,9 @@ class ModelPT(LightningModule, Model):
|
|||
)
|
||||
return
|
||||
|
||||
restored_model = self.from_pretrained(model_name, map_location=map_location, strict=True)
|
||||
restored_model = self.from_pretrained(
|
||||
model_name, map_location=map_location, strict=cfg.get("init_strict", True)
|
||||
)
|
||||
|
||||
# Restore checkpoint into current model
|
||||
self.load_state_dict(restored_model.state_dict(), strict=False)
|
||||
|
|
Loading…
Reference in a new issue