Merge pull request #824 from swethmandava/master

Updating notebooks and gpu_affinity scripts for bert tf1
This commit is contained in:
Swetha Mandava 2021-02-03 11:09:03 -08:00 committed by GitHub
commit 004385d0ce
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
12 changed files with 313 additions and 199 deletions

View file

@ -50,7 +50,8 @@ Once the image is built, you need to run the container with the `--publish
at port `8888` over all network interfaces (`0.0.0.0`):
```bash
nvidia-docker run \
docker run \
--gpus all \
-v $PWD:/workspace/bert \
-v $PWD/results:/results \
--shm-size=1g \
@ -62,11 +63,6 @@ nvidia-docker run \
### 2.c Dataset
We need to download the vocabulary and the bert_config files:
```python3
python3 /workspace/bert/data/bertPrep.py --action download --dataset google_pretrained_weights # Includes vocab
```
This is only needed during fine-tuning in order to download the Squad dataset:
@ -134,7 +130,8 @@ Once the image is built, you need to run the container with the `--publish
at port `8888` over all network interfaces (`0.0.0.0`):
```bash
nvidia-docker run \
docker run \
--gpus all \
-v $PWD:/workspace/bert \
-v $PWD/results:/results \
--shm-size=1g \

View file

@ -6,7 +6,7 @@
"metadata": {},
"outputs": [],
"source": [
"# Copyright 2019 NVIDIA Corporation. All Rights Reserved.\n",
"# Copyright 2021 NVIDIA Corporation. All Rights Reserved.\n",
"#\n",
"# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
"# you may not use this file except in compliance with the License.\n",
@ -144,9 +144,9 @@
"We will use large pre-trained models avaialble on NGC (NVIDIA GPU Cluster, https://ngc.nvidia.com).\n",
"There are many configuration available, in particular we will download and use the following:\n",
"\n",
"**bert_tf_large_fp16_384**\n",
"**bert_tf_ckpt_large_pretraining_amp_lamb**\n",
"\n",
"Which is pre-trained using the Wikipedia and Book corpus datasets as training data. \n",
"Which is pre-trained using the Wikipedia and Book corpus datasets as training data with AMP and LAMB optimizer. \n",
"We will fine-tune on the SQuaD 1.1 Dataset."
]
},
@ -163,12 +163,21 @@
"metadata": {},
"outputs": [],
"source": [
"# bert_tf_large_fp16_384\n",
"DATA_DIR_FP16 = data_dir + '/pretrained_model_fp16'\n",
"!mkdir -p $DATA_DIR_FP16\n",
"!wget -nc -q --show-progress -O $DATA_DIR_FP16/bert_for_tensorflow.zip \\\n",
"https://api.ngc.nvidia.com/v2/models/nvidia/bert_for_tensorflow/versions/1/zip\n",
"!unzip -n -d $DATA_DIR_FP16/ $DATA_DIR_FP16/bert_for_tensorflow.zip "
"# bert_tf_large pretrained model\n",
"DATA_DIR_PT = data_dir + '/pretrained_large_model'\n",
"!mkdir -p $DATA_DIR_PT\n",
"!wget --content-disposition -O $DATA_DIR_PT/bert_tf_ckpt_large_pretraining_amp_lamb_19.03.1.zip \\\n",
"https://api.ngc.nvidia.com/v2/models/nvidia/bert_tf_ckpt_large_pretraining_amp_lamb/versions/19.03.1/zip \\\n",
"&& unzip -n -d $DATA_DIR_PT/ $DATA_DIR_PT/bert_tf_ckpt_large_pretraining_amp_lamb_19.03.1.zip \\\n",
"&& rm $DATA_DIR_PT/bert_tf_ckpt_large_pretraining_amp_lamb_19.03.1.zip\n",
"\n",
"# bert_tf_large finetuned model on SQUAD1.1\n",
"DATA_DIR_FT = data_dir + '/finetuned_large_model_SQUAD1.1'\n",
"!mkdir -p $DATA_DIR_FT\n",
"!wget --content-disposition -O $DATA_DIR_FT/bert_tf_ckpt_large_qa_squad11_amp_384_19.03.1.zip \\\n",
"https://api.ngc.nvidia.com/v2/models/nvidia/bert_tf_ckpt_large_qa_squad11_amp_384/versions/19.03.1/zip \\\n",
"&& unzip -n -d $DATA_DIR_FT/ $DATA_DIR_FT/bert_tf_ckpt_large_qa_squad11_amp_384_19.03.1.zip \\\n",
"&& rm $DATA_DIR_FT/bert_tf_ckpt_large_qa_squad11_amp_384_19.03.1.zip"
]
},
{
@ -190,7 +199,7 @@
"if working_dir not in sys.path:\n",
" sys.path.append(working_dir)\n",
"\n",
"init_checkpoint = os.path.join(data_dir, 'pretrained_model_fp16/model.ckpt-1000000')"
"init_checkpoint = os.path.join(data_dir, 'pretrained_large_model/model.ckpt')"
]
},
{
@ -227,10 +236,10 @@
"\n",
"# The config json file corresponding to the pre-trained BERT model.\n",
"# This specifies the model architecture.\n",
"bert_config_file = os.path.join(data_dir, 'google_pretrained_weights/uncased_L-24_H-1024_A-16/bert_config.json')\n",
"bert_config_file = os.path.join(data_dir, 'finetuned_large_model_SQUAD1.1/bert_config.json')\n",
"\n",
"# The vocabulary file that the BERT model was trained on.\n",
"vocab_file = os.path.join(data_dir, 'google_pretrained_weights/uncased_L-24_H-1024_A-16/vocab.txt')\n",
"vocab_file = os.path.join(data_dir, 'finetuned_large_model_SQUAD1.1/vocab.txt')\n",
"\n",
"# Whether to lower case the input text. \n",
"# Should be True for uncased models and False for cased models.\n",
@ -279,7 +288,7 @@
"warmup_proportion = 0.1\n",
"\n",
"# # Total number of training epochs to perform (results will improve if trained with epochs)\n",
"num_train_epochs = 2\n",
"num_train_epochs = 1\n",
"\n",
"global_batch_size = train_batch_size\n",
"training_hooks = []\n",
@ -322,8 +331,7 @@
"num_train_steps = int(len(train_examples) / global_batch_size * num_train_epochs)\n",
"num_warmup_steps = int(num_train_steps * warmup_proportion)\n",
"\n",
"# Pre-shuffle the input to avoid having to make a very large shuffle\n",
"# buffer in in the `input_fn`.\n",
"# Pre-shuffle the input to avoid having to make a very large shuffle buffer in in the `input_fn`.\n",
"rng = random.Random(12345)\n",
"rng.shuffle(train_examples)\n",
"\n",
@ -331,8 +339,7 @@
"end_index = len(train_examples)\n",
"tmp_filenames = os.path.join(output_dir, \"train.tf_record\")\n",
"\n",
"# We write to a temporary file to avoid storing very large constant tensors\n",
"# in memory.\n",
"# We write to a temporary file to avoid storing very large constant tensors in memory.\n",
"train_writer = run_squad.FeatureWriter(\n",
" filename=tmp_filenames,\n",
" is_training=True)\n",
@ -353,7 +360,7 @@
"tf.logging.info(\" Num split examples = %d\", train_writer.num_features)\n",
"tf.logging.info(\" Batch size = %d\", train_batch_size)\n",
"tf.logging.info(\" Num steps = %d\", num_train_steps)\n",
"tf.logging.info(\" LR = %f\", learning_rate)\n",
"tf.logging.info(\" Learning Rate = %f\", learning_rate)\n",
"\n",
"del train_examples"
]

