DeepLearningExamples/TensorFlow/LanguageModeling/BERT/notebooks/biobert_ner_tf_inference.ipynb

610 lines
24 KiB
Text

{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Copyright 2019 NVIDIA Corporation. All Rights Reserved.\n",
"#\n",
"# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
"# you may not use this file except in compliance with the License.\n",
"# You may obtain a copy of the License at\n",
"#\n",
"# http://www.apache.org/licenses/LICENSE-2.0\n",
"#\n",
"# Unless required by applicable law or agreed to in writing, software\n",
"# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
"# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
"# See the License for the specific language governing permissions and\n",
"# limitations under the License.\n",
"# =============================================================================="
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<img src=\"http://developer.download.nvidia.com/compute/machine-learning/frameworks/nvidia_logo.png\" style=\"width: 90px; float: right;\">\n",
"\n",
"# BioBERT Named-Entity Recognition Inference with Mixed Precision\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 1. Overview\n",
"\n",
"Bidirectional Embedding Representations from Transformers (BERT), is a method of pre-training language representations which obtains state-of-the-art results on a wide array of Natural Language Processing (NLP) tasks. \n",
"\n",
"BioBERT is a domain specific version of BERT that has been trained on PubMed abstracts.\n",
"\n",
"The original BioBERT paper can be found here: https://arxiv.org/abs/1901.08746\n",
"\n",
"NVIDIA's BioBERT is an optimized version of the implementation presented in the paper, leveraging mixed precision arithmetic and tensor cores on V100 GPUS for faster training times while maintaining target accuracy."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 1.a Learning objectives\n",
"\n",
"This notebook demonstrates:\n",
"- Inference on NER task with BioBERT model\n",
"- The use/download of fine-tuned NVIDIA BioBERT models\n",
"- Use of Mixed Precision for Inference"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 2. Requirements\n",
"\n",
"Please refer to the ReadMe file"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 3. BioBERT Inference: Named-Entity Recognition\n",
"\n",
"We can run inference on a fine-tuned BioBERT model for tasks like Named-Entity Recognition.\n",
"\n",
"Here we use a BioBERT model fine-tuned on a [BC5CDR-disease Dataset](https://www.ncbi.nlm.nih.gov/research/bionlp/Data/) which consists of 1500 PubMed articles with 5818 annotated diseases."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 3.a Extract Disease Information from Text\n",
"\n",
"In this example we will use Named-Entity Recognition model created using BioBERT to extract disease information from the following paragraph:\n",
"\n",
"**Input Text**\n",
"\n",
"_\"The authors describe the case of a 56 - year - old woman with chronic, severe heart failure \n",
"secondary to dilated cardiomyopathy and absence of significant ventricular arrhythmias \n",
"who developed QT prolongation and torsade de pointes ventricular tachycardia during one cycle \n",
"of intermittent low dose (2.5 mcg/kg per min) dobutamine. \n",
"This report of torsade de pointes ventricular tachycardia during intermittent dobutamine \n",
"supports the hypothesis that unpredictable fatal arrhythmias may occur even with low doses \n",
"and in patients with no history of significant rhythm disturbances.\n",
"The mechanisms of proarrhythmic effects of Dubutamine are discussed.\"_\n",
"\n",
"**Output visualized using displaCy**\n",
"\n",
"<div class=\"entities\" style=\"line-height: 2.5; direction: ltr\">The authors describe the case of a 56 year old woman with chronic , severe \n",
"<mark class=\"entity\" style=\"background: #ddd; padding: 0.45em 0.6em; margin: 0 0.25em; line-height: 1; border-radius: 0.35em; box-decoration-break: clone; -webkit-box-decoration-break: clone\">\n",
" heart failure \n",
" <span style=\"font-size: 0.8em; font-weight: bold; line-height: 1; border-radius: 0.35em; text-transform: uppercase; vertical-align: middle; margin-left: 0.5rem\">DISEASE</span>\n",
"</mark>\n",
"secondary to \n",
"<mark class=\"entity\" style=\"background: #ddd; padding: 0.45em 0.6em; margin: 0 0.25em; line-height: 1; border-radius: 0.35em; box-decoration-break: clone; -webkit-box-decoration-break: clone\">\n",
" dilated cardiomyopathy \n",
" <span style=\"font-size: 0.8em; font-weight: bold; line-height: 1; border-radius: 0.35em; text-transform: uppercase; vertical-align: middle; margin-left: 0.5rem\">DISEASE</span>\n",
"</mark>\n",
"and absence of significant \n",
"<mark class=\"entity\" style=\"background: #ddd; padding: 0.45em 0.6em; margin: 0 0.25em; line-height: 1; border-radius: 0.35em; box-decoration-break: clone; -webkit-box-decoration-break: clone\">\n",
" ventricular arrhythmias \n",
" <span style=\"font-size: 0.8em; font-weight: bold; line-height: 1; border-radius: 0.35em; text-transform: uppercase; vertical-align: middle; margin-left: 0.5rem\">DISEASE</span>\n",
"</mark>\n",
"who developed QT \n",
"<mark class=\"entity\" style=\"background: #ddd; padding: 0.45em 0.6em; margin: 0 0.25em; line-height: 1; border-radius: 0.35em; box-decoration-break: clone; -webkit-box-decoration-break: clone\">\n",
" prolongation \n",
" <span style=\"font-size: 0.8em; font-weight: bold; line-height: 1; border-radius: 0.35em; text-transform: uppercase; vertical-align: middle; margin-left: 0.5rem\">DISEASE</span>\n",
"</mark>\n",
"and torsade de pointes \n",
"<mark class=\"entity\" style=\"background: #ddd; padding: 0.45em 0.6em; margin: 0 0.25em; line-height: 1; border-radius: 0.35em; box-decoration-break: clone; -webkit-box-decoration-break: clone\">\n",
" ventricular tachycardia \n",
" <span style=\"font-size: 0.8em; font-weight: bold; line-height: 1; border-radius: 0.35em; text-transform: uppercase; vertical-align: middle; margin-left: 0.5rem\">DISEASE</span>\n",
"</mark>\n",
"during one cycle of intermittent low dose ( 2.5 mcg / kg per min ) dobutamine . This report of torsade de pointes \n",
"<mark class=\"entity\" style=\"background: #ddd; padding: 0.45em 0.6em; margin: 0 0.25em; line-height: 1; border-radius: 0.35em; box-decoration-break: clone; -webkit-box-decoration-break: clone\">\n",
" ventricular tachycardia \n",
" <span style=\"font-size: 0.8em; font-weight: bold; line-height: 1; border-radius: 0.35em; text-transform: uppercase; vertical-align: middle; margin-left: 0.5rem\">DISEASE</span>\n",
"</mark>\n",
"during intermittent dobutamine supports the hypothesis that unpredictable fatal \n",
"<mark class=\"entity\" style=\"background: #ddd; padding: 0.45em 0.6em; margin: 0 0.25em; line-height: 1; border-radius: 0.35em; box-decoration-break: clone; -webkit-box-decoration-break: clone\">\n",
" arrhythmias \n",
" <span style=\"font-size: 0.8em; font-weight: bold; line-height: 1; border-radius: 0.35em; text-transform: uppercase; vertical-align: middle; margin-left: 0.5rem\">DISEASE</span>\n",
"</mark>\n",
"may occur even with low doses and in patients with no history of significant \n",
"<mark class=\"entity\" style=\"background: #ddd; padding: 0.45em 0.6em; margin: 0 0.25em; line-height: 1; border-radius: 0.35em; box-decoration-break: clone; -webkit-box-decoration-break: clone\">\n",
" rhythm disturbances \n",
" <span style=\"font-size: 0.8em; font-weight: bold; line-height: 1; border-radius: 0.35em; text-transform: uppercase; vertical-align: middle; margin-left: 0.5rem\">DISEASE</span>\n",
"</mark>\n",
". The mechanisms of proarrhythmic effects of Dubutamine are discussed . </div>"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"text= \"\"\"\n",
"The authors describe the case of a 56 year old woman with chronic, severe heart failure\n",
"secondary to dilated cardiomyopathy and absence of significant ventricular arrhythmias\n",
"who developed QT prolongation and torsade de pointes ventricular tachycardia during one cycle\n",
"of intermittent low dose (2.5 mcg/kg per min) dobutamine.\n",
"This report of torsade de pointes ventricular tachycardia during intermittent dobutamine\n",
"supports the hypothesis that unpredictable fatal arrhythmias may occur even with low doses\n",
"and in patients with no history of significant rhythm disturbances.\n",
"The mechanisms of proarrhythmic effects of Dubutamine are discussed.\n",
"\"\"\""
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import sys\n",
"\n",
"notebooks_dir = '../notebooks'\n",
"working_dir = '../'\n",
"if working_dir not in sys.path:\n",
" sys.path.append(working_dir)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Convert the text into the IOB tags format seen during training, using dummy placeholder labels\n",
"import spacy\n",
"nlp = spacy.load(\"en_core_web_sm\")\n",
"\n",
"text = text.strip()\n",
"doc = nlp(text)\n",
"input_file = os.path.join(notebooks_dir, 'input.tsv')\n",
"with open(os.path.join(input_file), 'w') as wf: \n",
" for word in doc:\n",
" if word.text is '\\n':\n",
" continue\n",
" wf.write(word.text + '\\tO\\n')\n",
" wf.write('\\n') # Indicate end of text"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 3.b Mixed Precision\n",
"\n",
"Mixed precision training offers significant computational speedup by performing operations in half-precision format, while storing minimal information in single-precision to retain as much information as possible in critical parts of the network. Since the introduction of tensor cores in the Volta and Turing architectures, significant training speedups are experienced by switching to mixed precision -- up to 3x overall speedup on the most arithmetically intense model architectures.\n",
"\n",
"For information about:\n",
"- How to train using mixed precision, see the [Mixed Precision Training](https://arxiv.org/abs/1710.03740) paper and [Training With Mixed Precision](https://docs.nvidia.com/deeplearning/sdk/mixed-precision-training/index.html) documentation.\n",
"- How to access and enable AMP for TensorFlow, see [Using TF-AMP](https://docs.nvidia.com/deeplearning/dgx/tensorflow-user-guide/index.html#tfamp) from the TensorFlow User Guide.\n",
"- Techniques used for mixed precision training, see the [Mixed-Precision Training of Deep Neural Networks](https://devblogs.nvidia.com/mixed-precision-training-deep-neural-networks/) blog."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In this notebook we control mixed precision execution with the environmental variable:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"os.environ[\"TF_ENABLE_AUTO_MIXED_PRECISION\"] = \"1\" "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The model we'll use was trained with mixed precision model, which takes much less time to train than the fp32 version, without losing accuracy. So we'll need to set with the following flag: "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"use_mixed_precision_model = True"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 4. Fine-Tuned NVIDIA BioBERT TF Models\n",
"\n",
"We have the following Named Entity Reconition models fine-tuned from BioBERT available on NGC (NVIDIA GPU Cluster, https://ngc.nvidia.com).\n",
"\n",
"| **Model** | **Description** |\n",
"|:---------:|:----------:|\n",
"|BioBERT NER BC5CDR Disease | NER model to extract disease information from text, trained on the BC5CDR-Disease dataset |\n",
"|BioBERT NER BC5CDR Chemical | NER model to extract chemical information from text, trained on the BC5CDR-Chemical dataset. |\n",
"\n",
"\n",
"For this exampple, we will download the Diease NER model trained from the BC5CDR-disease Dataset.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# biobert_uncased_base_ner_disease\n",
"DATA_DIR_FP16 = '../data/download/finetuned_model_fp16'\n",
"!mkdir -p $DATA_DIR_FP16\n",
"!wget -nc -q --show-progress -O $DATA_DIR_FP16/biobert_uncased_base_ner_disease.zip \\\n",
"https://api.ngc.nvidia.com/v2/models/nvidia/biobert_uncased_base_ner_disease/versions/1/zip\n",
"!unzip -n -d $DATA_DIR_FP16/ $DATA_DIR_FP16/biobert_uncased_base_ner_disease.zip "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In the code that follows we will refer to these models."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 5. Running NER task inference\n",
"\n",
"In order to run NER inference we will follow step-by-step the flow implemented in run_ner.py."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 5.a Configure Things"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import run_ner\n",
"from run_ner import BC5CDRProcessor, model_fn_builder, file_based_input_fn_builder, filed_based_convert_examples_to_features, result_to_pair\n",
"\n",
"import os, sys\n",
"import time\n",
"\n",
"import tensorflow as tf\n",
"import modeling\n",
"import tokenization\n",
"\n",
"tf.logging.set_verbosity(tf.logging.ERROR)\n",
"\n",
"# Create the output directory where all the results are saved.\n",
"output_dir = os.path.join(working_dir, 'output')\n",
"tf.gfile.MakeDirs(output_dir)\n",
"\n",
"# The config json file corresponding to the pre-trained BERT model.\n",
"# This specifies the model architecture.\n",
"bert_config_file = os.path.join(DATA_DIR_FP16, 'bert_config.json')\n",
"\n",
"# The vocabulary file that the BERT model was trained on.\n",
"vocab_file = os.path.join(DATA_DIR_FP16, 'vocab.txt')\n",
"\n",
"init_checkpoint = os.path.join(DATA_DIR_FP16, 'model.ckpt-10251')\n",
"\n",
"# Whether to lower case the input text. \n",
"# Should be True for uncased models and False for cased models.\n",
"# The BioBERT available in NGC is uncased\n",
"do_lower_case = True\n",
" \n",
"# Total batch size for predictions\n",
"predict_batch_size = 1\n",
"params = dict([('batch_size', predict_batch_size)])\n",
"\n",
"# The maximum total input sequence length after WordPiece tokenization. \n",
"# Sequences longer than this will be truncated, and sequences shorter than this will be padded.\n",
"max_seq_length = 128\n",
"\n",
"# This is a WA to use flags from here:\n",
"flags = tf.flags\n",
"\n",
"if 'f' not in tf.flags.FLAGS: \n",
" tf.app.flags.DEFINE_string('f', '', 'kernel')\n",
"FLAGS = flags.FLAGS\n",
"\n",
"FLAGS.output_dir = output_dir"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 5.b Define Tokenizer & Create Estimator"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Validate the casing config consistency with the checkpoint name.\n",
"tokenization.validate_case_matches_checkpoint(do_lower_case, init_checkpoint)\n",
"\n",
"# Create the tokenizer.\n",
"tokenizer = tokenization.FullTokenizer(vocab_file=vocab_file, do_lower_case=do_lower_case)\n",
"\n",
"# Load the configuration from file\n",
"bert_config = modeling.BertConfig.from_json_file(bert_config_file)\n",
"\n",
"\n",
"# Use the data processor for BC5CDR\n",
"processor = BC5CDRProcessor()\n",
"# Get labels in the index order that was used during training\n",
"label_list = processor.get_labels()\n",
"\n",
"# Reverse index the labels. This will be used later when evaluating predictions.\n",
"id2label = {}\n",
"for (i, label) in enumerate(label_list, 1):\n",
" id2label[i] = label\n",
"\n",
"\n",
"config = tf.ConfigProto(log_device_placement=True) \n",
"run_config = tf.estimator.RunConfig(\n",
" model_dir=None,\n",
" session_config=config,\n",
" save_checkpoints_steps=1000,\n",
" keep_checkpoint_max=1)\n",
"\n",
"\n",
"# Use model function builder to create the model function\n",
"model_fn = model_fn_builder(\n",
" bert_config=bert_config,\n",
" num_labels=len(label_list) + 1,\n",
" init_checkpoint=init_checkpoint,\n",
" use_fp16=use_mixed_precision_model)\n",
"\n",
"estimator = tf.estimator.Estimator(\n",
" model_fn=model_fn,\n",
" config=run_config,\n",
" params=params)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 5.c Run Inference"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Load the input data using the BC5CDR processor\n",
"predict_examples = processor.get_test_examples(notebooks_dir, file_name='input.tsv')\n",
"\n",
"\n",
"# Convert to tf_records and save it\n",
"predict_file = os.path.join(output_dir, \"predict.tf_record\")\n",
"filed_based_convert_examples_to_features(predict_examples, label_list,\n",
" max_seq_length, tokenizer,\n",
" predict_file)\n",
"\n",
"\n",
"tf.logging.info(\"***** Running predictions *****\")\n",
"tf.logging.info(\" Num orig examples = %d\", len(predict_examples))\n",
"tf.logging.info(\" Batch size = %d\", predict_batch_size)\n",
"\n",
"# Run prediction on this tf_record file\n",
"predict_input_fn = file_based_input_fn_builder(\n",
" input_file=predict_file,\n",
" batch_size=predict_batch_size,\n",
" seq_length=max_seq_length,\n",
" is_training=False,\n",
" drop_remainder=False)\n",
"\n",
"\n",
"pred_start_time = time.time()\n",
"\n",
"predictions = estimator.predict(input_fn=predict_input_fn)\n",
"predictions = list(predictions)\n",
"\n",
"pred_time_elapsed = time.time() - pred_start_time\n",
"\n",
"tf.logging.info(\"-----------------------------\")\n",
"tf.logging.info(\"Total Inference Time = %0.2f\", pred_time_elapsed)\n",
"# tf.logging.info(\"Inference Performance = %0.4f sentences/sec\", avg_sentences_per_second)\n",
"tf.logging.info(\"-----------------------------\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 5.d Save Predictions"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Let's now process the predictions and save them to file(s)\n",
"tf.logging.info(\"Save Predictions:\")\n",
"\n",
"# File containing the list of predictions as IOB tags\n",
"output_predict_file = os.path.join(FLAGS.output_dir, \"label_test.txt\")\n",
"# File containing the list of words, the dummy token and the predicted IOB tag\n",
"test_labels_file = os.path.join(FLAGS.output_dir, \"test_labels.txt\")\n",
"test_labels_err_file = os.path.join(FLAGS.output_dir, \"test_labels_errs.txt\")\n",
"\n",
"with tf.gfile.Open(output_predict_file, 'w') as writer, \\\n",
" tf.gfile.Open(test_labels_file, 'w') as tl, \\\n",
" tf.gfile.Open(test_labels_err_file, 'w') as tle:\n",
" i=0\n",
" for prediction in estimator.predict(input_fn=predict_input_fn, yield_single_examples=True):\n",
" output_line = \"\\n\".join(id2label[id] for id in prediction if id != 0) + \"\\n\"\n",
" writer.write(output_line)\n",
" result_to_pair(predict_examples[i], prediction, id2label, tl, tle)\n",
" i = i + 1"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 5.e Visualize Predictions"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Let's create a function that can formats the predictions for display using displaCy\n",
"def predictions_for_displacy(predict_examples, predictions, id2label):\n",
" processed_text = ''\n",
" entities = []\n",
" current_pos = 0\n",
" start_pos = 0\n",
" end_pos = 0\n",
" end_detected = False\n",
" prev_label = ''\n",
"\n",
" for predict_line, pred_ids in zip(predict_examples, predictions):\n",
" words = str(predict_line.text).split(' ')\n",
" labels = str(predict_line.label).split(' ')\n",
"\n",
" # get from CLS to SEP\n",
" pred_labels = []\n",
" for id in pred_ids:\n",
" if id == 0:\n",
" continue\n",
" curr_label = id2label[id]\n",
" if curr_label == '[CLS]':\n",
" continue\n",
" elif curr_label == '[SEP]':\n",
" break\n",
" elif curr_label == 'X':\n",
" continue\n",
" pred_labels.append(curr_label)\n",
"\n",
" for tok, label, pred_label in zip(words, labels, pred_labels):\n",
" if pred_label is 'B':\n",
" start_pos = current_pos\n",
" elif pred_label is 'I' and prev_label is not 'B' and prev_label is not 'I':\n",
" start_pos = current_pos\n",
" elif pred_label is 'O' and (prev_label is 'B' or prev_label is 'I'):\n",
" end_pos = current_pos\n",
" end_detected = True\n",
"\n",
" if end_detected:\n",
" entities.append({'start':start_pos, 'end': end_pos, 'label': 'DISEASE'})\n",
" start_pos = 0\n",
" end_pos = 0\n",
" end_detected = False\n",
"\n",
" processed_text = processed_text + tok + ' '\n",
" current_pos = current_pos + len(tok) + 1\n",
" prev_label = pred_label\n",
"\n",
" #Handle entity at the very end\n",
" if start_pos > 0 and end_detected is False:\n",
" entities.append({'start':start_pos, 'end': current_pos, 'label': 'DISEASE'})\n",
" \n",
" displacy_input = [{\"text\": processed_text,\n",
" \"ents\": entities,\n",
" \"title\": None}]\n",
" \n",
" return displacy_input"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Convert the predictions to the Named Entities format required by displaCy and visualize\n",
"displacy_input = predictions_for_displacy(predict_examples, predictions, id2label)\n",
"html = spacy.displacy.render(displacy_input, style=\"ent\", manual=True)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 6. What's next"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now that you are familiar with running NER Inference on BioBERT, using mixed precision, you may want to try extracting disease information from other biomedical text. "
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.9"
}
},
"nbformat": 4,
"nbformat_minor": 4
}