DeepLearningExamples/PyTorch/SpeechRecognition/Jasper/inference.py
Przemek Strzelczyk fa400a7367 Adding Jasper/PyT
2019-07-26 20:08:16 +02:00

215 lines
9.1 KiB
Python

# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import itertools
from typing import List
from tqdm import tqdm
import math
import toml
from dataset import AudioToTextDataLayer
from helpers import process_evaluation_batch, process_evaluation_epoch, Optimization, add_ctc_labels, AmpOptimizations, print_dict, model_multi_gpu
from model import AudioPreprocessing, GreedyCTCDecoder, JasperEncoderDecoder
import torch
import apex
from apex import amp
import random
import numpy as np
import pickle
def parse_args():
parser = argparse.ArgumentParser(description='Jasper')
parser.add_argument("--local_rank", default=None, type=int)
parser.add_argument("--batch_size", default=16, type=int, help='data batch size')
parser.add_argument("--steps", default=None, help='if not specified do evaluation on full dataset. otherwise only evaluates the specified number of iterations for each worker', type=int)
parser.add_argument("--model_toml", type=str, help='relative model configuration path given dataset folder')
parser.add_argument("--dataset_dir", type=str, help='absolute path to dataset folder')
parser.add_argument("--val_manifest", type=str, help='relative path to evaluation dataset manifest file')
parser.add_argument("--ckpt", default=None, type=str, required=True, help='path to model checkpoint')
parser.add_argument("--max_duration", default=None, type=float, help='maximum duration of sequences. if None uses attribute from model configuration file')
parser.add_argument("--pad_to", default=None, type=int, help="default is pad to value as specified in model configurations. if -1 pad to maximum duration. If > 0 pad batch to next multiple of value")
parser.add_argument("--fp16", action='store_true', help='use half precision')
parser.add_argument("--cudnn_benchmark", action='store_true', help="enable cudnn benchmark")
parser.add_argument("--save_prediction", type=str, default=None, help="if specified saves predictions in text form at this location")
parser.add_argument("--logits_save_to", default=None, type=str, help="if specified will save logits to path")
parser.add_argument("--seed", default=42, type=int, help='seed')
return parser.parse_args()
def eval(
data_layer,
audio_processor,
encoderdecoder,
greedy_decoder,
labels,
multi_gpu,
args):
"""performs inference / evaluation
Args:
data_layer: data layer object that holds data loader
audio_processor: data processing module
encoderdecoder: acoustic model
greedy_decoder: greedy decoder
labels: list of labels as output vocabulary
multi_gpu: true if using multiple gpus
args: script input arguments
"""
logits_save_to=args.logits_save_to
audio_processor.eval()
encoderdecoder.eval()
with torch.no_grad():
_global_var_dict = {
'predictions': [],
'transcripts': [],
'logits' : [],
}
for it, data in enumerate(tqdm(data_layer.data_iterator)):
tensors = []
for d in data:
tensors.append(d.cuda())
t_audio_signal_e, t_a_sig_length_e, t_transcript_e, t_transcript_len_e = tensors
inp = (t_audio_signal_e, t_a_sig_length_e)
t_processed_signal, p_length_e = audio_processor(x=inp)
t_log_probs_e, _ = encoderdecoder((t_processed_signal, p_length_e))
t_predictions_e = greedy_decoder(log_probs=t_log_probs_e)
values_dict = dict(
predictions=[t_predictions_e],
transcript=[t_transcript_e],
transcript_length=[t_transcript_len_e],
output=[t_log_probs_e]
)
process_evaluation_batch(values_dict, _global_var_dict, labels=labels)
if args.steps is not None and it + 1 >= args.steps:
break
wer, _ = process_evaluation_epoch(_global_var_dict)
if (not multi_gpu or (multi_gpu and torch.distributed.get_rank() == 0)):
print("==========>>>>>>Evaluation WER: {0}\n".format(wer))
if args.save_prediction is not None:
with open(args.save_prediction, 'w') as fp:
fp.write('\n'.join(_global_var_dict['predictions']))
if logits_save_to is not None:
logits = []
for batch in _global_var_dict["logits"]:
for i in range(batch.shape[0]):
logits.append(batch[i].cpu().numpy())
with open(logits_save_to, 'wb') as f:
pickle.dump(logits, f, protocol=pickle.HIGHEST_PROTOCOL)
def main(args):
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.backends.cudnn.benchmark = args.cudnn_benchmark
print("CUDNN BENCHMARK ", args.cudnn_benchmark)
assert(torch.cuda.is_available())
if args.local_rank is not None:
torch.cuda.set_device(args.local_rank)
torch.distributed.init_process_group(backend='nccl', init_method='env://')
multi_gpu = args.local_rank is not None
if multi_gpu:
print("DISTRIBUTED with ", torch.distributed.get_world_size())
if args.fp16:
optim_level = Optimization.mxprO3
else:
optim_level = Optimization.mxprO0
jasper_model_definition = toml.load(args.model_toml)
dataset_vocab = jasper_model_definition['labels']['labels']
ctc_vocab = add_ctc_labels(dataset_vocab)
val_manifest = args.val_manifest
featurizer_config = jasper_model_definition['input_eval']
featurizer_config["optimization_level"] = optim_level
if args.max_duration is not None:
featurizer_config['max_duration'] = args.max_duration
if args.pad_to is not None:
featurizer_config['pad_to'] = args.pad_to if args.pad_to >= 0 else "max"
print('model_config')
print_dict(jasper_model_definition)
print('feature_config')
print_dict(featurizer_config)
data_layer = AudioToTextDataLayer(
dataset_dir=args.dataset_dir,
featurizer_config=featurizer_config,
manifest_filepath=val_manifest,
labels=dataset_vocab,
batch_size=args.batch_size,
pad_to_max=featurizer_config['pad_to'] == "max",
shuffle=False,
multi_gpu=multi_gpu)
audio_preprocessor = AudioPreprocessing(**featurizer_config)
encoderdecoder = JasperEncoderDecoder(jasper_model_definition=jasper_model_definition, feat_in=1024, num_classes=len(ctc_vocab))
if args.ckpt is not None:
print("loading model from ", args.ckpt)
checkpoint = torch.load(args.ckpt, map_location="cpu")
for k in audio_preprocessor.state_dict().keys():
checkpoint['state_dict'][k] = checkpoint['state_dict'].pop("audio_preprocessor." + k)
audio_preprocessor.load_state_dict(checkpoint['state_dict'], strict=False)
encoderdecoder.load_state_dict(checkpoint['state_dict'], strict=False)
greedy_decoder = GreedyCTCDecoder()
# print("Number of parameters in encoder: {0}".format(model.jasper_encoder.num_weights()))
N = len(data_layer)
step_per_epoch = math.ceil(N / (args.batch_size * (1 if not torch.distributed.is_initialized() else torch.distributed.get_world_size())))
if args.steps is not None:
print('-----------------')
print('Have {0} examples to eval on.'.format(args.steps * args.batch_size * (1 if not torch.distributed.is_initialized() else torch.distributed.get_world_size())))
print('Have {0} steps / (gpu * epoch).'.format(args.steps))
print('-----------------')
else:
print('-----------------')
print('Have {0} examples to eval on.'.format(N))
print('Have {0} steps / (gpu * epoch).'.format(step_per_epoch))
print('-----------------')
audio_preprocessor.cuda()
encoderdecoder.cuda()
if args.fp16:
encoderdecoder = amp.initialize(
models=encoderdecoder,
opt_level=AmpOptimizations[optim_level])
encoderdecoder = model_multi_gpu(encoderdecoder, multi_gpu)
eval(
data_layer=data_layer,
audio_processor=audio_preprocessor,
encoderdecoder=encoderdecoder,
greedy_decoder=greedy_decoder,
labels=ctc_vocab,
args=args,
multi_gpu=multi_gpu)
if __name__=="__main__":
args = parse_args()
print_dict(vars(args))
main(args)