DeepLearningExamples/FasterTransformer/v3.0/sample/pytorch/run_squad.py
byshiue b2e89e6e80
[FT] FasterTransformer 3.0 Release (#696)
[FT] feat: Add FasterTransformer v3.0

1. Add supporting of INT8 quantization of cpp and TensorFlow op.
2. Provide the tools to quantize the model.
3. Fix the bugs that cmake 3.15 and 3.16 cannot build this project. 
4. Deprecate the FasterTransformer v1
2020-09-23 10:03:37 +08:00

489 lines
19 KiB
Python

# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Models for question-answering on SQuAD (Bert) modified from HuggingFace transformers ."""
import argparse
import logging
import os
import random
import timeit
import numpy as np
import torch
from torch.utils.data import DataLoader, SequentialSampler
from tqdm import tqdm
from transformers import (
BertConfig,
BertTokenizer,
squad_convert_examples_to_features,
)
from utils.modeling_bert import BertForQuestionAnswering
from transformers.data.metrics.squad_metrics import (
compute_predictions_logits,
squad_evaluate,
)
from transformers.data.processors.squad import SquadResult, SquadV1Processor, SquadV2Processor
logger = logging.getLogger(__name__)
def set_seed(args):
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
if args.n_gpu > 0:
torch.cuda.manual_seed_all(args.seed)
def to_list(tensor):
return tensor.detach().cpu().tolist()
def evaluate(args, model, tokenizer, prefix=""):
dataset, examples, features = load_and_cache_examples(args, tokenizer, evaluate=True, output_examples=True)
if not os.path.exists(args.output_dir) and args.local_rank in [-1, 0]:
os.makedirs(args.output_dir)
args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)
# Note that DistributedSampler samples randomly
eval_sampler = SequentialSampler(dataset)
eval_dataloader = DataLoader(dataset, sampler=eval_sampler, batch_size=args.eval_batch_size)
# multi-gpu evaluate
if args.n_gpu > 1 and not isinstance(model, torch.nn.DataParallel):
model = torch.nn.DataParallel(model)
# Eval!
logger.info("***** Running evaluation {} *****".format(prefix))
logger.info(" Num examples = %d", len(dataset))
logger.info(" Batch size = %d", args.eval_batch_size)
all_results = []
start_time = timeit.default_timer()
for batch in tqdm(eval_dataloader, desc="Evaluating"):
model.eval()
batch = tuple(t.to(args.device) for t in batch)
with torch.no_grad():
# inputs = {
# "input_ids": batch[0],
# "attention_mask": batch[1].half() if args.data_type == 'fp16' else batch[1],
# "token_type_ids": batch[2],
# }
inputs = [batch[0], batch[1].half() if args.data_type == 'fp16' else batch[1], batch[2]]
example_indices = batch[3]
# outputs = model(**inputs)
outputs = model(*inputs)
for i, example_index in enumerate(example_indices):
eval_feature = features[example_index.item()]
unique_id = int(eval_feature.unique_id)
output = [to_list(output[i]) for output in outputs]
# Some models (XLNet, XLM) use 5 arguments for their predictions, while the other "simpler"
# models only use two.
if len(output) >= 5:
start_logits = output[0]
start_top_index = output[1]
end_logits = output[2]
end_top_index = output[3]
cls_logits = output[4]
result = SquadResult(
unique_id,
start_logits,
end_logits,
start_top_index=start_top_index,
end_top_index=end_top_index,
cls_logits=cls_logits,
)
else:
start_logits, end_logits = output
result = SquadResult(unique_id, start_logits, end_logits)
all_results.append(result)
evalTime = timeit.default_timer() - start_time
logger.info(" Evaluation done in total %f secs (%f sec per example)", evalTime, evalTime / len(dataset))
# Compute predictions
output_prediction_file = os.path.join(args.output_dir, "predictions_{}.json".format(prefix))
output_nbest_file = os.path.join(args.output_dir, "nbest_predictions_{}.json".format(prefix))
if args.version_2_with_negative:
output_null_log_odds_file = os.path.join(args.output_dir, "null_odds_{}.json".format(prefix))
else:
output_null_log_odds_file = None
predictions = compute_predictions_logits(
examples,
features,
all_results,
args.n_best_size,
args.max_answer_length,
args.do_lower_case,
output_prediction_file,
output_nbest_file,
output_null_log_odds_file,
args.verbose_logging,
args.version_2_with_negative,
args.null_score_diff_threshold,
tokenizer,
)
# Compute the F1 and exact scores.
results = squad_evaluate(examples, predictions)
return results
def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=False):
if args.local_rank not in [-1, 0] and not evaluate:
# Make sure only the first process in distributed training process the dataset, and the others will use the cache
torch.distributed.barrier()
# Load data features from cache or dataset file
input_dir = args.data_dir if args.data_dir else "."
cached_features_file = os.path.join(
input_dir,
"cached_{}_{}_{}".format(
"dev" if evaluate else "train",
list(filter(None, args.model_name_or_path.split("/"))).pop(),
str(args.max_seq_length),
),
)
# Init features and dataset from cache if it exists
if os.path.exists(cached_features_file) and not args.overwrite_cache:
logger.info("Loading features from cached file %s", cached_features_file)
features_and_dataset = torch.load(cached_features_file)
features, dataset, examples = (
features_and_dataset["features"],
features_and_dataset["dataset"],
features_and_dataset["examples"],
)
else:
logger.info("Creating features from dataset file at %s", input_dir)
if not args.data_dir and ((evaluate and not args.predict_file) or (not evaluate and not args.train_file)):
try:
import tensorflow_datasets as tfds
except ImportError:
raise ImportError("If not data_dir is specified, tensorflow_datasets needs to be installed.")
if args.version_2_with_negative:
logger.warn("tensorflow_datasets does not handle version 2 of SQuAD.")
tfds_examples = tfds.load("squad")
examples = SquadV1Processor().get_examples_from_dataset(tfds_examples, evaluate=evaluate)
else:
processor = SquadV2Processor() if args.version_2_with_negative else SquadV1Processor()
if evaluate:
examples = processor.get_dev_examples(args.data_dir, filename=args.predict_file)
else:
examples = processor.get_train_examples(args.data_dir, filename=args.train_file)
features, dataset = squad_convert_examples_to_features(
examples=examples,
tokenizer=tokenizer,
max_seq_length=args.max_seq_length,
doc_stride=args.doc_stride,
max_query_length=args.max_query_length,
is_training=not evaluate,
return_dataset="pt",
threads=args.threads,
)
if args.local_rank in [-1, 0]:
logger.info("Saving features into cached file %s", cached_features_file)
torch.save({"features": features, "dataset": dataset, "examples": examples}, cached_features_file)
if args.local_rank == 0 and not evaluate:
# Make sure only the first process in distributed training process the dataset, and the others will use the cache
torch.distributed.barrier()
if output_examples:
return dataset, examples, features
return dataset
def main():
parser = argparse.ArgumentParser()
# Required parameters
parser.add_argument(
"--model_name_or_path",
default=None,
type=str,
required=True,
help="Path to pre-trained model or shortcut name",
)
parser.add_argument(
"--output_dir",
default=None,
type=str,
required=True,
help="The output directory where the model checkpoints and predictions will be written.",
)
# Other parameters
parser.add_argument(
"--data_dir",
default=None,
type=str,
help="The input data dir. Should contain the .json files for the task."
+ "If no data dir or train/predict files are specified, will run with tensorflow_datasets.",
)
parser.add_argument(
"--train_file",
default=None,
type=str,
help="The input training file. If a data dir is specified, will look for the file there"
+ "If no data dir or train/predict files are specified, will run with tensorflow_datasets.",
)
parser.add_argument(
"--predict_file",
default=None,
type=str,
help="The input evaluation file. If a data dir is specified, will look for the file there"
+ "If no data dir or train/predict files are specified, will run with tensorflow_datasets.",
)
parser.add_argument(
"--config_name", default="", type=str, help="Pretrained config name or path if not the same as model_name"
)
parser.add_argument(
"--tokenizer_name",
default="",
type=str,
help="Pretrained tokenizer name or path if not the same as model_name",
)
parser.add_argument(
"--cache_dir",
default="",
type=str,
help="Where do you want to store the pre-trained models downloaded from s3",
)
parser.add_argument(
"--version_2_with_negative",
action="store_true",
help="If true, the SQuAD examples contain some that do not have an answer.",
)
parser.add_argument(
"--null_score_diff_threshold",
type=float,
default=0.0,
help="If null_score - best_non_null is greater than the threshold predict null.",
)
parser.add_argument(
"--max_seq_length",
default=384,
type=int,
help="The maximum total input sequence length after WordPiece tokenization. Sequences "
"longer than this will be truncated, and sequences shorter than this will be padded.",
)
parser.add_argument(
"--doc_stride",
default=128,
type=int,
help="When splitting up a long document into chunks, how much stride to take between chunks.",
)
parser.add_argument(
"--max_query_length",
default=64,
type=int,
help="The maximum number of tokens for the question. Questions longer than this will "
"be truncated to this length.",
)
parser.add_argument("--do_eval", action="store_true", help="Whether to run eval on the dev set.")
parser.add_argument(
"--do_lower_case", action="store_true", help="Set this flag if you are using an uncased model."
)
parser.add_argument(
"--per_gpu_eval_batch_size", default=8, type=int, help="Batch size per GPU/CPU for evaluation."
)
parser.add_argument(
"--n_best_size",
default=20,
type=int,
help="The total number of n-best predictions to generate in the nbest_predictions.json output file.",
)
parser.add_argument(
"--max_answer_length",
default=30,
type=int,
help="The maximum length of an answer that can be generated. This is needed because the start "
"and end predictions are not conditioned on one another.",
)
parser.add_argument(
"--verbose_logging",
action="store_true",
help="If true, all of the warnings related to data processing will be printed. "
"A number of warnings are expected for a normal SQuAD evaluation.",
)
parser.add_argument(
"--overwrite_cache", action="store_true", help="Overwrite the cached training and evaluation sets"
)
parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")
parser.add_argument("--local_rank", type=int, default=-1, help="local_rank for distributed training on gpus")
parser.add_argument("--threads", type=int, default=1, help="multiple threads for converting example to features")
parser.add_argument("--model_type", type=str, help="ori, ths, ext, thsext")
parser.add_argument("--data_type", type=str, help="fp32, fp16")
parser.add_argument('--module_path', type=str, default='./',
help='path containing the th_fastertransformer dynamic lib')
parser.add_argument('--ths_path', type=str, default='./lib/libths_fastertransformer.so',
help='path of the ths_fastertransformer dynamic lib file')
args = parser.parse_args()
if args.doc_stride >= args.max_seq_length - args.max_query_length:
logger.warning(
"WARNING - You've set a doc stride which may be superior to the document length in some "
"examples. This could result in errors when building features from the examples. Please reduce the doc "
"stride or increase the maximum length to ensure the features are correctly built."
)
# Setup CUDA, GPU & distributed training
if args.local_rank == -1:
device = torch.device("cuda")
args.n_gpu = torch.cuda.device_count()
else: # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
torch.cuda.set_device(args.local_rank)
device = torch.device("cuda", args.local_rank)
torch.distributed.init_process_group(backend="nccl")
args.n_gpu = 1
args.device = device
# Setup logging
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN,
)
logger.warning(
"Process rank: %s, device: %s, n_gpu: %s",
args.local_rank,
device,
args.n_gpu,
)
# Set seed
set_seed(args)
# Load pretrained model and tokenizer
if args.local_rank not in [-1, 0]:
# Make sure only the first process in distributed training will download model & vocab
torch.distributed.barrier()
config = BertConfig.from_pretrained(
args.config_name if args.config_name else args.model_name_or_path,
cache_dir=args.cache_dir if args.cache_dir else None,
)
tokenizer = BertTokenizer.from_pretrained(
args.tokenizer_name if args.tokenizer_name else args.model_name_or_path,
do_lower_case=args.do_lower_case,
cache_dir=args.cache_dir if args.cache_dir else None,
)
if args.local_rank == 0:
# Make sure only the first process in distributed training will download model & vocab
torch.distributed.barrier()
logger.info("Parameters %s", args)
# Evaluation - we can ask to evaluate all the checkpoints (sub-directories) in a directory
results = {}
if args.do_eval and args.local_rank in [-1, 0]:
logger.info("Loading checkpoint %s for evaluation", args.model_name_or_path)
checkpoints = [args.model_name_or_path]
logger.info("Evaluate the following checkpoints: %s", checkpoints)
for checkpoint in checkpoints:
# Reload the model
global_step = checkpoint.split("-")[-1] if len(checkpoints) > 1 else ""
use_ths = args.model_type.startswith('ths')
model = BertForQuestionAnswering.from_pretrained(checkpoint, torchscript=use_ths) # , force_download=True)
model.to(args.device)
if args.data_type == 'fp16':
logger.info("Use fp16")
model.half()
if args.model_type == 'ext':
logger.info("Use custom BERT encoder")
from utils.encoder import EncoderWeights, CustomEncoder
weights = EncoderWeights(model.config.num_hidden_layers, model.config.hidden_size, model.bert.encoder)
weights.to_cuda()
if args.data_type == 'fp16':
weights.to_half()
enc = CustomEncoder(model.config.num_hidden_layers,
model.config.num_attention_heads,
model.config.hidden_size//model.config.num_attention_heads,
weights,
os.path.abspath(args.module_path))
model.replace_encoder(enc)
if args.model_type == 'thsext':
logger.info("Use custom BERT encoder for TorchScript")
from utils.encoder import EncoderWeights, CustomEncoder
weights = EncoderWeights(model.config.num_hidden_layers, model.config.hidden_size, model.bert.encoder)
weights.to_cuda()
if args.data_type == 'fp16':
weights.to_half()
enc = CustomEncoder(model.config.num_hidden_layers,
model.config.num_attention_heads,
model.config.hidden_size//model.config.num_attention_heads,
weights,
os.path.abspath(args.ths_path), True)
enc_ = torch.jit.script(enc)
model.replace_encoder(enc_)
if use_ths:
logger.info("Use TorchScript mode")
fake_input_id = torch.LongTensor(args.per_gpu_eval_batch_size, args.max_seq_length)
fake_input_id.fill_(1)
fake_input_id = fake_input_id.to(args.device)
fake_mask = torch.ones(args.per_gpu_eval_batch_size, args.max_seq_length).to(args.device)
fake_type_id = fake_input_id.clone().detach()
if args.data_type == 'fp16':
fake_mask = fake_mask.half()
model.eval()
model_ = torch.jit.trace(model, (fake_input_id, fake_mask, fake_type_id))
model = model_
# Evaluate
result = evaluate(args, model, tokenizer, prefix=global_step)
result = dict((k + ("_{}".format(global_step) if global_step else ""), v) for k, v in result.items())
results.update(result)
logger.info("Results: {}".format(results))
return results
if __name__ == "__main__":
main()