DeepLearningExamples/PyTorch/SpeechRecognition/Jasper/inference.py
2020-08-05 16:44:50 +02:00

341 lines
15 KiB
Python

# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import itertools
from typing import List
from tqdm import tqdm
import math
import toml
from dataset import AudioToTextDataLayer
from helpers import process_evaluation_batch, process_evaluation_epoch, add_ctc_labels, print_dict, model_multi_gpu, __ctc_decoder_predictions_tensor
from model import AudioPreprocessing, GreedyCTCDecoder, JasperEncoderDecoder
from parts.features import audio_from_file
import torch
import torch.nn as nn
import apex
from apex import amp
import random
import numpy as np
import pickle
import time
import os
def parse_args():
parser = argparse.ArgumentParser(description='Jasper')
parser.register("type", "bool", lambda x: x.lower() in ("yes", "true", "t", "1"))
parser.add_argument("--local_rank", default=None, type=int)
parser.add_argument("--batch_size", default=16, type=int, help='data batch size')
parser.add_argument("--steps", default=None, help='if not specified do evaluation on full dataset. otherwise only evaluates the specified number of iterations for each worker', type=int)
parser.add_argument("--model_toml", type=str, help='relative model configuration path given dataset folder')
parser.add_argument("--dataset_dir", type=str, help='absolute path to dataset folder')
parser.add_argument("--val_manifest", type=str, help='relative path to evaluation dataset manifest file')
parser.add_argument("--ckpt", default=None, type=str, required=True, help='path to model checkpoint')
parser.add_argument("--max_duration", default=None, type=float, help='maximum duration of sequences. if None uses attribute from model configuration file')
parser.add_argument("--pad_to", default=None, type=int, help="default is pad to value as specified in model configurations. if -1 pad to maximum duration. If > 0 pad batch to next multiple of value")
parser.add_argument("--amp", "--fp16", action='store_true', help='use half precision')
parser.add_argument("--cudnn_benchmark", action='store_true', help="enable cudnn benchmark")
parser.add_argument("--save_prediction", type=str, default=None, help="if specified saves predictions in text form at this location")
parser.add_argument("--logits_save_to", default=None, type=str, help="if specified will save logits to path")
parser.add_argument("--seed", default=42, type=int, help='seed')
parser.add_argument("--output_dir", default="results/", type=str, help="Output directory to store exported models. Only used if --export_model is used")
parser.add_argument("--export_model", action='store_true', help="Exports the audio_featurizer, encoder and decoder using torch.jit to the output_dir")
parser.add_argument("--wav", type=str, help='absolute path to .wav file (16KHz)')
parser.add_argument("--cpu", action="store_true", help="Run inference on CPU")
parser.add_argument("--ema", action="store_true", help="If available, load EMA model weights")
# FIXME Unused, but passed by Triton helper scripts
parser.add_argument("--pyt_fp16", action='store_true', help='use half precision')
return parser.parse_args()
def calc_wer(data_layer, audio_processor,
encoderdecoder, greedy_decoder,
labels, args, device):
encoderdecoder = encoderdecoder.module if hasattr(encoderdecoder, 'module') else encoderdecoder
with torch.no_grad():
# reset global_var_dict - results of evaluation will be stored there
_global_var_dict = {
'predictions': [],
'transcripts': [],
'logits' : [],
}
# Evaluation mini-batch for loop
for it, data in enumerate(tqdm(data_layer.data_iterator)):
tensors = [t.to(device) for t in data]
t_audio_signal_e, t_a_sig_length_e, t_transcript_e, t_transcript_len_e = tensors
t_processed_signal = audio_processor(t_audio_signal_e, t_a_sig_length_e)
t_log_probs_e, _ = encoderdecoder.infer(t_processed_signal)
t_predictions_e = greedy_decoder(t_log_probs_e)
values_dict = dict(
predictions=[t_predictions_e],
transcript=[t_transcript_e],
transcript_length=[t_transcript_len_e],
output=[t_log_probs_e]
)
# values_dict will contain results from all workers
process_evaluation_batch(values_dict, _global_var_dict, labels=labels)
if args.steps is not None and it + 1 >= args.steps:
break
# final aggregation (over minibatches) and logging of results
wer, _ = process_evaluation_epoch(_global_var_dict)
return wer, _global_var_dict
def jit_export(audio, audio_len, audio_processor, encoderdecoder, greedy_decoder, args):
print("##############")
module_name = "{}_{}".format(os.path.basename(args.model_toml), "fp16" if args.amp else "fp32")
if args.use_conv_mask:
module_name = module_name + "_noMaskConv"
# Export just the featurizer
print("exporting featurizer ...")
traced_module_feat = torch.jit.script(audio_processor)
traced_module_feat.save(os.path.join(args.output_dir, module_name + "_feat.pt"))
# Export just the acoustic model
print("exporting acoustic model ...")
inp_postFeat, _ = audio_processor(audio, audio_len)
traced_module_acoustic = torch.jit.trace(encoderdecoder, inp_postFeat)
traced_module_acoustic.save(os.path.join(args.output_dir, module_name + "_acoustic.pt"))
# Export just the decoder
print("exporting decoder ...")
inp_postAcoustic = encoderdecoder(inp_postFeat)
traced_module_decode = torch.jit.script(greedy_decoder, inp_postAcoustic)
traced_module_decode.save(os.path.join(args.output_dir, module_name + "_decoder.pt"))
print("JIT export complete")
return traced_module_feat, traced_module_acoustic, traced_module_decode
def run_once(audio_processor, encoderdecoder, greedy_decoder, audio, audio_len, labels, device):
features, lens = audio_processor(audio, audio_len)
if not device.type == 'cpu':
torch.cuda.synchronize()
t0 = time.perf_counter()
# TorchScripted model does not support (features, lengths)
if isinstance(encoderdecoder, torch.jit.TracedModule):
t_log_probs_e = encoderdecoder(features)
else:
t_log_probs_e, _ = encoderdecoder.infer((features, lens))
if not device.type == 'cpu':
torch.cuda.synchronize()
t1 = time.perf_counter()
t_predictions_e = greedy_decoder(log_probs=t_log_probs_e)
hypotheses = __ctc_decoder_predictions_tensor(t_predictions_e, labels=labels)
print("INFERENCE TIME\t\t: {} ms".format((t1-t0)*1000.0))
print("TRANSCRIPT\t\t:", hypotheses[0])
def eval(
data_layer,
audio_processor,
encoderdecoder,
greedy_decoder,
labels,
multi_gpu,
device,
args):
"""performs inference / evaluation
Args:
data_layer: data layer object that holds data loader
audio_processor: data processing module
encoderdecoder: acoustic model
greedy_decoder: greedy decoder
labels: list of labels as output vocabulary
multi_gpu: true if using multiple gpus
args: script input arguments
"""
logits_save_to=args.logits_save_to
with torch.no_grad():
if args.wav:
audio, audio_len = audio_from_file(args.wav)
run_once(audio_processor, encoderdecoder, greedy_decoder, audio, audio_len, labels, device)
if args.export_model:
jit_audio_processor, jit_encoderdecoder, jit_greedy_decoder = jit_export(audio, audio_len, audio_processor, encoderdecoder,greedy_decoder,args)
run_once(jit_audio_processor, jit_encoderdecoder, jit_greedy_decoder, audio, audio_len, labels, device)
return
wer, _global_var_dict = calc_wer(data_layer, audio_processor, encoderdecoder, greedy_decoder, labels, args, device)
if (not multi_gpu or (multi_gpu and torch.distributed.get_rank() == 0)):
print("==========>>>>>>Evaluation WER: {0}\n".format(wer))
if args.save_prediction is not None:
with open(args.save_prediction, 'w') as fp:
fp.write('\n'.join(_global_var_dict['predictions']))
if logits_save_to is not None:
logits = []
for batch in _global_var_dict["logits"]:
for i in range(batch.shape[0]):
logits.append(batch[i].cpu().numpy())
with open(logits_save_to, 'wb') as f:
pickle.dump(logits, f, protocol=pickle.HIGHEST_PROTOCOL)
# if args.export_model:
# feat, acoustic, decoder = jit_export(inp, audio_processor, encoderdecoder, greedy_decoder,args)
# wer_after = calc_wer(data_layer, feat, acoustic, decoder, labels, args)
# print("===>>>Before WER: {0}".format(wer))
# print("===>>>Traced WER: {0}".format(wer_after))
# print("===>>>Diff : {0} %".format((wer_after - wer_before) * 100.0 / wer_before))
# print("")
def main(args):
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
multi_gpu = args.local_rank is not None
if args.cpu:
assert(not multi_gpu)
device = torch.device('cpu')
else:
assert(torch.cuda.is_available())
device = torch.device('cuda')
torch.backends.cudnn.benchmark = args.cudnn_benchmark
print("CUDNN BENCHMARK ", args.cudnn_benchmark)
if multi_gpu:
print("DISTRIBUTED with ", torch.distributed.get_world_size())
torch.cuda.set_device(args.local_rank)
torch.distributed.init_process_group(backend='nccl', init_method='env://')
optim_level = 3 if args.amp else 0
jasper_model_definition = toml.load(args.model_toml)
dataset_vocab = jasper_model_definition['labels']['labels']
ctc_vocab = add_ctc_labels(dataset_vocab)
val_manifest = args.val_manifest
featurizer_config = jasper_model_definition['input_eval']
featurizer_config["optimization_level"] = optim_level
featurizer_config["fp16"] = args.amp
args.use_conv_mask = jasper_model_definition['encoder'].get('convmask', True)
if args.use_conv_mask and args.export_model:
print('WARNING: Masked convs currently not supported for TorchScript. Disabling.')
jasper_model_definition['encoder']['convmask'] = False
if args.max_duration is not None:
featurizer_config['max_duration'] = args.max_duration
if args.pad_to is not None:
featurizer_config['pad_to'] = args.pad_to
if featurizer_config['pad_to'] == "max":
featurizer_config['pad_to'] = -1
print('=== model_config ===')
print_dict(jasper_model_definition)
print()
print('=== feature_config ===')
print_dict(featurizer_config)
print()
data_layer = None
if args.wav is None:
data_layer = AudioToTextDataLayer(
dataset_dir=args.dataset_dir,
featurizer_config=featurizer_config,
manifest_filepath=val_manifest,
labels=dataset_vocab,
batch_size=args.batch_size,
pad_to_max=featurizer_config['pad_to'] == -1,
shuffle=False,
multi_gpu=multi_gpu)
audio_preprocessor = AudioPreprocessing(**featurizer_config)
encoderdecoder = JasperEncoderDecoder(jasper_model_definition=jasper_model_definition, feat_in=1024, num_classes=len(ctc_vocab))
if args.ckpt is not None:
print("loading model from ", args.ckpt)
if os.path.isdir(args.ckpt):
exit(0)
else:
checkpoint = torch.load(args.ckpt, map_location="cpu")
if args.ema and 'ema_state_dict' in checkpoint:
print('Loading EMA state dict')
sd = 'ema_state_dict'
else:
sd = 'state_dict'
for k in audio_preprocessor.state_dict().keys():
checkpoint[sd][k] = checkpoint[sd].pop("audio_preprocessor." + k)
audio_preprocessor.load_state_dict(checkpoint[sd], strict=False)
encoderdecoder.load_state_dict(checkpoint[sd], strict=False)
greedy_decoder = GreedyCTCDecoder()
# print("Number of parameters in encoder: {0}".format(model.jasper_encoder.num_weights()))
if args.wav is None:
N = len(data_layer)
step_per_epoch = math.ceil(N / (args.batch_size * (1 if not torch.distributed.is_initialized() else torch.distributed.get_world_size())))
if args.steps is not None:
print('-----------------')
print('Have {0} examples to eval on.'.format(args.steps * args.batch_size * (1 if not torch.distributed.is_initialized() else torch.distributed.get_world_size())))
print('Have {0} steps / (gpu * epoch).'.format(args.steps))
print('-----------------')
else:
print('-----------------')
print('Have {0} examples to eval on.'.format(N))
print('Have {0} steps / (gpu * epoch).'.format(step_per_epoch))
print('-----------------')
print ("audio_preprocessor.normalize: ", audio_preprocessor.featurizer.normalize)
audio_preprocessor.to(device)
encoderdecoder.to(device)
if args.amp:
encoderdecoder = amp.initialize(models=encoderdecoder,
opt_level='O'+str(optim_level))
encoderdecoder = model_multi_gpu(encoderdecoder, multi_gpu)
audio_preprocessor.eval()
encoderdecoder.eval()
greedy_decoder.eval()
eval(
data_layer=data_layer,
audio_processor=audio_preprocessor,
encoderdecoder=encoderdecoder,
greedy_decoder=greedy_decoder,
labels=ctc_vocab,
args=args,
device=device,
multi_gpu=multi_gpu)
if __name__=="__main__":
args = parse_args()
print_dict(vars(args))
main(args)