[Jasper/PyT] Clean up inference flags
This commit is contained in:
parent
568e246408
commit
706ef498c9
|
@ -439,6 +439,7 @@ LOG_FILE: path to the DLLogger .json logfile. (default: '')
|
|||
CUDNN_BENCHMARK: enable cudnn benchmark mode for using more optimized kernels. (default: false)
|
||||
MAX_DURATION: filter out recordings shorter then MAX_DURATION seconds. (default: "")
|
||||
PAD_TO_MAX_DURATION: pad all sequences with zeros to maximum length. (default: false)
|
||||
PAD_LEADING: pad every batch with leading zeros to counteract conv shifts of the field of view. (default: 16)
|
||||
NUM_GPUS: number of GPUs to use. Note that with > 1 GPUs WER results might be inaccurate due to the batching policy. (default: 1)
|
||||
NUM_STEPS: number of batches to evaluate, loop the dataset if necessary. (default: 0)
|
||||
NUM_WARMUP_STEPS: number of initial steps before measuring performance. (default: 0)
|
||||
|
@ -464,6 +465,7 @@ BATCH_SIZE_SEQ: batch sizes to measure on. (default: "1 2 4 8 16")
|
|||
MAX_DURATION_SEQ: input durations (in seconds) to measure on (default: "2 7 16.7")
|
||||
CUDNN_BENCHMARK: (default: true)
|
||||
PAD_TO_MAX_DURATION: (default: true)
|
||||
PAD_LEADING: (default: 0)
|
||||
NUM_WARMUP_STEPS: (default: 10)
|
||||
NUM_STEPS: (default: 500)
|
||||
DALI_DEVICE: (default: cpu)
|
||||
|
|
|
@ -26,6 +26,7 @@ import dllogger
|
|||
import torch
|
||||
import numpy as np
|
||||
import torch.distributed as distrib
|
||||
import torch.nn.functional as F
|
||||
from apex import amp
|
||||
from apex.parallel import DistributedDataParallel
|
||||
from dllogger import JSONStreamBackend, StdOutBackend, Verbosity
|
||||
|
@ -57,8 +58,9 @@ def get_parser():
|
|||
help='Relative path to evaluation dataset manifest files')
|
||||
parser.add_argument('--ckpt', default=None, type=str,
|
||||
help='Path to model checkpoint')
|
||||
parser.add_argument("--max_duration", default=None, type=float, help='maximum duration of sequences. if None uses attribute from model configuration file')
|
||||
parser.add_argument("--pad_to_max_duration", action='store_true', help='pad to maximum duration of sequences')
|
||||
parser.add_argument('--pad_leading', type=int, default=16,
|
||||
help='Pads every batch with leading zeros '
|
||||
'to counteract conv shifts of the field of view')
|
||||
parser.add_argument('--amp', '--fp16', action='store_true',
|
||||
help='Use FP16 precision')
|
||||
parser.add_argument('--cudnn_benchmark', action='store_true',
|
||||
|
@ -210,7 +212,6 @@ def main():
|
|||
print("DALI supported only with input .json files; disabling")
|
||||
use_dali = False
|
||||
|
||||
assert not args.pad_to_max_duration
|
||||
assert not (args.transcribe_wav and args.transcribe_filelist)
|
||||
|
||||
if args.transcribe_wav:
|
||||
|
@ -226,6 +227,7 @@ def main():
|
|||
drop_last=(True if measure_perf else False))
|
||||
|
||||
_, features_kw = config.input(cfg, 'val')
|
||||
assert not features_kw['pad_to_max_duration']
|
||||
feat_proc = FilterbankFeatures(**features_kw)
|
||||
|
||||
elif use_dali:
|
||||
|
@ -327,6 +329,9 @@ def main():
|
|||
if args.amp:
|
||||
feats = feats.half()
|
||||
|
||||
feats = F.pad(feats, (args.pad_leading, 0))
|
||||
feat_lens += args.pad_leading
|
||||
|
||||
if model.encoder.use_conv_masks:
|
||||
log_probs, log_prob_lens = model(feats, feat_lens)
|
||||
else:
|
||||
|
|
|
@ -23,6 +23,7 @@
|
|||
: ${CUDNN_BENCHMARK:=false}
|
||||
: ${MAX_DURATION:=""}
|
||||
: ${PAD_TO_MAX_DURATION:=false}
|
||||
: ${PAD_LEADING:=16}
|
||||
: ${NUM_GPUS:=1}
|
||||
: ${NUM_STEPS:=0}
|
||||
: ${NUM_WARMUP_STEPS:=0}
|
||||
|
@ -46,6 +47,7 @@ ARGS+=" --seed=$SEED"
|
|||
ARGS+=" --dali_device=$DALI_DEVICE"
|
||||
ARGS+=" --steps $NUM_STEPS"
|
||||
ARGS+=" --warmup_steps $NUM_WARMUP_STEPS"
|
||||
ARGS+=" --pad_leading $PAD_LEADING"
|
||||
|
||||
[ "$AMP" = true ] && ARGS+=" --amp"
|
||||
[ "$EMA" = true ] && ARGS+=" --ema"
|
||||
|
|
|
@ -19,6 +19,7 @@ set -a
|
|||
: ${OUTPUT_DIR:=${3:-"/results"}}
|
||||
: ${CUDNN_BENCHMARK:=true}
|
||||
: ${PAD_TO_MAX_DURATION:=true}
|
||||
: ${PAD_LEADING:=0}
|
||||
: ${NUM_WARMUP_STEPS:=10}
|
||||
: ${NUM_STEPS:=500}
|
||||
|
||||
|
|
|
@ -74,14 +74,7 @@ def get_dataloader(model_args_list):
|
|||
return None
|
||||
|
||||
cfg = config.load(args.model_config)
|
||||
|
||||
if args.max_duration is not None:
|
||||
cfg['input_val']['audio_dataset']['max_duration'] = args.max_duration
|
||||
cfg['input_val']['filterbank_features']['max_duration'] = args.max_duration
|
||||
|
||||
if args.pad_to_max_duration:
|
||||
assert cfg['input_train']['audio_dataset']['max_duration'] > 0
|
||||
cfg['input_train']['audio_dataset']['pad_to_max_duration'] = True
|
||||
config.apply_config_overrides(cfg, args)
|
||||
|
||||
symbols = add_ctc_blank(cfg['labels'])
|
||||
|
||||
|
@ -108,15 +101,7 @@ def init_feature_extractor(args):
|
|||
from common.features import FilterbankFeatures
|
||||
|
||||
cfg = config.load(args.model_config)
|
||||
|
||||
if args.max_duration is not None:
|
||||
cfg['input_val']['audio_dataset']['max_duration'] = args.max_duration
|
||||
cfg['input_val']['filterbank_features']['max_duration'] = args.max_duration
|
||||
|
||||
if args.pad_to_max_duration:
|
||||
assert cfg['input_train']['audio_dataset']['max_duration'] > 0
|
||||
cfg['input_train']['audio_dataset']['pad_to_max_duration'] = True
|
||||
|
||||
config.apply_config_overrides(cfg, args)
|
||||
_, features_kw = config.input(cfg, 'val')
|
||||
|
||||
feature_proc = FilterbankFeatures(**features_kw)
|
||||
|
@ -131,14 +116,7 @@ def init_acoustic_model(args):
|
|||
from jasper import config
|
||||
|
||||
cfg = config.load(args.model_config)
|
||||
|
||||
if args.max_duration is not None:
|
||||
cfg['input_val']['audio_dataset']['max_duration'] = args.max_duration
|
||||
cfg['input_val']['filterbank_features']['max_duration'] = args.max_duration
|
||||
|
||||
if args.pad_to_max_duration:
|
||||
assert cfg['input_train']['audio_dataset']['max_duration'] > 0
|
||||
cfg['input_train']['audio_dataset']['pad_to_max_duration'] = True
|
||||
config.apply_config_overrides(cfg, args)
|
||||
|
||||
if cfg['jasper']['encoder']['use_conv_masks'] == True:
|
||||
print("[Jasper module]: Warning: setting 'use_conv_masks' \
|
||||
|
|
Loading…
Reference in a new issue