[Jasper/PyT] Clean up inference flags

This commit is contained in:
Mikolaj Blaz 2021-08-13 19:25:20 +00:00 committed by Krzysztof Kudrynski
parent 568e246408
commit 706ef498c9
5 changed files with 16 additions and 28 deletions

View file

@ -439,6 +439,7 @@ LOG_FILE: path to the DLLogger .json logfile. (default: '')
CUDNN_BENCHMARK: enable cudnn benchmark mode for using more optimized kernels. (default: false)
MAX_DURATION: filter out recordings shorter then MAX_DURATION seconds. (default: "")
PAD_TO_MAX_DURATION: pad all sequences with zeros to maximum length. (default: false)
PAD_LEADING: pad every batch with leading zeros to counteract conv shifts of the field of view. (default: 16)
NUM_GPUS: number of GPUs to use. Note that with > 1 GPUs WER results might be inaccurate due to the batching policy. (default: 1)
NUM_STEPS: number of batches to evaluate, loop the dataset if necessary. (default: 0)
NUM_WARMUP_STEPS: number of initial steps before measuring performance. (default: 0)
@ -464,6 +465,7 @@ BATCH_SIZE_SEQ: batch sizes to measure on. (default: "1 2 4 8 16")
MAX_DURATION_SEQ: input durations (in seconds) to measure on (default: "2 7 16.7")
CUDNN_BENCHMARK: (default: true)
PAD_TO_MAX_DURATION: (default: true)
PAD_LEADING: (default: 0)
NUM_WARMUP_STEPS: (default: 10)
NUM_STEPS: (default: 500)
DALI_DEVICE: (default: cpu)

View file

@ -26,6 +26,7 @@ import dllogger
import torch
import numpy as np
import torch.distributed as distrib
import torch.nn.functional as F
from apex import amp
from apex.parallel import DistributedDataParallel
from dllogger import JSONStreamBackend, StdOutBackend, Verbosity
@ -57,8 +58,9 @@ def get_parser():
help='Relative path to evaluation dataset manifest files')
parser.add_argument('--ckpt', default=None, type=str,
help='Path to model checkpoint')
parser.add_argument("--max_duration", default=None, type=float, help='maximum duration of sequences. if None uses attribute from model configuration file')
parser.add_argument("--pad_to_max_duration", action='store_true', help='pad to maximum duration of sequences')
parser.add_argument('--pad_leading', type=int, default=16,
help='Pads every batch with leading zeros '
'to counteract conv shifts of the field of view')
parser.add_argument('--amp', '--fp16', action='store_true',
help='Use FP16 precision')
parser.add_argument('--cudnn_benchmark', action='store_true',
@ -210,7 +212,6 @@ def main():
print("DALI supported only with input .json files; disabling")
use_dali = False
assert not args.pad_to_max_duration
assert not (args.transcribe_wav and args.transcribe_filelist)
if args.transcribe_wav:
@ -226,6 +227,7 @@ def main():
drop_last=(True if measure_perf else False))
_, features_kw = config.input(cfg, 'val')
assert not features_kw['pad_to_max_duration']
feat_proc = FilterbankFeatures(**features_kw)
elif use_dali:
@ -327,6 +329,9 @@ def main():
if args.amp:
feats = feats.half()
feats = F.pad(feats, (args.pad_leading, 0))
feat_lens += args.pad_leading
if model.encoder.use_conv_masks:
log_probs, log_prob_lens = model(feats, feat_lens)
else:

View file

@ -23,6 +23,7 @@
: ${CUDNN_BENCHMARK:=false}
: ${MAX_DURATION:=""}
: ${PAD_TO_MAX_DURATION:=false}
: ${PAD_LEADING:=16}
: ${NUM_GPUS:=1}
: ${NUM_STEPS:=0}
: ${NUM_WARMUP_STEPS:=0}
@ -46,6 +47,7 @@ ARGS+=" --seed=$SEED"
ARGS+=" --dali_device=$DALI_DEVICE"
ARGS+=" --steps $NUM_STEPS"
ARGS+=" --warmup_steps $NUM_WARMUP_STEPS"
ARGS+=" --pad_leading $PAD_LEADING"
[ "$AMP" = true ] && ARGS+=" --amp"
[ "$EMA" = true ] && ARGS+=" --ema"

View file

@ -19,6 +19,7 @@ set -a
: ${OUTPUT_DIR:=${3:-"/results"}}
: ${CUDNN_BENCHMARK:=true}
: ${PAD_TO_MAX_DURATION:=true}
: ${PAD_LEADING:=0}
: ${NUM_WARMUP_STEPS:=10}
: ${NUM_STEPS:=500}

View file

@ -74,14 +74,7 @@ def get_dataloader(model_args_list):
return None
cfg = config.load(args.model_config)
if args.max_duration is not None:
cfg['input_val']['audio_dataset']['max_duration'] = args.max_duration
cfg['input_val']['filterbank_features']['max_duration'] = args.max_duration
if args.pad_to_max_duration:
assert cfg['input_train']['audio_dataset']['max_duration'] > 0
cfg['input_train']['audio_dataset']['pad_to_max_duration'] = True
config.apply_config_overrides(cfg, args)
symbols = add_ctc_blank(cfg['labels'])
@ -108,15 +101,7 @@ def init_feature_extractor(args):
from common.features import FilterbankFeatures
cfg = config.load(args.model_config)
if args.max_duration is not None:
cfg['input_val']['audio_dataset']['max_duration'] = args.max_duration
cfg['input_val']['filterbank_features']['max_duration'] = args.max_duration
if args.pad_to_max_duration:
assert cfg['input_train']['audio_dataset']['max_duration'] > 0
cfg['input_train']['audio_dataset']['pad_to_max_duration'] = True
config.apply_config_overrides(cfg, args)
_, features_kw = config.input(cfg, 'val')
feature_proc = FilterbankFeatures(**features_kw)
@ -131,14 +116,7 @@ def init_acoustic_model(args):
from jasper import config
cfg = config.load(args.model_config)
if args.max_duration is not None:
cfg['input_val']['audio_dataset']['max_duration'] = args.max_duration
cfg['input_val']['filterbank_features']['max_duration'] = args.max_duration
if args.pad_to_max_duration:
assert cfg['input_train']['audio_dataset']['max_duration'] > 0
cfg['input_train']['audio_dataset']['pad_to_max_duration'] = True
config.apply_config_overrides(cfg, args)
if cfg['jasper']['encoder']['use_conv_masks'] == True:
print("[Jasper module]: Warning: setting 'use_conv_masks' \