336 lines
15 KiB
Python
336 lines
15 KiB
Python
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import argparse
|
|
import itertools
|
|
from typing import List
|
|
from tqdm import tqdm
|
|
import math
|
|
import toml
|
|
from dataset import AudioToTextDataLayer
|
|
from helpers import process_evaluation_batch, process_evaluation_epoch, add_ctc_labels, print_dict, model_multi_gpu, __ctc_decoder_predictions_tensor
|
|
from model import AudioPreprocessing, GreedyCTCDecoder, JasperEncoderDecoder
|
|
from parts.features import audio_from_file
|
|
import torch
|
|
import torch.nn as nn
|
|
import apex
|
|
from apex import amp
|
|
import random
|
|
import numpy as np
|
|
import pickle
|
|
import time
|
|
import os
|
|
|
|
def parse_args():
|
|
parser = argparse.ArgumentParser(description='Jasper')
|
|
|
|
parser.register("type", "bool", lambda x: x.lower() in ("yes", "true", "t", "1"))
|
|
|
|
parser.add_argument("--local_rank", default=None, type=int)
|
|
parser.add_argument("--batch_size", default=16, type=int, help='data batch size')
|
|
parser.add_argument("--steps", default=None, help='if not specified do evaluation on full dataset. otherwise only evaluates the specified number of iterations for each worker', type=int)
|
|
parser.add_argument("--model_toml", type=str, help='relative model configuration path given dataset folder')
|
|
parser.add_argument("--dataset_dir", type=str, help='absolute path to dataset folder')
|
|
parser.add_argument("--val_manifest", type=str, help='relative path to evaluation dataset manifest file')
|
|
parser.add_argument("--ckpt", default=None, type=str, required=True, help='path to model checkpoint')
|
|
parser.add_argument("--max_duration", default=None, type=float, help='maximum duration of sequences. if None uses attribute from model configuration file')
|
|
parser.add_argument("--pad_to", default=None, type=int, help="default is pad to value as specified in model configurations. if -1 pad to maximum duration. If > 0 pad batch to next multiple of value")
|
|
parser.add_argument("--amp", "--fp16", action='store_true', help='use half precision')
|
|
parser.add_argument("--cudnn_benchmark", action='store_true', help="enable cudnn benchmark")
|
|
parser.add_argument("--save_prediction", type=str, default=None, help="if specified saves predictions in text form at this location")
|
|
parser.add_argument("--logits_save_to", default=None, type=str, help="if specified will save logits to path")
|
|
parser.add_argument("--seed", default=42, type=int, help='seed')
|
|
parser.add_argument("--output_dir", default="results/", type=str, help="Output directory to store exported models. Only used if --export_model is used")
|
|
parser.add_argument("--export_model", action='store_true', help="Exports the audio_featurizer, encoder and decoder using torch.jit to the output_dir")
|
|
parser.add_argument("--wav", type=str, help='absolute path to .wav file (16KHz)')
|
|
parser.add_argument("--cpu", action="store_true", help="Run inference on CPU")
|
|
parser.add_argument("--ema", action="store_true", help="If available, load EMA model weights")
|
|
return parser.parse_args()
|
|
|
|
def calc_wer(data_layer, audio_processor,
|
|
encoderdecoder, greedy_decoder,
|
|
labels, args, device):
|
|
|
|
encoderdecoder = encoderdecoder.module if hasattr(encoderdecoder, 'module') else encoderdecoder
|
|
with torch.no_grad():
|
|
# reset global_var_dict - results of evaluation will be stored there
|
|
_global_var_dict = {
|
|
'predictions': [],
|
|
'transcripts': [],
|
|
'logits' : [],
|
|
}
|
|
|
|
# Evaluation mini-batch for loop
|
|
for it, data in enumerate(tqdm(data_layer.data_iterator)):
|
|
|
|
tensors = [t.to(device) for t in data]
|
|
|
|
t_audio_signal_e, t_a_sig_length_e, t_transcript_e, t_transcript_len_e = tensors
|
|
|
|
t_processed_signal = audio_processor(t_audio_signal_e, t_a_sig_length_e)
|
|
t_log_probs_e, _ = encoderdecoder.infer(t_processed_signal)
|
|
t_predictions_e = greedy_decoder(t_log_probs_e)
|
|
|
|
values_dict = dict(
|
|
predictions=[t_predictions_e],
|
|
transcript=[t_transcript_e],
|
|
transcript_length=[t_transcript_len_e],
|
|
output=[t_log_probs_e]
|
|
)
|
|
# values_dict will contain results from all workers
|
|
process_evaluation_batch(values_dict, _global_var_dict, labels=labels)
|
|
|
|
if args.steps is not None and it + 1 >= args.steps:
|
|
break
|
|
|
|
# final aggregation (over minibatches) and logging of results
|
|
wer, _ = process_evaluation_epoch(_global_var_dict)
|
|
|
|
return wer, _global_var_dict
|
|
|
|
|
|
def jit_export(audio, audio_len, audio_processor, encoderdecoder, greedy_decoder, args):
|
|
|
|
print("##############")
|
|
|
|
module_name = "{}_{}".format(os.path.basename(args.model_toml), "fp16" if args.amp else "fp32")
|
|
|
|
if args.use_conv_mask:
|
|
module_name = module_name + "_noMaskConv"
|
|
|
|
# Export just the featurizer
|
|
print("exporting featurizer ...")
|
|
traced_module_feat = torch.jit.script(audio_processor)
|
|
traced_module_feat.save(os.path.join(args.output_dir, module_name + "_feat.pt"))
|
|
|
|
# Export just the acoustic model
|
|
print("exporting acoustic model ...")
|
|
inp_postFeat, _ = audio_processor(audio, audio_len)
|
|
traced_module_acoustic = torch.jit.trace(encoderdecoder, inp_postFeat)
|
|
traced_module_acoustic.save(os.path.join(args.output_dir, module_name + "_acoustic.pt"))
|
|
|
|
# Export just the decoder
|
|
print("exporting decoder ...")
|
|
|
|
inp_postAcoustic = encoderdecoder(inp_postFeat)
|
|
traced_module_decode = torch.jit.script(greedy_decoder, inp_postAcoustic)
|
|
traced_module_decode.save(os.path.join(args.output_dir, module_name + "_decoder.pt"))
|
|
print("JIT export complete")
|
|
|
|
return traced_module_feat, traced_module_acoustic, traced_module_decode
|
|
|
|
def run_once(audio_processor, encoderdecoder, greedy_decoder, audio, audio_len, labels, device):
|
|
features, lens = audio_processor(audio, audio_len)
|
|
if not device.type == 'cpu':
|
|
torch.cuda.synchronize()
|
|
t0 = time.perf_counter()
|
|
# TorchScripted model does not support (features, lengths)
|
|
if isinstance(encoderdecoder, torch.jit.TracedModule):
|
|
t_log_probs_e = encoderdecoder(features)
|
|
else:
|
|
t_log_probs_e, _ = encoderdecoder.infer((features, lens))
|
|
if not device.type == 'cpu':
|
|
torch.cuda.synchronize()
|
|
t1 = time.perf_counter()
|
|
t_predictions_e = greedy_decoder(log_probs=t_log_probs_e)
|
|
hypotheses = __ctc_decoder_predictions_tensor(t_predictions_e, labels=labels)
|
|
print("INFERENCE TIME\t\t: {} ms".format((t1-t0)*1000.0))
|
|
print("TRANSCRIPT\t\t:", hypotheses[0])
|
|
|
|
|
|
def eval(
|
|
data_layer,
|
|
audio_processor,
|
|
encoderdecoder,
|
|
greedy_decoder,
|
|
labels,
|
|
multi_gpu,
|
|
device,
|
|
args):
|
|
"""performs inference / evaluation
|
|
Args:
|
|
data_layer: data layer object that holds data loader
|
|
audio_processor: data processing module
|
|
encoderdecoder: acoustic model
|
|
greedy_decoder: greedy decoder
|
|
labels: list of labels as output vocabulary
|
|
multi_gpu: true if using multiple gpus
|
|
args: script input arguments
|
|
"""
|
|
logits_save_to=args.logits_save_to
|
|
|
|
with torch.no_grad():
|
|
if args.wav:
|
|
audio, audio_len = audio_from_file(args.wav)
|
|
run_once(audio_processor, encoderdecoder, greedy_decoder, audio, audio_len, labels, device)
|
|
if args.export_model:
|
|
jit_audio_processor, jit_encoderdecoder, jit_greedy_decoder = jit_export(audio, audio_len, audio_processor, encoderdecoder,greedy_decoder,args)
|
|
run_once(jit_audio_processor, jit_encoderdecoder, jit_greedy_decoder, audio, audio_len, labels, device)
|
|
return
|
|
wer, _global_var_dict = calc_wer(data_layer, audio_processor, encoderdecoder, greedy_decoder, labels, args, device)
|
|
if (not multi_gpu or (multi_gpu and torch.distributed.get_rank() == 0)):
|
|
print("==========>>>>>>Evaluation WER: {0}\n".format(wer))
|
|
|
|
if args.save_prediction is not None:
|
|
with open(args.save_prediction, 'w') as fp:
|
|
fp.write('\n'.join(_global_var_dict['predictions']))
|
|
if logits_save_to is not None:
|
|
logits = []
|
|
for batch in _global_var_dict["logits"]:
|
|
for i in range(batch.shape[0]):
|
|
logits.append(batch[i].cpu().numpy())
|
|
with open(logits_save_to, 'wb') as f:
|
|
pickle.dump(logits, f, protocol=pickle.HIGHEST_PROTOCOL)
|
|
|
|
# if args.export_model:
|
|
# feat, acoustic, decoder = jit_export(inp, audio_processor, encoderdecoder, greedy_decoder,args)
|
|
# wer_after = calc_wer(data_layer, feat, acoustic, decoder, labels, args)
|
|
# print("===>>>Before WER: {0}".format(wer))
|
|
# print("===>>>Traced WER: {0}".format(wer_after))
|
|
# print("===>>>Diff : {0} %".format((wer_after - wer_before) * 100.0 / wer_before))
|
|
# print("")
|
|
|
|
|
|
def main(args):
|
|
random.seed(args.seed)
|
|
np.random.seed(args.seed)
|
|
torch.manual_seed(args.seed)
|
|
|
|
multi_gpu = args.local_rank is not None
|
|
|
|
if args.cpu:
|
|
assert(not multi_gpu)
|
|
device = torch.device('cpu')
|
|
else:
|
|
assert(torch.cuda.is_available())
|
|
device = torch.device('cuda')
|
|
torch.backends.cudnn.benchmark = args.cudnn_benchmark
|
|
print("CUDNN BENCHMARK ", args.cudnn_benchmark)
|
|
|
|
if multi_gpu:
|
|
print("DISTRIBUTED with ", torch.distributed.get_world_size())
|
|
torch.cuda.set_device(args.local_rank)
|
|
torch.distributed.init_process_group(backend='nccl', init_method='env://')
|
|
|
|
optim_level = 3 if args.amp else 0
|
|
|
|
jasper_model_definition = toml.load(args.model_toml)
|
|
dataset_vocab = jasper_model_definition['labels']['labels']
|
|
ctc_vocab = add_ctc_labels(dataset_vocab)
|
|
|
|
val_manifest = args.val_manifest
|
|
featurizer_config = jasper_model_definition['input_eval']
|
|
featurizer_config["optimization_level"] = optim_level
|
|
featurizer_config["fp16"] = args.amp
|
|
|
|
args.use_conv_mask = jasper_model_definition['encoder'].get('convmask', True)
|
|
if args.use_conv_mask and args.export_model:
|
|
print('WARNING: Masked convs currently not supported for TorchScript. Disabling.')
|
|
jasper_model_definition['encoder']['convmask'] = False
|
|
|
|
if args.max_duration is not None:
|
|
featurizer_config['max_duration'] = args.max_duration
|
|
if args.pad_to is not None:
|
|
featurizer_config['pad_to'] = args.pad_to
|
|
|
|
if featurizer_config['pad_to'] == "max":
|
|
featurizer_config['pad_to'] = -1
|
|
|
|
print('=== model_config ===')
|
|
print_dict(jasper_model_definition)
|
|
print()
|
|
print('=== feature_config ===')
|
|
print_dict(featurizer_config)
|
|
print()
|
|
data_layer = None
|
|
|
|
if args.wav is None:
|
|
data_layer = AudioToTextDataLayer(
|
|
dataset_dir=args.dataset_dir,
|
|
featurizer_config=featurizer_config,
|
|
manifest_filepath=val_manifest,
|
|
labels=dataset_vocab,
|
|
batch_size=args.batch_size,
|
|
pad_to_max=featurizer_config['pad_to'] == -1,
|
|
shuffle=False,
|
|
multi_gpu=multi_gpu)
|
|
audio_preprocessor = AudioPreprocessing(**featurizer_config)
|
|
encoderdecoder = JasperEncoderDecoder(jasper_model_definition=jasper_model_definition, feat_in=1024, num_classes=len(ctc_vocab))
|
|
|
|
if args.ckpt is not None:
|
|
print("loading model from ", args.ckpt)
|
|
|
|
if os.path.isdir(args.ckpt):
|
|
exit(0)
|
|
else:
|
|
checkpoint = torch.load(args.ckpt, map_location="cpu")
|
|
if args.ema and 'ema_state_dict' in checkpoint:
|
|
print('Loading EMA state dict')
|
|
sd = 'ema_state_dict'
|
|
else:
|
|
sd = 'state_dict'
|
|
|
|
for k in audio_preprocessor.state_dict().keys():
|
|
checkpoint[sd][k] = checkpoint[sd].pop("audio_preprocessor." + k)
|
|
audio_preprocessor.load_state_dict(checkpoint[sd], strict=False)
|
|
encoderdecoder.load_state_dict(checkpoint[sd], strict=False)
|
|
|
|
greedy_decoder = GreedyCTCDecoder()
|
|
|
|
# print("Number of parameters in encoder: {0}".format(model.jasper_encoder.num_weights()))
|
|
if args.wav is None:
|
|
N = len(data_layer)
|
|
step_per_epoch = math.ceil(N / (args.batch_size * (1 if not torch.distributed.is_initialized() else torch.distributed.get_world_size())))
|
|
|
|
if args.steps is not None:
|
|
print('-----------------')
|
|
print('Have {0} examples to eval on.'.format(args.steps * args.batch_size * (1 if not torch.distributed.is_initialized() else torch.distributed.get_world_size())))
|
|
print('Have {0} steps / (gpu * epoch).'.format(args.steps))
|
|
print('-----------------')
|
|
else:
|
|
print('-----------------')
|
|
print('Have {0} examples to eval on.'.format(N))
|
|
print('Have {0} steps / (gpu * epoch).'.format(step_per_epoch))
|
|
print('-----------------')
|
|
|
|
print ("audio_preprocessor.normalize: ", audio_preprocessor.featurizer.normalize)
|
|
|
|
audio_preprocessor.to(device)
|
|
encoderdecoder.to(device)
|
|
|
|
if args.amp:
|
|
encoderdecoder = amp.initialize(models=encoderdecoder,
|
|
opt_level='O'+str(optim_level))
|
|
|
|
encoderdecoder = model_multi_gpu(encoderdecoder, multi_gpu)
|
|
audio_preprocessor.eval()
|
|
encoderdecoder.eval()
|
|
greedy_decoder.eval()
|
|
|
|
eval(
|
|
data_layer=data_layer,
|
|
audio_processor=audio_preprocessor,
|
|
encoderdecoder=encoderdecoder,
|
|
greedy_decoder=greedy_decoder,
|
|
labels=ctc_vocab,
|
|
args=args,
|
|
device=device,
|
|
multi_gpu=multi_gpu)
|
|
|
|
if __name__=="__main__":
|
|
args = parse_args()
|
|
|
|
print_dict(vars(args))
|
|
|
|
main(args)
|