DeepLearningExamples/TensorFlow/LanguageModeling/BERT/notebooks/bert_squad_tf_finetuning.ipynb

631 lines
23 KiB
Text
Executable file

{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Copyright 2019 NVIDIA Corporation. All Rights Reserved.\n",
"#\n",
"# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
"# you may not use this file except in compliance with the License.\n",
"# You may obtain a copy of the License at\n",
"#\n",
"# http://www.apache.org/licenses/LICENSE-2.0\n",
"#\n",
"# Unless required by applicable law or agreed to in writing, software\n",
"# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
"# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
"# See the License for the specific language governing permissions and\n",
"# limitations under the License.\n",
"# =============================================================================="
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<img src=\"http://developer.download.nvidia.com/compute/machine-learning/frameworks/nvidia_logo.png\" style=\"width: 90px; float: right;\">\n",
"\n",
"# BERT Question Answering Fine-Tuning with Mixed Precision"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 1. Overview\n",
"\n",
"Bidirectional Embedding Representations from Transformers (BERT), is a method of pre-training language representations which obtains state-of-the-art results on a wide array of Natural Language Processing (NLP) tasks. \n",
"\n",
"The original paper can be found here: https://arxiv.org/abs/1810.04805.\n",
"\n",
"NVIDIA's BERT 19.10 is an optimized version of Google's official implementation, leveraging mixed precision arithmetic and tensor cores on V100 GPUS for faster training times while maintaining target accuracy."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 1.a Learning objectives\n",
"\n",
"This notebook demonstrates:\n",
"- Fine-Tuning on Question Answering (QA) task with BERT Large model\n",
"- The use/download of pretrained NVIDIA BERT models\n",
"- Use of Mixed Precision for Training"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 2. Requirements\n",
"\n",
"Please refer to Section 2. of the ReadMe file"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 3. BERT Question Answering Task\n",
"\n",
"Here we run QA fine-tuning on a pre-trained BERT model.\n",
"To fine-tune we will use the [SQuaD 1.1 Dataset](https://rajpurkar.github.io/SQuAD-explorer/) which contains 100,000+ question-answer pairs on 500+ articles."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import sys\n",
"\n",
"data_dir = '../data/download'\n",
"\n",
"# SQuAD json for training\n",
"train_file = os.path.join(data_dir, 'squad/v1.1/train-v1.1.json')\n",
"# json for inference\n",
"predict_file = os.path.join(data_dir, 'squad/v1.1/dev-v1.1.json')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 3.a Mixed Precision\n",
"\n",
"Mixed precision training offers significant computational speedup by performing operations in half-precision format, while storing minimal information in single-precision to retain as much information as possible in critical parts of the network. Since the introduction of tensor cores in the Volta and Turing architectures, significant training speedups are experienced by switching to mixed precision -- up to 3x overall speedup on the most arithmetically intense model architectures.\n",
"\n",
"For information about:\n",
"- How to train using mixed precision, see the [Mixed Precision Training](https://arxiv.org/abs/1710.03740) paper and [Training With Mixed Precision](https://docs.nvidia.com/deeplearning/sdk/mixed-precision-training/index.html) documentation.\n",
"- How to access and enable AMP for TensorFlow, see [Using TF-AMP](https://docs.nvidia.com/deeplearning/dgx/tensorflow-user-guide/index.html#tfamp) from the TensorFlow User Guide.\n",
"- Techniques used for mixed precision training, see the [Mixed-Precision Training of Deep Neural Networks](https://devblogs.nvidia.com/mixed-precision-training-deep-neural-networks/) blog."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In this notebook we control mixed precision execution with the following flag: "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"use_fp16 = True;\n",
"\n",
"import os\n",
"os.environ[\"TF_ENABLE_AUTO_MIXED_PRECISION_GRAPH_REWRITE\"] = \"1\" if use_fp16 else \"0\" \n",
"\n",
"# For detailed debug uncomment the following line:\n",
"#os.environ[\"TF_CPP_VMODULE\"]=\"auto_mixed_precision=2\""
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 4. Pre-Trained NVIDIA BERT TF Models\n",
"\n",
"Based on the model size, we have the following two default configurations of BERT.\n",
"\n",
"| **Model** | **Hidden layers** | **Hidden unit size** | **Attention heads** | **Feedforward filter size** | **Max sequence length** | **Parameters** |\n",
"|:---------:|:----------:|:----:|:---:|:--------:|:---:|:----:|\n",
"|BERTBASE |12 encoder| 768| 12|4 x 768|512|110M|\n",
"|BERTLARGE|24 encoder|1024| 16|4 x 1024|512|330M|\n",
"\n",
"We will use large pre-trained models avaialble on NGC (NVIDIA GPU Cluster, https://ngc.nvidia.com).\n",
"There are many configuration available, in particular we will download and use the following:\n",
"\n",
"**bert_tf_large_fp16_384**\n",
"\n",
"Which is pre-trained using the Wikipedia and Book corpus datasets as training data. \n",
"We will fine-tune on the SQuaD 1.1 Dataset."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's create the folders for the pre-trained models:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# bert_tf_large_fp16_384\n",
"DATA_DIR_FP16 = data_dir + '/pretrained_model_fp16'\n",
"!mkdir -p $DATA_DIR_FP16\n",
"!wget -nc -q --show-progress -O $DATA_DIR_FP16/bert_for_tensorflow.zip \\\n",
"https://api.ngc.nvidia.com/v2/models/nvidia/bert_for_tensorflow/versions/1/zip\n",
"!unzip -n -d $DATA_DIR_FP16/ $DATA_DIR_FP16/bert_for_tensorflow.zip "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In the code that follows we will refer to this model."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"notebooks_dir = '../notebooks'\n",
"\n",
"working_dir = '..'\n",
"if working_dir not in sys.path:\n",
" sys.path.append(working_dir)\n",
"\n",
"init_checkpoint = os.path.join(data_dir, 'pretrained_model_fp16/model.ckpt-1000000')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 5. Running QA task fine-tuning\n",
"\n",
"In order to run Q-A inference we will follow step-by-step a simplified flow implemented in run_squad.py:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import run_squad\n",
"\n",
"import json\n",
"import tensorflow as tf\n",
"import modeling\n",
"import tokenization\n",
"import time\n",
"import random\n",
"\n",
"import optimization\n",
"\n",
"tf.logging.set_verbosity(tf.logging.INFO)\n",
"\n",
"# Create the output directory where all the results are saved.\n",
"output_dir = os.path.join(working_dir, 'results')\n",
"tf.gfile.MakeDirs(output_dir)\n",
"\n",
"# The config json file corresponding to the pre-trained BERT model.\n",
"# This specifies the model architecture.\n",
"bert_config_file = os.path.join(data_dir, 'google_pretrained_weights/uncased_L-24_H-1024_A-16/bert_config.json')\n",
"\n",
"# The vocabulary file that the BERT model was trained on.\n",
"vocab_file = os.path.join(data_dir, 'google_pretrained_weights/uncased_L-24_H-1024_A-16/vocab.txt')\n",
"\n",
"# Whether to lower case the input text. \n",
"# Should be True for uncased models and False for cased models.\n",
"do_lower_case = True\n",
" \n",
"# Total batch size for predictions\n",
"predict_batch_size = 1\n",
"params = dict([('batch_size', predict_batch_size)])\n",
"\n",
"# The maximum total input sequence length after WordPiece tokenization. \n",
"# Sequences longer than this will be truncated, and sequences shorter than this will be padded.\n",
"max_seq_length = 384\n",
"\n",
"# When splitting up a long document into chunks, how much stride to take between chunks.\n",
"doc_stride = 128\n",
"\n",
"# The maximum number of tokens for the question. \n",
"# Questions longer than this will be truncated to this length.\n",
"max_query_length = 64\n",
"\n",
"# This is a WA to use flags from here:\n",
"flags = tf.flags\n",
"\n",
"if 'f' not in tf.flags.FLAGS: \n",
" tf.app.flags.DEFINE_string('f', '', 'kernel')\n",
"FLAGS = flags.FLAGS\n",
"\n",
"verbose_logging = True\n",
"# Set to True if the dataset has samples with no answers. For SQuAD 1.1, this is set to False\n",
"version_2_with_negative = False\n",
"\n",
"# The total number of n-best predictions to generate in the nbest_predictions.json output file.\n",
"n_best_size = 20\n",
"\n",
"# The maximum length of an answer that can be generated. \n",
"# This is needed because the start and end predictions are not conditioned on one another.\n",
"max_answer_length = 30\n",
"\n",
"# The initial learning rate for Adam\n",
"learning_rate = 5e-6\n",
"\n",
"# Total batch size for training\n",
"train_batch_size = 3\n",
"\n",
"# Proportion of training to perform linear learning rate warmup for\n",
"warmup_proportion = 0.1\n",
"\n",
"# # Total number of training epochs to perform (results will improve if trained with epochs)\n",
"num_train_epochs = 2\n",
"\n",
"global_batch_size = train_batch_size\n",
"training_hooks = []\n",
"training_hooks.append(run_squad.LogTrainRunHook(global_batch_size, 0))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's create the tokenizer and the training tf_record:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Validate the casing config consistency with the checkpoint name.\n",
"tokenization.validate_case_matches_checkpoint(do_lower_case, init_checkpoint)\n",
"\n",
"# Create the tokenizer.\n",
"tokenizer = tokenization.FullTokenizer(vocab_file=vocab_file, do_lower_case=do_lower_case)\n",
" \n",
"# Load the configuration from file\n",
"bert_config = modeling.BertConfig.from_json_file(bert_config_file)\n",
"\n",
"config = tf.ConfigProto(log_device_placement=True) \n",
"\n",
"run_config = tf.estimator.RunConfig(\n",
" model_dir=output_dir,\n",
" session_config=config,\n",
" save_checkpoints_steps=1000,\n",
" keep_checkpoint_max=1)\n",
"\n",
"# Read the training examples from the training file:\n",
"train_examples = run_squad.read_squad_examples(input_file=train_file, is_training=True)\n",
"\n",
"num_train_steps = int(len(train_examples) / global_batch_size * num_train_epochs)\n",
"num_warmup_steps = int(num_train_steps * warmup_proportion)\n",
"\n",
"# Pre-shuffle the input to avoid having to make a very large shuffle\n",
"# buffer in in the `input_fn`.\n",
"rng = random.Random(12345)\n",
"rng.shuffle(train_examples)\n",
"\n",
"start_index = 0 \n",
"end_index = len(train_examples)\n",
"tmp_filenames = os.path.join(output_dir, \"train.tf_record\")\n",
"\n",
"# We write to a temporary file to avoid storing very large constant tensors\n",
"# in memory.\n",
"train_writer = run_squad.FeatureWriter(\n",
" filename=tmp_filenames,\n",
" is_training=True)\n",
"\n",
"run_squad.convert_examples_to_features(\n",
" examples=train_examples[start_index:end_index],\n",
" tokenizer=tokenizer,\n",
" max_seq_length=max_seq_length,\n",
" doc_stride=doc_stride,\n",
" max_query_length=max_query_length,\n",
" is_training=True,\n",
" output_fn=train_writer.process_feature)\n",
"\n",
"train_writer.close()\n",
"\n",
"tf.logging.info(\"***** Running training *****\")\n",
"tf.logging.info(\" Num orig examples = %d\", end_index - start_index)\n",
"tf.logging.info(\" Num split examples = %d\", train_writer.num_features)\n",
"tf.logging.info(\" Batch size = %d\", train_batch_size)\n",
"tf.logging.info(\" Num steps = %d\", num_train_steps)\n",
"tf.logging.info(\" LR = %f\", learning_rate)\n",
"\n",
"del train_examples"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We need to create the model for the estimator:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def model_fn(features, labels, mode, params): # pylint: disable=unused-argument\n",
" unique_ids = features[\"unique_ids\"]\n",
" input_ids = features[\"input_ids\"]\n",
" input_mask = features[\"input_mask\"]\n",
" segment_ids = features[\"segment_ids\"]\n",
" \n",
" is_training = (mode == tf.estimator.ModeKeys.TRAIN)\n",
"\n",
" (start_logits, end_logits) = run_squad.create_model(\n",
" bert_config=bert_config,\n",
" is_training=is_training,\n",
" input_ids=input_ids,\n",
" input_mask=input_mask,\n",
" segment_ids=segment_ids,\n",
" use_one_hot_embeddings=False)\n",
"\n",
" tvars = tf.trainable_variables()\n",
"\n",
" initialized_variable_names = {}\n",
" if init_checkpoint:\n",
" (assignment_map, initialized_variable_names) = modeling.get_assignment_map_from_checkpoint(tvars, init_checkpoint)\n",
" tf.train.init_from_checkpoint(init_checkpoint, assignment_map)\n",
"\n",
" output_spec = None\n",
" if mode == tf.estimator.ModeKeys.TRAIN:\n",
" seq_length = modeling.get_shape_list(input_ids)[1]\n",
" \n",
" def compute_loss(logits, positions):\n",
" one_hot_positions = tf.one_hot(positions, depth=seq_length, dtype=tf.float32)\n",
" log_probs = tf.nn.log_softmax(logits, axis=-1)\n",
" loss = -tf.reduce_mean(tf.reduce_sum(one_hot_positions * log_probs, axis=-1))\n",
" return loss\n",
"\n",
" start_positions = features[\"start_positions\"]\n",
" end_positions = features[\"end_positions\"]\n",
" start_loss = compute_loss(start_logits, start_positions)\n",
" end_loss = compute_loss(end_logits, end_positions)\n",
" total_loss = (start_loss + end_loss) / 2.0\n",
" \n",
" train_op = optimization.create_optimizer(total_loss, learning_rate, num_train_steps, num_warmup_steps, None, False, use_fp16)\n",
" \n",
" output_spec = tf.estimator.EstimatorSpec(mode=mode, loss=total_loss, train_op=train_op)\n",
" \n",
" elif mode == tf.estimator.ModeKeys.PREDICT:\n",
" predictions = {\n",
" \"unique_ids\": unique_ids,\n",
" \"start_logits\": start_logits,\n",
" \"end_logits\": end_logits,\n",
" }\n",
" output_spec = tf.estimator.EstimatorSpec(mode=mode, predictions=predictions)\n",
"\n",
" return output_spec\n",
"\n",
"estimator = tf.estimator.Estimator(\n",
" model_fn=model_fn,\n",
" config=run_config,\n",
" params=params)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 5.a Fine Tuning\n",
"\n",
"Fine tuning is performed using the run_squad.py.\n",
"\n",
"The run_squad.sh script trains a model and performs evaluation on the SQuaD v1.1 dataset. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"train_input_fn = run_squad.input_fn_builder(\n",
" input_file=tmp_filenames,\n",
" batch_size=train_batch_size,\n",
" seq_length=max_seq_length,\n",
" is_training=True,\n",
" drop_remainder=True,\n",
" hvd=None)\n",
"\n",
"train_start_time = time.time()\n",
"estimator.train(input_fn=train_input_fn, hooks=training_hooks, max_steps=num_train_steps)\n",
"train_time_elapsed = time.time() - train_start_time\n",
"train_time_wo_startup = training_hooks[-1].total_time\n",
"\n",
"avg_sentences_per_second = num_train_steps * global_batch_size * 1.0 / train_time_wo_startup if train_time_wo_startup else 0\n",
"\n",
"tf.logging.info(\"-----------------------------\")\n",
"tf.logging.info(\"Total Training Time = %0.2f Training Time W/O start up overhead = %0.2f \"\n",
" \"Sentences processed = %d\", train_time_elapsed, train_time_wo_startup,\n",
" num_train_steps * global_batch_size)\n",
"tf.logging.info(\"Training Performance = %0.4f sentences/sec\", avg_sentences_per_second)\n",
"tf.logging.info(\"-----------------------------\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 5.b Inference\n",
"\n",
"Now we run inference with the fine-tuned model just saved:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"eval_examples = run_squad.read_squad_examples(\n",
" input_file=predict_file, is_training=False)\n",
"\n",
"eval_writer = run_squad.FeatureWriter(\n",
" filename=os.path.join(output_dir, \"eval.tf_record\"),\n",
" is_training=False)\n",
"\n",
"eval_features = []\n",
"def append_feature(feature):\n",
" eval_features.append(feature)\n",
" eval_writer.process_feature(feature)\n",
"\n",
"\n",
"# Loads a data file into a list of InputBatch's\n",
"run_squad.convert_examples_to_features(\n",
" examples=eval_examples,\n",
" tokenizer=tokenizer,\n",
" max_seq_length=max_seq_length,\n",
" doc_stride=doc_stride,\n",
" max_query_length=max_query_length,\n",
" is_training=False,\n",
" output_fn=append_feature)\n",
"\n",
"eval_writer.close()\n",
"\n",
"tf.logging.info(\"***** Running predictions *****\")\n",
"tf.logging.info(\" Num orig examples = %d\", len(eval_examples))\n",
"tf.logging.info(\" Num split examples = %d\", len(eval_features))\n",
"tf.logging.info(\" Batch size = %d\", predict_batch_size)\n",
"\n",
"predict_input_fn = run_squad.input_fn_builder(\n",
" input_file=eval_writer.filename,\n",
" batch_size=predict_batch_size,\n",
" seq_length=max_seq_length,\n",
" is_training=False,\n",
" drop_remainder=False)\n",
"\n",
"all_results = []\n",
"eval_hooks = [run_squad.LogEvalRunHook(predict_batch_size)]\n",
"eval_start_time = time.time()\n",
"for result in estimator.predict(\n",
" predict_input_fn, yield_single_examples=True, hooks=eval_hooks, checkpoint_path=None):\n",
" unique_id = int(result[\"unique_ids\"])\n",
" start_logits = [float(x) for x in result[\"start_logits\"].flat]\n",
" end_logits = [float(x) for x in result[\"end_logits\"].flat]\n",
" all_results.append(\n",
" run_squad.RawResult(\n",
" unique_id=unique_id,\n",
" start_logits=start_logits,\n",
" end_logits=end_logits))\n",
"\n",
"eval_time_elapsed = time.time() - eval_start_time\n",
"\n",
"time_list = eval_hooks[-1].time_list\n",
"time_list.sort()\n",
"eval_time_wo_startup = sum(time_list[:int(len(time_list) * 0.99)])\n",
"num_sentences = eval_hooks[-1].count * predict_batch_size\n",
"avg_sentences_per_second = num_sentences * 1.0 / eval_time_wo_startup\n",
"\n",
"tf.logging.info(\"-----------------------------\")\n",
"tf.logging.info(\"Total Inference Time = %0.2f Inference Time W/O start up overhead = %0.2f \"\n",
" \"Sentences processed = %d\", eval_time_elapsed, eval_time_wo_startup,\n",
" num_sentences)\n",
"tf.logging.info(\"Inference Performance = %0.4f sentences/sec\", avg_sentences_per_second)\n",
"tf.logging.info(\"-----------------------------\")\n",
"\n",
"output_prediction_file = os.path.join(output_dir, \"predictions.json\")\n",
"output_nbest_file = os.path.join(output_dir, \"nbest_predictions.json\")\n",
"output_null_log_odds_file = os.path.join(output_dir, \"null_odds.json\")\n",
"\n",
"run_squad.write_predictions(eval_examples, eval_features, all_results,\n",
" n_best_size, max_answer_length,\n",
" do_lower_case, output_prediction_file,\n",
" output_nbest_file, output_null_log_odds_file,\n",
" version_2_with_negative, verbose_logging)\n",
"\n",
"tf.logging.info(\"Inference Results:\")\n",
"\n",
"# Here we show only the prediction results, nbest prediction is also available in the output directory\n",
"results = \"\"\n",
"with open(output_prediction_file, 'r') as json_file:\n",
" data = json.load(json_file)\n",
" for question in eval_examples:\n",
" results += \"<tr><td>{}</td><td>{}</td><td>{}</td></tr>\".format(question.qas_id, question.question_text, data[question.qas_id])\n",
"\n",
"\n",
"from IPython.display import display, HTML\n",
"display(HTML(\"<table><tr><th>Id</th><th>Question</th><th>Answer</th></tr>{}</table>\".format(results))) "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 5.b Evaluation\n",
"\n",
"Let's run evaluation using the script in the SQuaD1.1 folder and our fine-tuned model:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!python ../data/download/squad/v1.1/evaluate-v1.1.py \\\n",
" $predict_file \\\n",
" $output_dir/predictions.json"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 6. What's next\n",
"\n",
"Now that you have fine-tuned a BERT model you may want to take a look at the run_squad script which containd more options for fine-tuning."
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.9"
}
},
"nbformat": 4,
"nbformat_minor": 2
}