[Jasper/PyT] Minor update in metrics and CLI params
This commit is contained in:
parent
36c2d7e8b8
commit
0d4dd6b523
20
PyTorch/SpeechRecognition/Jasper/common/utils.py
Normal file
20
PyTorch/SpeechRecognition/Jasper/common/utils.py
Normal file
|
@ -0,0 +1,20 @@
|
|||
import numpy as np
|
||||
|
||||
|
||||
class BenchmarkStats:
|
||||
""" Tracks statistics used for benchmarking. """
|
||||
def __init__(self):
|
||||
self.utts = []
|
||||
self.times = []
|
||||
self.losses = []
|
||||
|
||||
def update(self, utts, times, losses):
|
||||
self.utts.append(utts)
|
||||
self.times.append(times)
|
||||
self.losses.append(losses)
|
||||
|
||||
def get(self, n_epochs):
|
||||
throughput = sum(self.utts[-n_epochs:]) / sum(self.times[-n_epochs:])
|
||||
|
||||
return {'throughput': throughput, 'benchmark_epochs_num': n_epochs,
|
||||
'loss': np.mean(self.losses[-n_epochs:])}
|
|
@ -57,10 +57,6 @@ def get_parser():
|
|||
help='Relative path to evaluation dataset manifest files')
|
||||
parser.add_argument('--ckpt', default=None, type=str,
|
||||
help='Path to model checkpoint')
|
||||
parser.add_argument('--max_duration', default=None, type=float,
|
||||
help='Filter out longer inputs (in seconds)')
|
||||
parser.add_argument('--pad_to_max_duration', action='store_true',
|
||||
help='Pads every batch to max_duration')
|
||||
parser.add_argument('--amp', '--fp16', action='store_true',
|
||||
help='Use FP16 precision')
|
||||
parser.add_argument('--cudnn_benchmark', action='store_true',
|
||||
|
@ -92,6 +88,9 @@ def get_parser():
|
|||
help='Evaluate with a TorchScripted model')
|
||||
io.add_argument('--torchscript_export', action='store_true',
|
||||
help='Export the model with torch.jit to the output_dir')
|
||||
io.add_argument('--override_config', type=str, action='append',
|
||||
help='Overrides a value from a config .yaml.'
|
||||
' Syntax: `--override_config nested.config.key=val`.')
|
||||
return parser
|
||||
|
||||
|
||||
|
@ -193,15 +192,7 @@ def main():
|
|||
print_once(f'Inference with {distrib.get_world_size()} GPUs')
|
||||
|
||||
cfg = config.load(args.model_config)
|
||||
|
||||
if args.max_duration is not None:
|
||||
cfg['input_val']['audio_dataset']['max_duration'] = args.max_duration
|
||||
cfg['input_val']['filterbank_features']['max_duration'] = args.max_duration
|
||||
|
||||
if args.pad_to_max_duration:
|
||||
assert cfg['input_val']['audio_dataset']['max_duration'] > 0
|
||||
cfg['input_val']['audio_dataset']['pad_to_max_duration'] = True
|
||||
cfg['input_val']['filterbank_features']['pad_to_max_duration'] = True
|
||||
config.apply_config_overrides(cfg, args)
|
||||
|
||||
symbols = helpers.add_ctc_blank(cfg['labels'])
|
||||
|
||||
|
|
|
@ -1,5 +1,10 @@
|
|||
import copy
|
||||
import inspect
|
||||
import typing
|
||||
from ast import literal_eval
|
||||
from contextlib import suppress
|
||||
from numbers import Number
|
||||
|
||||
import yaml
|
||||
|
||||
from .model import JasperDecoderForCTC, JasperBlock, JasperEncoder
|
||||
|
@ -99,12 +104,22 @@ def decoder(conf, n_classes):
|
|||
return validate_and_fill(JasperDecoderForCTC, decoder_kw)
|
||||
|
||||
|
||||
def apply_duration_flags(cfg, max_duration, pad_to_max_duration):
|
||||
if max_duration is not None:
|
||||
cfg['input_train']['audio_dataset']['max_duration'] = max_duration
|
||||
cfg['input_train']['filterbank_features']['max_duration'] = max_duration
|
||||
def apply_config_overrides(conf, args):
|
||||
if args.override_config is None:
|
||||
return
|
||||
for override_key_val in args.override_config:
|
||||
key, val = override_key_val.split('=')
|
||||
with suppress(TypeError, ValueError):
|
||||
val = literal_eval(val)
|
||||
apply_nested_config_override(conf, key, val)
|
||||
|
||||
if pad_to_max_duration:
|
||||
assert cfg['input_train']['audio_dataset']['max_duration'] > 0
|
||||
cfg['input_train']['audio_dataset']['pad_to_max_duration'] = True
|
||||
cfg['input_train']['filterbank_features']['pad_to_max_duration'] = True
|
||||
|
||||
def apply_nested_config_override(conf, key_str, val):
|
||||
fields = key_str.split('.')
|
||||
for f in fields[:-1]:
|
||||
conf = conf[f]
|
||||
f = fields[-1]
|
||||
assert (f not in conf
|
||||
or type(val) is type(conf[f])
|
||||
or (isinstance(val, Number) and isinstance(conf[f], Number)))
|
||||
conf[f] = val
|
||||
|
|
|
@ -55,7 +55,9 @@ ARGS+=" --warmup_steps $NUM_WARMUP_STEPS"
|
|||
[ -n "$PREDICTION_FILE" ] && ARGS+=" --save_prediction $PREDICTION_FILE"
|
||||
[ -n "$LOGITS_FILE" ] && ARGS+=" --logits_save_to $LOGITS_FILE"
|
||||
[ "$CPU" == "true" ] && ARGS+=" --cpu"
|
||||
[ -n "$MAX_DURATION" ] && ARGS+=" --max_duration $MAX_DURATION"
|
||||
[ "$PAD_TO_MAX_DURATION" = true ] && ARGS+=" --pad_to_max_duration"
|
||||
[ -n "$MAX_DURATION" ] && ARGS+=" --override_config input_val.audio_dataset.max_duration=$MAX_DURATION" \
|
||||
ARGS+=" --override_config input_val.filterbank_features.max_duration=$MAX_DURATION"
|
||||
[ "$PAD_TO_MAX_DURATION" = true ] && ARGS+=" --override_config input_val.audio_dataset.pad_to_max_duration=True" \
|
||||
ARGS+=" --override_config input_val.filterbank_features.pad_to_max_duration=True"
|
||||
|
||||
python -m torch.distributed.launch --nproc_per_node=$NUM_GPUS inference.py $ARGS
|
||||
|
|
|
@ -78,7 +78,10 @@ ARGS+=" --dali_device=$DALI_DEVICE"
|
|||
[ "$AMP" = true ] && ARGS+=" --amp"
|
||||
[ "$RESUME" = true ] && ARGS+=" --resume"
|
||||
[ "$CUDNN_BENCHMARK" = true ] && ARGS+=" --cudnn_benchmark"
|
||||
[ "$PAD_TO_MAX_DURATION" = true ] && ARGS+=" --pad_to_max_duration"
|
||||
[ -n "$MAX_DURATION" ] && ARGS+=" --override_config input_train.audio_dataset.max_duration=$MAX_DURATION" \
|
||||
ARGS+=" --override_config input_train.filterbank_features.max_duration=$MAX_DURATION"
|
||||
[ "$PAD_TO_MAX_DURATION" = true ] && ARGS+=" --override_config input_train.audio_dataset.pad_to_max_duration=True" \
|
||||
ARGS+=" --override_config input_train.filterbank_features.pad_to_max_duration=True"
|
||||
[ -n "$CHECKPOINT" ] && ARGS+=" --ckpt=$CHECKPOINT"
|
||||
[ -n "$LOG_FILE" ] && ARGS+=" --log_file $LOG_FILE"
|
||||
[ -n "$PRE_ALLOCATE" ] && ARGS+=" --pre_allocate_range $PRE_ALLOCATE"
|
||||
|
|
|
@ -38,6 +38,7 @@ from common.helpers import (Checkpointer, greedy_wer, num_weights, print_once,
|
|||
process_evaluation_epoch)
|
||||
from common.optimizers import AdamW, lr_policy, Novograd
|
||||
from common.tb_dllogger import flush_log, init_log, log
|
||||
from common.utils import BenchmarkStats
|
||||
from jasper import config
|
||||
from jasper.model import CTCLossNM, GreedyCTCDecoder, Jasper
|
||||
|
||||
|
@ -111,16 +112,17 @@ def parse_args():
|
|||
help='Paths of the training dataset manifest file')
|
||||
io.add_argument('--val_manifests', type=str, required=True, nargs='+',
|
||||
help='Paths of the evaluation datasets manifest files')
|
||||
io.add_argument('--max_duration', type=float,
|
||||
help='Discard samples longer than max_duration')
|
||||
io.add_argument('--pad_to_max_duration', action='store_true', default=False,
|
||||
help='Pad training sequences to max_duration')
|
||||
io.add_argument('--dataset_dir', required=True, type=str,
|
||||
help='Root dir of dataset')
|
||||
io.add_argument('--output_dir', type=str, required=True,
|
||||
help='Directory for logs and checkpoints')
|
||||
io.add_argument('--log_file', type=str, default=None,
|
||||
help='Path to save the training logfile.')
|
||||
io.add_argument('--benchmark_epochs_num', type=int, default=1,
|
||||
help='Number of epochs accounted in final average throughput.')
|
||||
io.add_argument('--override_config', type=str, action='append',
|
||||
help='Overrides a value from a config .yaml.'
|
||||
' Syntax: `--override_config nested.config.key=val`.')
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
|
@ -202,7 +204,7 @@ def main():
|
|||
init_log(args)
|
||||
|
||||
cfg = config.load(args.model_config)
|
||||
config.apply_duration_flags(cfg, args.max_duration, args.pad_to_max_duration)
|
||||
config.apply_config_overrides(cfg, args)
|
||||
|
||||
symbols = helpers.add_ctc_blank(cfg['labels'])
|
||||
|
||||
|
@ -384,11 +386,14 @@ def main():
|
|||
loss.backward()
|
||||
model.zero_grad()
|
||||
|
||||
bmark_stats = BenchmarkStats()
|
||||
|
||||
for epoch in range(start_epoch + 1, args.epochs + 1):
|
||||
if multi_gpu and not use_dali:
|
||||
train_loader.sampler.set_epoch(epoch)
|
||||
|
||||
epoch_utts = 0
|
||||
epoch_loss = 0
|
||||
accumulated_batches = 0
|
||||
epoch_start_time = time.time()
|
||||
|
||||
|
@ -434,6 +439,7 @@ def main():
|
|||
accumulated_batches += 1
|
||||
|
||||
if accumulated_batches % args.grad_accumulation_steps == 0:
|
||||
epoch_loss += step_loss
|
||||
optimizer.step()
|
||||
apply_ema(model, ema_model, args.ema)
|
||||
|
||||
|
@ -476,8 +482,11 @@ def main():
|
|||
break
|
||||
|
||||
epoch_time = time.time() - epoch_start_time
|
||||
epoch_loss /= steps_per_epoch
|
||||
log((epoch,), None, 'train_avg', {'throughput': epoch_utts / epoch_time,
|
||||
'took': epoch_time})
|
||||
'took': epoch_time,
|
||||
'loss': epoch_loss})
|
||||
bmark_stats.update(epoch_utts, epoch_time, epoch_loss)
|
||||
|
||||
if epoch % args.save_frequency == 0 or epoch in args.keep_milestones:
|
||||
checkpointer.save(model, ema_model, optimizer, epoch, step, best_wer)
|
||||
|
@ -491,7 +500,7 @@ def main():
|
|||
profiler.stop()
|
||||
torch.autograd.profiler.emit_nvtx().__exit__(None, None, None)
|
||||
|
||||
log((), None, 'train_avg', {'throughput': epoch_utts / epoch_time})
|
||||
log((), None, 'train_avg', bmark_stats.get(args.benchmark_epochs_num))
|
||||
|
||||
if epoch == args.epochs:
|
||||
evaluate(epoch, step, val_loader, val_feat_proc, symbols, model,
|
||||
|
|
Loading…
Reference in a new issue