View file

@ -6,7 +6,7 @@
"metadata": {},
"outputs": [],
"source": [
"# Copyright 2019 NVIDIA Corporation. All Rights Reserved.\n",
"# Copyright 2021 NVIDIA Corporation. All Rights Reserved.\n",
"#\n",
"# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
"# you may not use this file except in compliance with the License.\n",
@ -210,22 +210,6 @@
"os.environ[\"TF_ENABLE_AUTO_MIXED_PRECISION\"] = \"1\" "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can choose the mixed precision model (which takes much less time to train than the fp32 version) without losing accuracy, with the following flag: "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"use_mixed_precision_model = True"
]
},
{
"cell_type": "markdown",
"metadata": {},
@ -258,9 +242,7 @@
"We will take advantage of the fine-tuned models available on NGC (NVIDIA GPU Cluster, https://ngc.nvidia.com).\n",
"Among the many configurations available we will download these two:\n",
"\n",
" - **bert_tf_v2_large_fp32_384**\n",
"\n",
" - **bert_tf_v2_large_fp16_384**\n",
" - **bert_tf_ckpt_large_qa_squad2_amp_384**\n",
"\n",
"Which are trained on the SQuaD 2.0 Dataset."
]
@ -271,19 +253,14 @@
"metadata": {},
"outputs": [],
"source": [
"# bert_tf_v2_large_fp32_384\n",
"DATA_DIR_FP32 = data_dir + '/finetuned_model_fp32'\n",
"!mkdir -p $DATA_DIR_FP32\n",
"!wget -nc -q --show-progress -O $DATA_DIR_FP32/bert_tf_v2_large_fp32_384.zip \\\n",
"https://api.ngc.nvidia.com/v2/models/nvidia/bert_tf_v2_large_fp32_384/versions/1/zip\n",
"!unzip -n -d $DATA_DIR_FP32/ $DATA_DIR_FP32/bert_tf_v2_large_fp32_384.zip \n",
"# bert_tf_ckpt_large_qa_squad2_amp_384\n",
"DATA_DIR_FT = data_dir + '/finetuned_large_model_SQUAD2.0'\n",
"!mkdir -p $DATA_DIR_FT\n",
" \n",
"# bert_tf_v2_large_fp16_384\n",
"DATA_DIR_FP16 = data_dir + '/finetuned_model_fp16'\n",
"!mkdir -p $DATA_DIR_FP16\n",
"!wget -nc -q --show-progress -O $DATA_DIR_FP16/bert_tf_v2_large_fp16_384.zip \\\n",
"https://api.ngc.nvidia.com/v2/models/nvidia/bert_tf_v2_large_fp16_384/versions/1/zip\n",
"!unzip -n -d $DATA_DIR_FP16/ $DATA_DIR_FP16/bert_tf_v2_large_fp16_384.zip "
"!wget --content-disposition -O $DATA_DIR_FT/bert_tf_ckpt_large_qa_squad2_amp_384_19.03.1.zip \\\n",
"https://api.ngc.nvidia.com/v2/models/nvidia/bert_tf_ckpt_large_qa_squad2_amp_384/versions/19.03.1/zip \\\n",
"&& unzip -n -d $DATA_DIR_FT/ $DATA_DIR_FT/bert_tf_ckpt_large_qa_squad2_amp_384_19.03.1.zip \\\n",
"&& rm -rf $DATA_DIR_FT/bert_tf_ckpt_large_qa_squad2_amp_384_19.03.1.zip"
]
},
{
@ -326,16 +303,13 @@
"\n",
"# The config json file corresponding to the pre-trained BERT model.\n",
"# This specifies the model architecture.\n",
"bert_config_file = os.path.join(data_dir, 'google_pretrained_weights/uncased_L-24_H-1024_A-16/bert_config.json')\n",
"bert_config_file = os.path.join(data_dir, 'finetuned_large_model_SQUAD2.0/bert_config.json')\n",
"\n",
"# The vocabulary file that the BERT model was trained on.\n",
"vocab_file = os.path.join(data_dir, 'google_pretrained_weights/uncased_L-24_H-1024_A-16/vocab.txt')\n",
"vocab_file = os.path.join(data_dir, 'finetuned_large_model_SQUAD2.0/vocab.txt')\n",
"\n",
"# Depending on the mixed precision flag we use different fine-tuned model\n",
"if use_mixed_precision_model:\n",
" init_checkpoint = os.path.join(data_dir, 'finetuned_model_fp16/model.ckpt-8144')\n",
"else:\n",
" init_checkpoint = os.path.join(data_dir, 'finetuned_model_fp32/model.ckpt-8144')\n",
"# Initiate checkpoint to the fine-tuned BERT Large model\n",
"init_checkpoint = os.path.join(data_dir, 'finetuned_large_model_SQUAD2.0/model.ckpt')\n",
"\n",
"# Whether to lower case the input text. \n",
"# Should be True for uncased models and False for cased models.\n",

View file

@ -4,13 +4,11 @@
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "jDXroBuNw60P"
},
"outputs": [],
"source": [
"# Copyright 2019 NVIDIA Corporation. All Rights Reserved.\n",
"# Copyright 2021 NVIDIA Corporation. All Rights Reserved.\n",
"#\n",
"# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
"# you may not use this file except in compliance with the License.\n",
@ -28,7 +26,9 @@
},
{
"cell_type": "markdown",
"metadata": {},
"metadata": {
"id": "DXgkABw0IUIC"
},
"source": [
"<a href=\"https://colab.research.google.com/github/NVIDIA/DeepLearningExamples/blob/master/TensorFlow/LanguageModeling/BERT/notebooks/bert_squad_tf_inference_colab.ipynb#scrollTo=5hRb96NKE3X0\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
@ -36,7 +36,6 @@
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "k-XnFINow60d"
},
"source": [
@ -48,7 +47,6 @@
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "TfF7V662w60j"
},
"source": [
@ -64,7 +62,6 @@
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "Ah3Lv9zyw60l"
},
"source": [
@ -79,7 +76,6 @@
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "hxNJ8HByw60o"
},
"source": [
@ -95,7 +91,31 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "zyMjTVx1Ibt-",
"outputId": "216eeecc-3785-466b-d126-94f02abe19d8"
},
"outputs": [],
"source": [
"#Select lower version of tensroflow on Google Colab\n",
"%tensorflow_version 1.x\n",
"import tensorflow\n",
"print(tensorflow.__version__)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "1PXBHL8KIUIF",
"outputId": "cb05ba96-4d21-4483-de95-dd724383dd80"
},
"outputs": [],
"source": [
"!nvidia-smi"
@ -104,8 +124,7 @@
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "hxNJ8HByw60o"
"id": "BgfmN3z6IUIF"
},
"source": [
"### 2.b Download the required files from NVIDIA-Github:"
@ -115,9 +134,11 @@
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "KV_WnOY4zUa_"
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "KV_WnOY4zUa_",
"outputId": "a1b0f548-75f5-432d-8fbd-6fbad5f421a3"
},
"outputs": [],
"source": [
@ -130,9 +151,11 @@
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "5D7i7Pao5qoj"
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "5D7i7Pao5qoj",
"outputId": "d172417e-1cb1-4138-e1f4-1eda0f3d72bd"
},
"outputs": [],
"source": [
@ -146,7 +169,6 @@
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "mjlZbP0dw60r"
},
"source": [
@ -160,7 +182,6 @@
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "mOc16svBw60t"
},
"source": [
@ -201,9 +222,11 @@
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "srU0TT1Iw60v"
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "srU0TT1Iw60v",
"outputId": "b5340882-7b24-4343-ebc2-9ff709b489a3"
},
"outputs": [],
"source": [
@ -245,8 +268,6 @@
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "ujyka-8Iw603"
},
"outputs": [],
@ -263,8 +284,6 @@
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "6gA3-6LVw61D"
},
"outputs": [],
@ -275,7 +294,6 @@
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "D9p8XaBnw61N"
},
"source": [
@ -292,7 +310,6 @@
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "ceeYPqQcw61P"
},
"source": [
@ -303,8 +320,6 @@
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "k4jIJevFw61R"
},
"outputs": [],
@ -316,30 +331,6 @@
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "rt_4-ZA5w61Y"
},
"source": [
"We can choose the mixed precision model (which takes much less time to train than the fp32 version) without losing accuracy, with the following flag: "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "BRdclfEaw61Z"
},
"outputs": [],
"source": [
"use_mixed_precision_model = True"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "iu4Jb5puw61p"
},
"source": [
@ -355,9 +346,7 @@
"We will take advantage of the fine-tuned models available on NGC (NVIDIA GPU Cluster, https://ngc.nvidia.com).\n",
"Among the many configurations available we will download these two:\n",
"\n",
" - **bert_tf_v2_large_fp32_384**\n",
"\n",
" - **bert_tf_v2_large_fp16_384**\n",
" - **bert_tf_ckpt_large_qa_squad2_amp_384**\n",
"\n",
"Which are trained on the SQuaD 2.0 Dataset."
]
@ -366,31 +355,26 @@
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "5JWKZfP8w61t"
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "5JWKZfP8w61t",
"outputId": "c5b92300-1f61-435e-9f6c-e3b3c116161e"
},
"outputs": [],
"source": [
"# bert_tf_v2_large_fp32_384\n",
"DATA_DIR_FP32 = os.path.join(data_dir, 'finetuned_model_fp32')\n",
"!mkdir -p $DATA_DIR_FP32\n",
"!wget -nc -q --show-progress -O $DATA_DIR_FP32/bert_tf_v2_large_fp32_384.zip \\\n",
"https://api.ngc.nvidia.com/v2/models/nvidia/bert_tf_v2_large_fp32_384/versions/1/zip\n",
"!unzip -n -d $DATA_DIR_FP32/ $DATA_DIR_FP32/bert_tf_v2_large_fp32_384.zip \n",
" \n",
"# bert_tf_v2_large_fp16_384\n",
"DATA_DIR_FP16 = os.path.join(data_dir, 'finetuned_model_fp16')\n",
"!mkdir -p $DATA_DIR_FP16\n",
"!wget -nc -q --show-progress -O $DATA_DIR_FP16/bert_tf_v2_large_fp16_384.zip \\\n",
"https://api.ngc.nvidia.com/v2/models/nvidia/bert_tf_v2_large_fp16_384/versions/1/zip\n",
"!unzip -n -d $DATA_DIR_FP16/ $DATA_DIR_FP16/bert_tf_v2_large_fp16_384.zip "
"# bert_tf_ckpt_large_qa_squad2_amp_384\n",
"DATA_DIR_FT = os.path.join(data_dir, 'finetuned_large_model')\n",
"!mkdir -p $DATA_DIR_FT \n",
"!wget --content-disposition -O $DATA_DIR_FT/bert_tf_ckpt_large_qa_squad2_amp_384_19.03.1.zip \\\n",
"https://api.ngc.nvidia.com/v2/models/nvidia/bert_tf_ckpt_large_qa_squad2_amp_384/versions/19.03.1/zip \\\n",
"&& unzip -n -d $DATA_DIR_FT/ $DATA_DIR_FT/bert_tf_ckpt_large_qa_squad2_amp_384_19.03.1.zip \\\n",
"&& rm $DATA_DIR_FT/bert_tf_ckpt_large_qa_squad2_amp_384_19.03.1.zip"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "GrFrZickw61z"
},
"source": [
@ -400,7 +384,6 @@
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "cU8mGJDa1FfX"
},
"source": [
@ -411,9 +394,11 @@
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "5hRb96NKE3X0"
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "5hRb96NKE3X0",
"outputId": "85b50f58-d89f-41b1-8ac1-55c3117743ea"
},
"outputs": [],
"source": [
@ -427,7 +412,6 @@
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "VY1Dipam15DE"
},
"source": [
@ -438,8 +422,6 @@
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "jqAJob92C2wA"
},
"outputs": [],
@ -447,13 +429,12 @@
"try:\n",
" __import__(\"horovod\")\n",
"except ImportError:\n",
" os.system(\"pip install horovod\")"
" os.system(\"pip install --no-cache-dir horovod\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "5NuuGNsDw611"
},
"source": [
@ -468,8 +449,6 @@
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "_c2qCQ9-w613"
},
"outputs": [],
@ -490,16 +469,14 @@
"\n",
"# The config json file corresponding to the pre-trained BERT model.\n",
"# This specifies the model architecture.\n",
"bert_config_file = os.path.join(data_dir, 'google_pretrained_weights/uncased_L-24_H-1024_A-16/bert_config.json')\n",
"bert_config_file = os.path.join(data_dir, 'finetuned_large_model_SQUAD2.0/bert_config.json')\n",
"\n",
"# The vocabulary file that the BERT model was trained on.\n",
"vocab_file = os.path.join(data_dir, 'google_pretrained_weights/uncased_L-24_H-1024_A-16/vocab.txt')\n",
"vocab_file = os.path.join(data_dir, 'finetuned_large_model_SQUAD2.0/vocab.txt')\n",
"\n",
"# Initiate checkpoint to the fine-tuned BERT Large model\n",
"init_checkpoint = os.path.join(data_dir, 'finetuned_large_model/model.ckpt')\n",
"\n",
"# Depending on the mixed precision flag we use different fine-tuned model\n",
"if use_mixed_precision_model:\n",
" init_checkpoint = os.path.join(data_dir, 'finetuned_model_fp16/model.ckpt-8144')\n",
"else:\n",
" init_checkpoint = os.path.join(data_dir, 'finetuned_model_fp32/model.ckpt-8144')\n",
"\n",
"# Whether to lower case the input text. \n",
"# Should be True for uncased models and False for cased models.\n",
@ -542,7 +519,6 @@
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "2h_eLUgPw618"
},
"source": [
@ -553,9 +529,11 @@
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "RXHdoUb9w619"
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "RXHdoUb9w619",
"outputId": "6079b4b3-c535-4b0e-aa17-d8a505cdb3ea"
},
"outputs": [],
"source": [
@ -611,7 +589,6 @@
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "xSKkf4JLw62E"
},
"source": [
@ -622,9 +599,12 @@
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 1000
},
"id": "3OKhc349w62F",
"outputId": "07d056d2-6d55-4922-88b3-bd25802f70fd",
"scrolled": true
},
"outputs": [],
@ -702,7 +682,8 @@
"run_squad.write_predictions(eval_examples, eval_features, all_results,\n",
" n_best_size, max_answer_length,\n",
" do_lower_case, output_prediction_file,\n",
" output_nbest_file, output_null_log_odds_file)\n",
" output_nbest_file, output_null_log_odds_file,\n",
" version_2_with_negative, verbose_logging)\n",
"\n",
"tf.logging.info(\"Inference Results:\")\n",
"\n",
@ -721,7 +702,6 @@
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "EMT0sKxHw62L"
},
"source": [
@ -731,7 +711,6 @@
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "mKBM_UD6w62N"
},
"source": [
@ -745,7 +724,7 @@
"metadata": {
"accelerator": "GPU",
"colab": {
"name": "bert_squad_tf_inference.ipynb",
"name": "Copy of bert_squad_tf_inference.ipynb",
"provenance": []
},
"kernelspec": {

View file

@ -6,7 +6,7 @@
"metadata": {},
"outputs": [],
"source": [
"# Copyright 2019 NVIDIA Corporation. All Rights Reserved.\n",
"# Copyright 2021 NVIDIA Corporation. All Rights Reserved.\n",
"#\n",
"# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
"# you may not use this file except in compliance with the License.\n",
@ -227,22 +227,6 @@
"os.environ[\"TF_ENABLE_AUTO_MIXED_PRECISION\"] = \"1\" "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The model we'll use was trained with mixed precision model, which takes much less time to train than the fp32 version, without losing accuracy. So we'll need to set with the following flag: "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"use_mixed_precision_model = True"
]
},
{
"cell_type": "markdown",
"metadata": {},
@ -266,12 +250,13 @@
"metadata": {},
"outputs": [],
"source": [
"# biobert_uncased_base_ner_disease\n",
"DATA_DIR_FP16 = '../data/download/finetuned_model_fp16'\n",
"!mkdir -p $DATA_DIR_FP16\n",
"!wget -nc -q --show-progress -O $DATA_DIR_FP16/biobert_uncased_base_ner_disease.zip \\\n",
"https://api.ngc.nvidia.com/v2/models/nvidia/biobert_uncased_base_ner_disease/versions/1/zip\n",
"!unzip -n -d $DATA_DIR_FP16/ $DATA_DIR_FP16/biobert_uncased_base_ner_disease.zip "
"# biobert_tf_uncased_base_ner_disease\n",
"DATA_DIR_BIOBERT = '../data/download/finetuned_biobert_model'\n",
"!mkdir -p $DATA_DIR_BIOBERT\n",
"!wget --content-disposition -O $DATA_DIR_BIOBERT/biobert_tf_uncased_base_ner_disease_19.08.1.zip \\\n",
"https://api.ngc.nvidia.com/v2/models/nvidia/biobert_tf_uncased_base_ner_disease/versions/19.08.1/zip \\\n",
"&& unzip -n -d $DATA_DIR_BIOBERT/ $DATA_DIR_BIOBERT/biobert_tf_uncased_base_ner_disease_19.08.1.zip \\\n",
"&& rm $DATA_DIR_BIOBERT/biobert_tf_uncased_base_ner_disease_19.08.1.zip"
]
},
{
@ -321,12 +306,12 @@
"\n",
"# The config json file corresponding to the pre-trained BERT model.\n",
"# This specifies the model architecture.\n",
"bert_config_file = os.path.join(DATA_DIR_FP16, 'bert_config.json')\n",
"bert_config_file = os.path.join(DATA_DIR_BIOBERT, 'bert_config.json')\n",
"\n",
"# The vocabulary file that the BERT model was trained on.\n",
"vocab_file = os.path.join(DATA_DIR_FP16, 'vocab.txt')\n",
"vocab_file = os.path.join(DATA_DIR_BIOBERT, 'vocab.txt')\n",
"\n",
"init_checkpoint = os.path.join(DATA_DIR_FP16, 'model.ckpt-10251')\n",
"init_checkpoint = os.path.join(DATA_DIR_BIOBERT, 'model.ckpt')\n",
"\n",
"# Whether to lower case the input text. \n",
"# Should be True for uncased models and False for cased models.\n",
@ -397,13 +382,13 @@
"model_fn = model_fn_builder(\n",
" bert_config=bert_config,\n",
" num_labels=len(label_list) + 1,\n",
" init_checkpoint=init_checkpoint,\n",
" use_fp16=use_mixed_precision_model)\n",
" init_checkpoint=init_checkpoint)\n",
"# amp=use_amp)\n",
"\n",
"estimator = tf.estimator.Estimator(\n",
" model_fn=model_fn,\n",
" config=run_config,\n",
" params=params)"
" model_fn=model_fn,\n",
" config=run_config,\n",
" params=params)"
]
},
{

View file

@ -0,0 +1,100 @@
The O
authors O
describe O
the O
case O
of O
a O
56 O
year O
old O
woman O
with O
chronic O
, O
severe O
heart O
failure O
secondary O
to O
dilated O
cardiomyopathy O
and O
absence O
of O
significant O
ventricular O
arrhythmias O
who O
developed O
QT O
prolongation O
and O
torsade O
de O
pointes O
ventricular O
tachycardia O
during O
one O
cycle O
of O
intermittent O
low O
dose O
( O
2.5 O
mcg O
/ O
kg O
per O
min O
) O
dobutamine O
. O
This O
report O
of O
torsade O
de O
pointes O
ventricular O
tachycardia O
during O
intermittent O
dobutamine O
supports O
the O
hypothesis O
that O
unpredictable O
fatal O
arrhythmias O
may O
occur O
even O
with O
low O
doses O
and O
in O
patients O
with O
no O
history O
of O
significant O
rhythm O
disturbances O
. O
The O
mechanisms O
of O
proarrhythmic O
effects O
of O
Dubutamine O
are O
discussed O
. O
1 The O
2 authors O
3 describe O
4 the O
5 case O
6 of O
7 a O
8 56 O
9 year O
10 old O
11 woman O
12 with O
13 chronic O
14 , O
15 severe O
16 heart O
17 failure O
18 secondary O
19 to O
20 dilated O
21 cardiomyopathy O
22 and O
23 absence O
24 of O
25 significant O
26 ventricular O
27 arrhythmias O
28 who O
29 developed O
30 QT O
31 prolongation O
32 and O
33 torsade O
34 de O
35 pointes O
36 ventricular O
37 tachycardia O
38 during O
39 one O
40 cycle O
41 of O
42 intermittent O
43 low O
44 dose O
45 ( O
46 2.5 O
47 mcg O
48 / O
49 kg O
50 per O
51 min O
52 ) O
53 dobutamine O
54 . O
55 This O
56 report O
57 of O
58 torsade O
59 de O
60 pointes O
61 ventricular O
62 tachycardia O
63 during O
64 intermittent O
65 dobutamine O
66 supports O
67 the O
68 hypothesis O
69 that O
70 unpredictable O
71 fatal O
72 arrhythmias O
73 may O
74 occur O
75 even O
76 with O
77 low O
78 doses O
79 and O
80 in O
81 patients O
82 with O
83 no O
84 history O
85 of O
86 significant O
87 rhythm O
88 disturbances O
89 . O
90 The O
91 mechanisms O
92 of O
93 proarrhythmic O
94 effects O
95 of O
96 Dubutamine O
97 are O
98 discussed O
99 . O

View file

@ -30,6 +30,7 @@ import tensorflow as tf
import horovod.tensorflow as hvd
import time
from utils.utils import LogEvalRunHook, LogTrainRunHook, setup_xla_flags
from utils.gpu_affinity import set_affinity
import utils.dllogger_class
from dllogger import Verbosity
from utils.create_glue_data import *
@ -511,6 +512,7 @@ def main(_):
master_process = (hvd.rank() == 0)
hvd_rank = hvd.rank()
config.gpu_options.visible_device_list = str(hvd.local_rank())
set_affinity(hvd.local_rank())
if hvd.size() > 1:
training_hooks.append(hvd.BroadcastGlobalVariablesHook(0))
if FLAGS.use_xla:

View file

@ -26,6 +26,7 @@ import tf_metrics
import time
import horovod.tensorflow as hvd
from utils.utils import LogEvalRunHook, LogTrainRunHook, setup_xla_flags
from utils.gpu_affinity import set_affinity
import utils.dllogger_class
from dllogger import Verbosity
@ -676,6 +677,7 @@ def main(_):
master_process = (hvd.rank() == 0)
hvd_rank = hvd.rank()
config.gpu_options.visible_device_list = str(hvd.local_rank())
set_affinity(hvd.local_rank())
if hvd.size() > 1:
training_hooks.append(hvd.BroadcastGlobalVariablesHook(0))

View file

@ -28,6 +28,7 @@ import tensorflow as tf
import glob
from utils.utils import LogEvalRunHook, setup_xla_flags
import utils.dllogger_class
from utils.gpu_affinity import set_affinity
from dllogger import Verbosity
from tensorflow.core.protobuf import rewriter_config_pb2
@ -573,6 +574,7 @@ def main(_):
config = tf.compat.v1.ConfigProto()
if FLAGS.horovod:
config.gpu_options.visible_device_list = str(hvd.local_rank())
set_affinity(hvd.local_rank())
if hvd.rank() == 0:
tf.compat.v1.logging.info("***** Configuaration *****")
for key in FLAGS.__flags.keys():

View file

@ -36,6 +36,7 @@ import tokenization
import time
import horovod.tensorflow as hvd
from utils.utils import LogEvalRunHook, LogTrainRunHook, setup_xla_flags
from utils.gpu_affinity import set_affinity
import utils.dllogger_class
from dllogger import Verbosity

View file

@ -37,6 +37,7 @@ import optimization
import tokenization
from utils.create_squad_data import *
from utils.utils import LogEvalRunHook, LogTrainRunHook, setup_xla_flags
from utils.gpu_affinity import set_affinity
import utils.dllogger_class
from dllogger import Verbosity
@ -983,6 +984,7 @@ def main(_):
master_process = (hvd.rank() == 0)
hvd_rank = hvd.rank()
config.gpu_options.visible_device_list = str(hvd.local_rank())
set_affinity(hvd.local_rank())
if hvd.size() > 1:
training_hooks.append(hvd.BroadcastGlobalVariablesHook(0))
if FLAGS.use_xla:
@ -1225,4 +1227,4 @@ def main(_):
if __name__ == "__main__":
FLAGS = extract_run_squad_flags()
tf.app.run()
tf.app.run()

View file

@ -0,0 +1,63 @@
# Copyright (c) 2020 NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import os
import pynvml
pynvml.nvmlInit()
def systemGetDriverVersion():
return pynvml.nvmlSystemGetDriverVersion()
def deviceGetCount():
return pynvml.nvmlDeviceGetCount()
class device:
# assume nvml returns list of 64 bit ints
_nvml_affinity_elements = math.ceil(os.cpu_count() / 64)
def __init__(self, device_idx):
super().__init__()
self.handle = pynvml.nvmlDeviceGetHandleByIndex(device_idx)
def getName(self):
return pynvml.nvmlDeviceGetName(self.handle)
def getCpuAffinity(self):
affinity_string = ''
for j in pynvml.nvmlDeviceGetCpuAffinity(
self.handle, device._nvml_affinity_elements
):
# assume nvml returns list of 64 bit ints
affinity_string = '{:064b}'.format(j) + affinity_string
affinity_list = [int(x) for x in affinity_string]
affinity_list.reverse() # so core 0 is in 0th element of list
return [i for i, e in enumerate(affinity_list) if e != 0]
def set_affinity(gpu_id=None):
if gpu_id is None:
gpu_id = int(os.getenv('LOCAL_RANK', 0))
dev = device(gpu_id)
os.sched_setaffinity(0, dev.getCpuAffinity())
# list of ints representing the logical cores this process is now affinitied with
return os.sched_getaffinity(0)