[Jasper/PyT] Minor update in metrics and CLI params

This commit is contained in:
Mikolaj Blaz 2021-06-30 17:38:22 +00:00 committed by Krzysztof Kudrynski
parent 36c2d7e8b8
commit 0d4dd6b523
6 changed files with 71 additions and 31 deletions

View file

@ -0,0 +1,20 @@
import numpy as np
class BenchmarkStats:
""" Tracks statistics used for benchmarking. """
def __init__(self):
self.utts = []
self.times = []
self.losses = []
def update(self, utts, times, losses):
self.utts.append(utts)
self.times.append(times)
self.losses.append(losses)
def get(self, n_epochs):
throughput = sum(self.utts[-n_epochs:]) / sum(self.times[-n_epochs:])
return {'throughput': throughput, 'benchmark_epochs_num': n_epochs,
'loss': np.mean(self.losses[-n_epochs:])}

View file

@ -57,10 +57,6 @@ def get_parser():
help='Relative path to evaluation dataset manifest files')
parser.add_argument('--ckpt', default=None, type=str,
help='Path to model checkpoint')
parser.add_argument('--max_duration', default=None, type=float,
help='Filter out longer inputs (in seconds)')
parser.add_argument('--pad_to_max_duration', action='store_true',
help='Pads every batch to max_duration')
parser.add_argument('--amp', '--fp16', action='store_true',
help='Use FP16 precision')
parser.add_argument('--cudnn_benchmark', action='store_true',
@ -92,6 +88,9 @@ def get_parser():
help='Evaluate with a TorchScripted model')
io.add_argument('--torchscript_export', action='store_true',
help='Export the model with torch.jit to the output_dir')
io.add_argument('--override_config', type=str, action='append',
help='Overrides a value from a config .yaml.'
' Syntax: `--override_config nested.config.key=val`.')
return parser
@ -193,15 +192,7 @@ def main():
print_once(f'Inference with {distrib.get_world_size()} GPUs')
cfg = config.load(args.model_config)
if args.max_duration is not None:
cfg['input_val']['audio_dataset']['max_duration'] = args.max_duration
cfg['input_val']['filterbank_features']['max_duration'] = args.max_duration
if args.pad_to_max_duration:
assert cfg['input_val']['audio_dataset']['max_duration'] > 0
cfg['input_val']['audio_dataset']['pad_to_max_duration'] = True
cfg['input_val']['filterbank_features']['pad_to_max_duration'] = True
config.apply_config_overrides(cfg, args)
symbols = helpers.add_ctc_blank(cfg['labels'])

View file

@ -1,5 +1,10 @@
import copy
import inspect
import typing
from ast import literal_eval
from contextlib import suppress
from numbers import Number
import yaml
from .model import JasperDecoderForCTC, JasperBlock, JasperEncoder
@ -99,12 +104,22 @@ def decoder(conf, n_classes):
return validate_and_fill(JasperDecoderForCTC, decoder_kw)
def apply_duration_flags(cfg, max_duration, pad_to_max_duration):
if max_duration is not None:
cfg['input_train']['audio_dataset']['max_duration'] = max_duration
cfg['input_train']['filterbank_features']['max_duration'] = max_duration
def apply_config_overrides(conf, args):
if args.override_config is None:
return
for override_key_val in args.override_config:
key, val = override_key_val.split('=')
with suppress(TypeError, ValueError):
val = literal_eval(val)
apply_nested_config_override(conf, key, val)
if pad_to_max_duration:
assert cfg['input_train']['audio_dataset']['max_duration'] > 0
cfg['input_train']['audio_dataset']['pad_to_max_duration'] = True
cfg['input_train']['filterbank_features']['pad_to_max_duration'] = True
def apply_nested_config_override(conf, key_str, val):
fields = key_str.split('.')
for f in fields[:-1]:
conf = conf[f]
f = fields[-1]
assert (f not in conf
or type(val) is type(conf[f])
or (isinstance(val, Number) and isinstance(conf[f], Number)))
conf[f] = val

View file

@ -55,7 +55,9 @@ ARGS+=" --warmup_steps $NUM_WARMUP_STEPS"
[ -n "$PREDICTION_FILE" ] && ARGS+=" --save_prediction $PREDICTION_FILE"
[ -n "$LOGITS_FILE" ] && ARGS+=" --logits_save_to $LOGITS_FILE"
[ "$CPU" == "true" ] && ARGS+=" --cpu"
[ -n "$MAX_DURATION" ] && ARGS+=" --max_duration $MAX_DURATION"
[ "$PAD_TO_MAX_DURATION" = true ] && ARGS+=" --pad_to_max_duration"
[ -n "$MAX_DURATION" ] && ARGS+=" --override_config input_val.audio_dataset.max_duration=$MAX_DURATION" \
ARGS+=" --override_config input_val.filterbank_features.max_duration=$MAX_DURATION"
[ "$PAD_TO_MAX_DURATION" = true ] && ARGS+=" --override_config input_val.audio_dataset.pad_to_max_duration=True" \
ARGS+=" --override_config input_val.filterbank_features.pad_to_max_duration=True"
python -m torch.distributed.launch --nproc_per_node=$NUM_GPUS inference.py $ARGS

