489 lines
19 KiB
Python
489 lines
19 KiB
Python
|
# coding=utf-8
|
||
|
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
||
|
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
|
||
|
#
|
||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
# you may not use this file except in compliance with the License.
|
||
|
# You may obtain a copy of the License at
|
||
|
#
|
||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||
|
#
|
||
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
# See the License for the specific language governing permissions and
|
||
|
# limitations under the License.
|
||
|
""" Models for question-answering on SQuAD (Bert) modified from HuggingFace transformers ."""
|
||
|
|
||
|
|
||
|
import argparse
|
||
|
import logging
|
||
|
import os
|
||
|
import random
|
||
|
import timeit
|
||
|
|
||
|
import numpy as np
|
||
|
import torch
|
||
|
from torch.utils.data import DataLoader, SequentialSampler
|
||
|
from tqdm import tqdm
|
||
|
|
||
|
from transformers import (
|
||
|
BertConfig,
|
||
|
BertTokenizer,
|
||
|
squad_convert_examples_to_features,
|
||
|
)
|
||
|
from utils.modeling_bert import BertForQuestionAnswering
|
||
|
from transformers.data.metrics.squad_metrics import (
|
||
|
compute_predictions_logits,
|
||
|
squad_evaluate,
|
||
|
)
|
||
|
from transformers.data.processors.squad import SquadResult, SquadV1Processor, SquadV2Processor
|
||
|
|
||
|
|
||
|
logger = logging.getLogger(__name__)
|
||
|
|
||
|
|
||
|
def set_seed(args):
|
||
|
random.seed(args.seed)
|
||
|
np.random.seed(args.seed)
|
||
|
torch.manual_seed(args.seed)
|
||
|
if args.n_gpu > 0:
|
||
|
torch.cuda.manual_seed_all(args.seed)
|
||
|
|
||
|
|
||
|
def to_list(tensor):
|
||
|
return tensor.detach().cpu().tolist()
|
||
|
|
||
|
|
||
|
def evaluate(args, model, tokenizer, prefix=""):
|
||
|
dataset, examples, features = load_and_cache_examples(args, tokenizer, evaluate=True, output_examples=True)
|
||
|
|
||
|
if not os.path.exists(args.output_dir) and args.local_rank in [-1, 0]:
|
||
|
os.makedirs(args.output_dir)
|
||
|
|
||
|
args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)
|
||
|
|
||
|
# Note that DistributedSampler samples randomly
|
||
|
eval_sampler = SequentialSampler(dataset)
|
||
|
eval_dataloader = DataLoader(dataset, sampler=eval_sampler, batch_size=args.eval_batch_size)
|
||
|
|
||
|
# multi-gpu evaluate
|
||
|
if args.n_gpu > 1 and not isinstance(model, torch.nn.DataParallel):
|
||
|
model = torch.nn.DataParallel(model)
|
||
|
|
||
|
# Eval!
|
||
|
logger.info("***** Running evaluation {} *****".format(prefix))
|
||
|
logger.info(" Num examples = %d", len(dataset))
|
||
|
logger.info(" Batch size = %d", args.eval_batch_size)
|
||
|
|
||
|
all_results = []
|
||
|
start_time = timeit.default_timer()
|
||
|
|
||
|
for batch in tqdm(eval_dataloader, desc="Evaluating"):
|
||
|
model.eval()
|
||
|
batch = tuple(t.to(args.device) for t in batch)
|
||
|
|
||
|
with torch.no_grad():
|
||
|
# inputs = {
|
||
|
# "input_ids": batch[0],
|
||
|
# "attention_mask": batch[1].half() if args.data_type == 'fp16' else batch[1],
|
||
|
# "token_type_ids": batch[2],
|
||
|
# }
|
||
|
inputs = [batch[0], batch[1].half() if args.data_type == 'fp16' else batch[1], batch[2]]
|
||
|
|
||
|
example_indices = batch[3]
|
||
|
|
||
|
# outputs = model(**inputs)
|
||
|
outputs = model(*inputs)
|
||
|
|
||
|
for i, example_index in enumerate(example_indices):
|
||
|
eval_feature = features[example_index.item()]
|
||
|
unique_id = int(eval_feature.unique_id)
|
||
|
|
||
|
output = [to_list(output[i]) for output in outputs]
|
||
|
|
||
|
# Some models (XLNet, XLM) use 5 arguments for their predictions, while the other "simpler"
|
||
|
# models only use two.
|
||
|
if len(output) >= 5:
|
||
|
start_logits = output[0]
|
||
|
start_top_index = output[1]
|
||
|
end_logits = output[2]
|
||
|
end_top_index = output[3]
|
||
|
cls_logits = output[4]
|
||
|
|
||
|
result = SquadResult(
|
||
|
unique_id,
|
||
|
start_logits,
|
||
|
end_logits,
|
||
|
start_top_index=start_top_index,
|
||
|
end_top_index=end_top_index,
|
||
|
cls_logits=cls_logits,
|
||
|
)
|
||
|
|
||
|
else:
|
||
|
start_logits, end_logits = output
|
||
|
result = SquadResult(unique_id, start_logits, end_logits)
|
||
|
|
||
|
all_results.append(result)
|
||
|
|
||
|
evalTime = timeit.default_timer() - start_time
|
||
|
logger.info(" Evaluation done in total %f secs (%f sec per example)", evalTime, evalTime / len(dataset))
|
||
|
|
||
|
# Compute predictions
|
||
|
output_prediction_file = os.path.join(args.output_dir, "predictions_{}.json".format(prefix))
|
||
|
output_nbest_file = os.path.join(args.output_dir, "nbest_predictions_{}.json".format(prefix))
|
||
|
|
||
|
if args.version_2_with_negative:
|
||
|
output_null_log_odds_file = os.path.join(args.output_dir, "null_odds_{}.json".format(prefix))
|
||
|
else:
|
||
|
output_null_log_odds_file = None
|
||
|
|
||
|
predictions = compute_predictions_logits(
|
||
|
examples,
|
||
|
features,
|
||
|
all_results,
|
||
|
args.n_best_size,
|
||
|
args.max_answer_length,
|
||
|
args.do_lower_case,
|
||
|
output_prediction_file,
|
||
|
output_nbest_file,
|
||
|
output_null_log_odds_file,
|
||
|
args.verbose_logging,
|
||
|
args.version_2_with_negative,
|
||
|
args.null_score_diff_threshold,
|
||
|
tokenizer,
|
||
|
)
|
||
|
|
||
|
# Compute the F1 and exact scores.
|
||
|
results = squad_evaluate(examples, predictions)
|
||
|
return results
|
||
|
|
||
|
|
||
|
def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=False):
|
||
|
if args.local_rank not in [-1, 0] and not evaluate:
|
||
|
# Make sure only the first process in distributed training process the dataset, and the others will use the cache
|
||
|
torch.distributed.barrier()
|
||
|
|
||
|
# Load data features from cache or dataset file
|
||
|
input_dir = args.data_dir if args.data_dir else "."
|
||
|
cached_features_file = os.path.join(
|
||
|
input_dir,
|
||
|
"cached_{}_{}_{}".format(
|
||
|
"dev" if evaluate else "train",
|
||
|
list(filter(None, args.model_name_or_path.split("/"))).pop(),
|
||
|
str(args.max_seq_length),
|
||
|
),
|
||
|
)
|
||
|
|
||
|
# Init features and dataset from cache if it exists
|
||
|
if os.path.exists(cached_features_file) and not args.overwrite_cache:
|
||
|
logger.info("Loading features from cached file %s", cached_features_file)
|
||
|
features_and_dataset = torch.load(cached_features_file)
|
||
|
features, dataset, examples = (
|
||
|
features_and_dataset["features"],
|
||
|
features_and_dataset["dataset"],
|
||
|
features_and_dataset["examples"],
|
||
|
)
|
||
|
else:
|
||
|
logger.info("Creating features from dataset file at %s", input_dir)
|
||
|
|
||
|
if not args.data_dir and ((evaluate and not args.predict_file) or (not evaluate and not args.train_file)):
|
||
|
try:
|
||
|
import tensorflow_datasets as tfds
|
||
|
except ImportError:
|
||
|
raise ImportError("If not data_dir is specified, tensorflow_datasets needs to be installed.")
|
||
|
|
||
|
if args.version_2_with_negative:
|
||
|
logger.warn("tensorflow_datasets does not handle version 2 of SQuAD.")
|
||
|
|
||
|
tfds_examples = tfds.load("squad")
|
||
|
examples = SquadV1Processor().get_examples_from_dataset(tfds_examples, evaluate=evaluate)
|
||
|
else:
|
||
|
processor = SquadV2Processor() if args.version_2_with_negative else SquadV1Processor()
|
||
|
if evaluate:
|
||
|
examples = processor.get_dev_examples(args.data_dir, filename=args.predict_file)
|
||
|
else:
|
||
|
examples = processor.get_train_examples(args.data_dir, filename=args.train_file)
|
||
|
|
||
|
features, dataset = squad_convert_examples_to_features(
|
||
|
examples=examples,
|
||
|
tokenizer=tokenizer,
|
||
|
max_seq_length=args.max_seq_length,
|
||
|
doc_stride=args.doc_stride,
|
||
|
max_query_length=args.max_query_length,
|
||
|
is_training=not evaluate,
|
||
|
return_dataset="pt",
|
||
|
threads=args.threads,
|
||
|
)
|
||
|
|
||
|
if args.local_rank in [-1, 0]:
|
||
|
logger.info("Saving features into cached file %s", cached_features_file)
|
||
|
torch.save({"features": features, "dataset": dataset, "examples": examples}, cached_features_file)
|
||
|
|
||
|
if args.local_rank == 0 and not evaluate:
|
||
|
# Make sure only the first process in distributed training process the dataset, and the others will use the cache
|
||
|
torch.distributed.barrier()
|
||
|
|
||
|
if output_examples:
|
||
|
return dataset, examples, features
|
||
|
return dataset
|
||
|
|
||
|
|
||
|
def main():
|
||
|
parser = argparse.ArgumentParser()
|
||
|
|
||
|
# Required parameters
|
||
|
parser.add_argument(
|
||
|
"--model_name_or_path",
|
||
|
default=None,
|
||
|
type=str,
|
||
|
required=True,
|
||
|
help="Path to pre-trained model or shortcut name",
|
||
|
)
|
||
|
parser.add_argument(
|
||
|
"--output_dir",
|
||
|
default=None,
|
||
|
type=str,
|
||
|
required=True,
|
||
|
help="The output directory where the model checkpoints and predictions will be written.",
|
||
|
)
|
||
|
|
||
|
# Other parameters
|
||
|
parser.add_argument(
|
||
|
"--data_dir",
|
||
|
default=None,
|
||
|
type=str,
|
||
|
help="The input data dir. Should contain the .json files for the task."
|
||
|
+ "If no data dir or train/predict files are specified, will run with tensorflow_datasets.",
|
||
|
)
|
||
|
parser.add_argument(
|
||
|
"--train_file",
|
||
|
default=None,
|
||
|
type=str,
|
||
|
help="The input training file. If a data dir is specified, will look for the file there"
|
||
|
+ "If no data dir or train/predict files are specified, will run with tensorflow_datasets.",
|
||
|
)
|
||
|
parser.add_argument(
|
||
|
"--predict_file",
|
||
|
default=None,
|
||
|
type=str,
|
||
|
help="The input evaluation file. If a data dir is specified, will look for the file there"
|
||
|
+ "If no data dir or train/predict files are specified, will run with tensorflow_datasets.",
|
||
|
)
|
||
|
parser.add_argument(
|
||
|
"--config_name", default="", type=str, help="Pretrained config name or path if not the same as model_name"
|
||
|
)
|
||
|
parser.add_argument(
|
||
|
"--tokenizer_name",
|
||
|
default="",
|
||
|
type=str,
|
||
|
help="Pretrained tokenizer name or path if not the same as model_name",
|
||
|
)
|
||
|
parser.add_argument(
|
||
|
"--cache_dir",
|
||
|
default="",
|
||
|
type=str,
|
||
|
help="Where do you want to store the pre-trained models downloaded from s3",
|
||
|
)
|
||
|
|
||
|
parser.add_argument(
|
||
|
"--version_2_with_negative",
|
||
|
action="store_true",
|
||
|
help="If true, the SQuAD examples contain some that do not have an answer.",
|
||
|
)
|
||
|
parser.add_argument(
|
||
|
"--null_score_diff_threshold",
|
||
|
type=float,
|
||
|
default=0.0,
|
||
|
help="If null_score - best_non_null is greater than the threshold predict null.",
|
||
|
)
|
||
|
|
||
|
parser.add_argument(
|
||
|
"--max_seq_length",
|
||
|
default=384,
|
||
|
type=int,
|
||
|
help="The maximum total input sequence length after WordPiece tokenization. Sequences "
|
||
|
"longer than this will be truncated, and sequences shorter than this will be padded.",
|
||
|
)
|
||
|
parser.add_argument(
|
||
|
"--doc_stride",
|
||
|
default=128,
|
||
|
type=int,
|
||
|
help="When splitting up a long document into chunks, how much stride to take between chunks.",
|
||
|
)
|
||
|
parser.add_argument(
|
||
|
"--max_query_length",
|
||
|
default=64,
|
||
|
type=int,
|
||
|
help="The maximum number of tokens for the question. Questions longer than this will "
|
||
|
"be truncated to this length.",
|
||
|
)
|
||
|
parser.add_argument("--do_eval", action="store_true", help="Whether to run eval on the dev set.")
|
||
|
parser.add_argument(
|
||
|
"--do_lower_case", action="store_true", help="Set this flag if you are using an uncased model."
|
||
|
)
|
||
|
|
||
|
parser.add_argument(
|
||
|
"--per_gpu_eval_batch_size", default=8, type=int, help="Batch size per GPU/CPU for evaluation."
|
||
|
)
|
||
|
parser.add_argument(
|
||
|
"--n_best_size",
|
||
|
default=20,
|
||
|
type=int,
|
||
|
help="The total number of n-best predictions to generate in the nbest_predictions.json output file.",
|
||
|
)
|
||
|
parser.add_argument(
|
||
|
"--max_answer_length",
|
||
|
default=30,
|
||
|
type=int,
|
||
|
help="The maximum length of an answer that can be generated. This is needed because the start "
|
||
|
"and end predictions are not conditioned on one another.",
|
||
|
)
|
||
|
parser.add_argument(
|
||
|
"--verbose_logging",
|
||
|
action="store_true",
|
||
|
help="If true, all of the warnings related to data processing will be printed. "
|
||
|
"A number of warnings are expected for a normal SQuAD evaluation.",
|
||
|
)
|
||
|
|
||
|
parser.add_argument(
|
||
|
"--overwrite_cache", action="store_true", help="Overwrite the cached training and evaluation sets"
|
||
|
)
|
||
|
parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")
|
||
|
|
||
|
parser.add_argument("--local_rank", type=int, default=-1, help="local_rank for distributed training on gpus")
|
||
|
parser.add_argument("--threads", type=int, default=1, help="multiple threads for converting example to features")
|
||
|
parser.add_argument("--model_type", type=str, help="ori, ths, ext, thsext")
|
||
|
parser.add_argument("--data_type", type=str, help="fp32, fp16")
|
||
|
parser.add_argument('--module_path', type=str, default='./',
|
||
|
help='path containing the th_fastertransformer dynamic lib')
|
||
|
parser.add_argument('--ths_path', type=str, default='./lib/libths_fastertransformer.so',
|
||
|
help='path of the ths_fastertransformer dynamic lib file')
|
||
|
args = parser.parse_args()
|
||
|
|
||
|
if args.doc_stride >= args.max_seq_length - args.max_query_length:
|
||
|
logger.warning(
|
||
|
"WARNING - You've set a doc stride which may be superior to the document length in some "
|
||
|
"examples. This could result in errors when building features from the examples. Please reduce the doc "
|
||
|
"stride or increase the maximum length to ensure the features are correctly built."
|
||
|
)
|
||
|
|
||
|
# Setup CUDA, GPU & distributed training
|
||
|
if args.local_rank == -1:
|
||
|
device = torch.device("cuda")
|
||
|
args.n_gpu = torch.cuda.device_count()
|
||
|
else: # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
|
||
|
torch.cuda.set_device(args.local_rank)
|
||
|
device = torch.device("cuda", args.local_rank)
|
||
|
torch.distributed.init_process_group(backend="nccl")
|
||
|
args.n_gpu = 1
|
||
|
args.device = device
|
||
|
|
||
|
# Setup logging
|
||
|
logging.basicConfig(
|
||
|
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||
|
datefmt="%m/%d/%Y %H:%M:%S",
|
||
|
level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN,
|
||
|
)
|
||
|
logger.warning(
|
||
|
"Process rank: %s, device: %s, n_gpu: %s",
|
||
|
args.local_rank,
|
||
|
device,
|
||
|
args.n_gpu,
|
||
|
)
|
||
|
|
||
|
# Set seed
|
||
|
set_seed(args)
|
||
|
|
||
|
# Load pretrained model and tokenizer
|
||
|
if args.local_rank not in [-1, 0]:
|
||
|
# Make sure only the first process in distributed training will download model & vocab
|
||
|
torch.distributed.barrier()
|
||
|
|
||
|
config = BertConfig.from_pretrained(
|
||
|
args.config_name if args.config_name else args.model_name_or_path,
|
||
|
cache_dir=args.cache_dir if args.cache_dir else None,
|
||
|
)
|
||
|
tokenizer = BertTokenizer.from_pretrained(
|
||
|
args.tokenizer_name if args.tokenizer_name else args.model_name_or_path,
|
||
|
do_lower_case=args.do_lower_case,
|
||
|
cache_dir=args.cache_dir if args.cache_dir else None,
|
||
|
)
|
||
|
|
||
|
if args.local_rank == 0:
|
||
|
# Make sure only the first process in distributed training will download model & vocab
|
||
|
torch.distributed.barrier()
|
||
|
|
||
|
logger.info("Parameters %s", args)
|
||
|
|
||
|
# Evaluation - we can ask to evaluate all the checkpoints (sub-directories) in a directory
|
||
|
results = {}
|
||
|
if args.do_eval and args.local_rank in [-1, 0]:
|
||
|
logger.info("Loading checkpoint %s for evaluation", args.model_name_or_path)
|
||
|
checkpoints = [args.model_name_or_path]
|
||
|
|
||
|
logger.info("Evaluate the following checkpoints: %s", checkpoints)
|
||
|
|
||
|
for checkpoint in checkpoints:
|
||
|
# Reload the model
|
||
|
global_step = checkpoint.split("-")[-1] if len(checkpoints) > 1 else ""
|
||
|
use_ths = args.model_type.startswith('ths')
|
||
|
model = BertForQuestionAnswering.from_pretrained(checkpoint, torchscript=use_ths) # , force_download=True)
|
||
|
model.to(args.device)
|
||
|
|
||
|
if args.data_type == 'fp16':
|
||
|
logger.info("Use fp16")
|
||
|
model.half()
|
||
|
if args.model_type == 'ext':
|
||
|
logger.info("Use custom BERT encoder")
|
||
|
from utils.encoder import EncoderWeights, CustomEncoder
|
||
|
weights = EncoderWeights(model.config.num_hidden_layers, model.config.hidden_size, model.bert.encoder)
|
||
|
weights.to_cuda()
|
||
|
if args.data_type == 'fp16':
|
||
|
weights.to_half()
|
||
|
enc = CustomEncoder(model.config.num_hidden_layers,
|
||
|
model.config.num_attention_heads,
|
||
|
model.config.hidden_size//model.config.num_attention_heads,
|
||
|
weights,
|
||
|
os.path.abspath(args.module_path))
|
||
|
model.replace_encoder(enc)
|
||
|
if args.model_type == 'thsext':
|
||
|
logger.info("Use custom BERT encoder for TorchScript")
|
||
|
from utils.encoder import EncoderWeights, CustomEncoder
|
||
|
weights = EncoderWeights(model.config.num_hidden_layers, model.config.hidden_size, model.bert.encoder)
|
||
|
weights.to_cuda()
|
||
|
if args.data_type == 'fp16':
|
||
|
weights.to_half()
|
||
|
enc = CustomEncoder(model.config.num_hidden_layers,
|
||
|
model.config.num_attention_heads,
|
||
|
model.config.hidden_size//model.config.num_attention_heads,
|
||
|
weights,
|
||
|
os.path.abspath(args.ths_path), True)
|
||
|
enc_ = torch.jit.script(enc)
|
||
|
model.replace_encoder(enc_)
|
||
|
if use_ths:
|
||
|
logger.info("Use TorchScript mode")
|
||
|
fake_input_id = torch.LongTensor(args.per_gpu_eval_batch_size, args.max_seq_length)
|
||
|
fake_input_id.fill_(1)
|
||
|
fake_input_id = fake_input_id.to(args.device)
|
||
|
fake_mask = torch.ones(args.per_gpu_eval_batch_size, args.max_seq_length).to(args.device)
|
||
|
fake_type_id = fake_input_id.clone().detach()
|
||
|
if args.data_type == 'fp16':
|
||
|
fake_mask = fake_mask.half()
|
||
|
model.eval()
|
||
|
model_ = torch.jit.trace(model, (fake_input_id, fake_mask, fake_type_id))
|
||
|
model = model_
|
||
|
|
||
|
# Evaluate
|
||
|
result = evaluate(args, model, tokenizer, prefix=global_step)
|
||
|
|
||
|
result = dict((k + ("_{}".format(global_step) if global_step else ""), v) for k, v in result.items())
|
||
|
results.update(result)
|
||
|
|
||
|
logger.info("Results: {}".format(results))
|
||
|
|
||
|
return results
|
||
|
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
main()
|