View file

@ -78,7 +78,10 @@ ARGS+=" --dali_device=$DALI_DEVICE"
[ "$AMP" = true ] && ARGS+=" --amp"
[ "$RESUME" = true ] && ARGS+=" --resume"
[ "$CUDNN_BENCHMARK" = true ] && ARGS+=" --cudnn_benchmark"
[ "$PAD_TO_MAX_DURATION" = true ] && ARGS+=" --pad_to_max_duration"
[ -n "$MAX_DURATION" ] && ARGS+=" --override_config input_train.audio_dataset.max_duration=$MAX_DURATION" \
ARGS+=" --override_config input_train.filterbank_features.max_duration=$MAX_DURATION"
[ "$PAD_TO_MAX_DURATION" = true ] && ARGS+=" --override_config input_train.audio_dataset.pad_to_max_duration=True" \
ARGS+=" --override_config input_train.filterbank_features.pad_to_max_duration=True"
[ -n "$CHECKPOINT" ] && ARGS+=" --ckpt=$CHECKPOINT"
[ -n "$LOG_FILE" ] && ARGS+=" --log_file $LOG_FILE"
[ -n "$PRE_ALLOCATE" ] && ARGS+=" --pre_allocate_range $PRE_ALLOCATE"

View file

@ -38,6 +38,7 @@ from common.helpers import (Checkpointer, greedy_wer, num_weights, print_once,
process_evaluation_epoch)
from common.optimizers import AdamW, lr_policy, Novograd
from common.tb_dllogger import flush_log, init_log, log
from common.utils import BenchmarkStats
from jasper import config
from jasper.model import CTCLossNM, GreedyCTCDecoder, Jasper
@ -111,16 +112,17 @@ def parse_args():
help='Paths of the training dataset manifest file')
io.add_argument('--val_manifests', type=str, required=True, nargs='+',
help='Paths of the evaluation datasets manifest files')
io.add_argument('--max_duration', type=float,
help='Discard samples longer than max_duration')
io.add_argument('--pad_to_max_duration', action='store_true', default=False,
help='Pad training sequences to max_duration')
io.add_argument('--dataset_dir', required=True, type=str,
help='Root dir of dataset')
io.add_argument('--output_dir', type=str, required=True,
help='Directory for logs and checkpoints')
io.add_argument('--log_file', type=str, default=None,
help='Path to save the training logfile.')
io.add_argument('--benchmark_epochs_num', type=int, default=1,
help='Number of epochs accounted in final average throughput.')
io.add_argument('--override_config', type=str, action='append',
help='Overrides a value from a config .yaml.'
' Syntax: `--override_config nested.config.key=val`.')
return parser.parse_args()
@ -202,7 +204,7 @@ def main():
init_log(args)
cfg = config.load(args.model_config)
config.apply_duration_flags(cfg, args.max_duration, args.pad_to_max_duration)
config.apply_config_overrides(cfg, args)
symbols = helpers.add_ctc_blank(cfg['labels'])
@ -384,11 +386,14 @@ def main():
loss.backward()
model.zero_grad()
bmark_stats = BenchmarkStats()
for epoch in range(start_epoch + 1, args.epochs + 1):
if multi_gpu and not use_dali:
train_loader.sampler.set_epoch(epoch)
epoch_utts = 0
epoch_loss = 0
accumulated_batches = 0
epoch_start_time = time.time()
@ -434,6 +439,7 @@ def main():
accumulated_batches += 1
if accumulated_batches % args.grad_accumulation_steps == 0:
epoch_loss += step_loss
optimizer.step()
apply_ema(model, ema_model, args.ema)
@ -476,8 +482,11 @@ def main():
break
epoch_time = time.time() - epoch_start_time
epoch_loss /= steps_per_epoch
log((epoch,), None, 'train_avg', {'throughput': epoch_utts / epoch_time,
'took': epoch_time})
'took': epoch_time,
'loss': epoch_loss})
bmark_stats.update(epoch_utts, epoch_time, epoch_loss)
if epoch % args.save_frequency == 0 or epoch in args.keep_milestones:
checkpointer.save(model, ema_model, optimizer, epoch, step, best_wer)
@ -491,7 +500,7 @@ def main():
profiler.stop()
torch.autograd.profiler.emit_nvtx().__exit__(None, None, None)
log((), None, 'train_avg', {'throughput': epoch_utts / epoch_time})
log((), None, 'train_avg', bmark_stats.get(args.benchmark_epochs_num))
if epoch == args.epochs:
evaluate(epoch, step, val_loader, val_feat_proc, symbols, model,