support arbitrary sequence

This commit is contained in:
xjia 2019-07-19 21:23:57 +08:00
parent 81b1b024e1
commit 9f00666591
11 changed files with 2116 additions and 13 deletions

4
FasterTransformer/.gitmodules vendored Normal file
View file

@ -0,0 +1,4 @@
[submodule "sample/fastertransformer_bert/bert"]
path = sample/tensorflow_bert/bert
url = https://github.com/google-research/bert.git

View file

@ -3,7 +3,7 @@ Faster Transformer
## What is it?
The Faster Transformer implements an equivalent but highly optimized BERT transformer layer for inference. On Volta and Turing GPUs, FP16 precision is used automatically to access the computing power of tensor cores.
Faster Transformer is built on top of the CUDA and cuBLAS. It supports three kinds of sequence lengths, 32, 64 and 128. Two key parameters of the transformer layer, the number of heads and the size of each head, are passed in runtime. Thus, not only the BERT Base (12 heads * 64 per head) , but also customized models like 4 heads * 32 per head and 8 heads * 96 per heads, are well supported. Our implementation shows good speedups on both small and large batch size cases.
Faster Transformer is built on top of the CUDA and cuBLAS. It supports sequence lengths that are larger than 3 and smaller or equal to 1024. Two key parameters of the transformer layer, the number of heads and the size of each head, are passed in runtime. Thus, not only the BERT Base (12 heads * 64 per head) , but also customized models like 4 heads * 32 per head and 8 heads * 96 per heads, are well supported. Our implementation shows good speedups on both small and large batch size cases.
C++ API, TensorRT plugin, and TensorFlow OP wrapper are available. You can easily integrate this optimized transformer layer into your TensorFlow or other inference service codes that built in native C++ or TensorRT. In addition to codes that illustrate the API invocations, we also provide a simple end-to-end BERT TensorFlow inference sample.
@ -61,7 +61,7 @@ For large batch size case, we report both Tensorflow XLA and faster transformer'
|--/trt_plugin: TensorRT plugin implementation
/sample: c++ and tensorflow transformer interface samples
|--/cpp: both FP16 and FP32 c++ interface samples
|--/fastertransformer_bert: samples that show of how to integrate our Tensorflow OP into the open source BERT model for sentence (and sentence-pair) classification tasks (GLUE), the samples support both FP16 and FP32, see readme file within this folder more details
|--/tensorflow_bert: samples that show of how to integrate our Tensorflow OP into the open source BERT model for sentence (and sentence-pair) classification tasks (GLUE), the samples support both FP16 and FP32, see readme file within this folder more details
|--/tensorflow: both FP16 and FP32 tensorflow OP samples
|--/tensorRT: both FP16 and FP32 tensorRT plugin samples
/tools/gemm_test: loop over all GEMM algorithms to pick the best one
@ -88,7 +88,8 @@ $ make
Note: xx is the compute capability of your GPU. For example, 60 (P40) or 61 (P4) or 70 (V100) or 75(T4).
### Execute demos ###
```shell
$ Step1 Generate the gemm_config.in file under the path build to pick GEMM algorithms for the best performance.
$ To achieve the best performance, please execute step1 and step2 together when you test a new model.
$ Step1 Generate the gemm_config.in file under the path build to pick GEMM algorithms for the best performance.
$ ./build/bin/gemm_fp16(32) <batch_size> <seq_len> <head_num> <size_per_head>
$ Step2 Execute demos
$ 1. Tensorflow demos: python build/transformer_fp16(32).py <batch_size> <num_layers> <seq_len> <head_num> <size_per_head>

View file

@ -166,8 +166,8 @@ void softmax_kernel(T* qk_buf_, const T* attr_mask, const int batch_size, const
for(int i = 0; i < seq_len; ++i)
{
T qk = qk_buf_[threadIdx.x + qk_offset];
T mask_val = attr_mask[threadIdx.x + mask_offset];
T qk = threadIdx.x < seq_len ? qk_buf_[threadIdx.x + qk_offset] : (T)(0.0f);
T mask_val = threadIdx.x < seq_len ? attr_mask[threadIdx.x + mask_offset] : (T)(0.0f);
mask_val = ((T)1.0f - mask_val) * (T)(-10000.0f);
@ -179,7 +179,9 @@ void softmax_kernel(T* qk_buf_, const T* attr_mask, const int batch_size, const
s_sum = sum_val + 1e-6f;
}
__syncthreads();
qk_buf_[threadIdx.x + qk_offset] = qk / (T)s_sum;
if(threadIdx.x < seq_len)
qk_buf_[threadIdx.x + qk_offset] = qk / (T)s_sum;
qk_offset += seq_len;
mask_offset += seq_len;
@ -199,12 +201,12 @@ void softmax_kernel_v2(T* qk_buf_, const T* attr_mask, const int batch_size, con
__shared__ float s_sum;
T qk = qk_buf_[threadIdx.x + qk_offset];
T mask_val = attr_mask[threadIdx.x + mask_offset];
T qk = threadIdx.x < seq_len ? qk_buf_[threadIdx.x + qk_offset] : (T)(0.0f);
T mask_val = threadIdx.x < seq_len ? attr_mask[threadIdx.x + mask_offset] : (T)(0.0f);
mask_val = ((T)1.0f - mask_val) * (T)(-10000.0f);
float qk_tmp = __expf((float)(qk * scaler + mask_val));
float qk_tmp = threadIdx.x < seq_len ? __expf((float)(qk * scaler + mask_val)) : 0.0f;
float sum_val = blockReduceSum<float>(qk_tmp);
if(threadIdx.x == 0)
@ -212,7 +214,9 @@ void softmax_kernel_v2(T* qk_buf_, const T* attr_mask, const int batch_size, con
s_sum = sum_val + 1e-6f;
}
__syncthreads();
qk_buf_[threadIdx.x + qk_offset] = (T)(qk_tmp / s_sum);
if(threadIdx.x < seq_len)
qk_buf_[threadIdx.x + qk_offset] = (T)(qk_tmp / s_sum);
}
template<typename T>
@ -272,7 +276,8 @@ void OpenMultiHeadAttention<OpType_>::multiHeadAttr_nofuse_kernelLauncher(
if(OpType_ == OperationType::FP32)
{
const int word_per_block = 32;
// const int word_per_block = 32;
const int word_per_block = 1;
assert(k > 1024);
assert(m / word_per_block * 3 > 65536);
@ -307,16 +312,27 @@ void OpenMultiHeadAttention<OpType_>::multiHeadAttr_nofuse_kernelLauncher(
computeType_,
static_cast<cublasGemmAlgo_t>(cublasAlgo_[1])));
if(seq_len <= 32)
block.x = 32;
else if(seq_len > 32 && seq_len <= 64)
block.x = 64;
else if(seq_len > 64 && seq_len <= 128)
block.x = 128;
else if(seq_len > 128 && seq_len <= 256)
block.x = 256;
else if(seq_len > 256 && seq_len <= 512)
block.x = 512;
else
block.x = 1024;
if(batch_size * head_num <= 120)
{
grid.x = batch_size * head_num * seq_len;
block.x = seq_len;
softmax_kernel_v2<DataType_><<<grid, block, 0, stream>>>(qk_buf_, attr_mask, batch_size, head_num, seq_len, scaler);
}
else
{
grid.x = batch_size * head_num;
block.x = seq_len;
softmax_kernel<DataType_><<<grid, block, 0, stream>>>(qk_buf_, attr_mask, batch_size, head_num, seq_len, scaler);
}
@ -336,8 +352,12 @@ void OpenMultiHeadAttention<OpType_>::multiHeadAttr_nofuse_kernelLauncher(
if(OpType_ == OperationType::HALF)
{
const int seq_per_block = 4;
// const int seq_per_block = 1;
grid.x = batch_size * head_num * seq_len / seq_per_block;
block.x = seq_per_block * size_per_head / 2;
assert(grid.x * seq_per_block != batch_size * head_num * seq_len);
transpose<DataType_><<<grid, block, 0, stream>>>(transpose_dst_, dst,
batch_size, seq_len, head_num, size_per_head / 2);
}

View file

@ -0,0 +1,70 @@
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# usage example
# python ckpt_type_convert.py --init_checkpoint=mrpc_output/model.ckpt-343 --fp16_checkpoint=mrpc_output/fp16_model.ckpt
import tensorflow as tf
import numpy as np
from tensorflow.contrib.framework.python.framework import checkpoint_utils
from tensorflow.python.ops import io_ops
from tensorflow.python.training.saver import BaseSaverBuilder
def checkpoint_dtype_cast(in_checkpoint_file, out_checkpoint_file):
var_list = checkpoint_utils.list_variables(tf.flags.FLAGS.init_checkpoint)
def init_graph():
for name, shape in var_list:
var = checkpoint_utils.load_variable(tf.flags.FLAGS.init_checkpoint, name)
recon_dtype = tf.float16 if var.dtype == np.float32 else var.dtype
tf.get_variable(name, shape=shape, dtype=recon_dtype)
init_graph()
saver = tf.train.Saver(builder=CastFromFloat32SaverBuilder())
with tf.Session() as sess:
saver.restore(sess, in_checkpoint_file)
saver.save(sess, 'tmp.ckpt')
tf.reset_default_graph()
init_graph()
saver = tf.train.Saver()
with tf.Session() as sess:
saver.restore(sess, 'tmp.ckpt')
saver.save(sess, out_checkpoint_file)
class CastFromFloat32SaverBuilder(BaseSaverBuilder):
# Based on tensorflow.python.training.saver.BulkSaverBuilder.bulk_restore
def bulk_restore(self, filename_tensor, saveables, preferred_shard,
restore_sequentially):
restore_specs = []
for saveable in saveables:
for spec in saveable.specs:
restore_specs.append((spec.name, spec.slice_spec, spec.dtype))
names, slices, dtypes = zip(*restore_specs)
restore_dtypes = [tf.float32 if dtype.base_dtype==tf.float16 else dtype for dtype in dtypes]
# print info
for i in range(len(restore_specs)):
print(names[i], 'from', restore_dtypes[i], 'to', dtypes[i].base_dtype)
with tf.device("cpu:0"):
restored = io_ops.restore_v2(
filename_tensor, names, slices, restore_dtypes)
return [tf.cast(r, dt.base_dtype) for r, dt in zip(restored, dtypes)]
if __name__ == '__main__':
tf.flags.DEFINE_string("fp16_checkpoint", None, "fp16 checkpoint file")
tf.flags.DEFINE_string("init_checkpoint", None, "initial checkpoint file")
checkpoint_dtype_cast(tf.flags.FLAGS.init_checkpoint, tf.flags.FLAGS.fp16_checkpoint)

View file

@ -0,0 +1,286 @@
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python import pywrap_tensorflow
from tensorflow.python.client import device_lib
import tensorflow as tf
import os
import sys
from my_modeling import *
build_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), '../../build/lib')
transformer_op_module = tf.load_op_library(
os.path.join(build_path, 'libtf_fastertransformer.so'))
def file_based_input_fn_builder_drop(input_file, seq_length, is_training,
drop_remainder):
""" Re-implementation of file_based_input_fn_builder function from modeling.py from Google's BERT repository https://github.com/google-research/bert
with drop_remainder=True.
"""
name_to_features = {
"input_ids": tf.FixedLenFeature([seq_length], tf.int64),
"input_mask": tf.FixedLenFeature([seq_length], tf.int64),
"segment_ids": tf.FixedLenFeature([seq_length], tf.int64),
"label_ids": tf.FixedLenFeature([], tf.int64),
"is_real_example": tf.FixedLenFeature([], tf.int64),
}
def _decode_record(record, name_to_features):
"""Decodes a record to a TensorFlow example."""
example = tf.parse_single_example(record, name_to_features)
# tf.Example only supports tf.int64, but the TPU only supports tf.int32.
# So cast all int64 to int32.
for name in list(example.keys()):
t = example[name]
if t.dtype == tf.int64:
t = tf.to_int32(t)
example[name] = t
return example
def input_fn(params):
"""The actual input function."""
batch_size = params["batch_size"]
# For training, we want a lot of parallel reading and shuffling.
# For eval, we want no shuffling and parallel reading doesn't matter.
d = tf.data.TFRecordDataset(input_file)
if is_training:
d = d.repeat()
d = d.shuffle(buffer_size=100)
# FASTINFER: drop remainder always
d = d.apply(
tf.contrib.data.map_and_batch(
lambda record: _decode_record(record, name_to_features),
batch_size=batch_size,
drop_remainder=True))
return d
return input_fn
def create_model(bert_config, is_training, input_ids, input_mask, segment_ids,
labels, num_labels, use_one_hot_embeddings):
"""Creates a classification model."""
model = BertModel(
config=bert_config,
is_training=is_training,
input_ids=input_ids,
input_mask=input_mask,
token_type_ids=segment_ids,
use_one_hot_embeddings=use_one_hot_embeddings)
output_layer = model.get_pooled_output()
hidden_size = output_layer.shape[-1].value
output_weights = tf.get_variable(
"output_weights", [num_labels, hidden_size],
dtype=tf.flags.FLAGS.floatx,
initializer=tf.truncated_normal_initializer(stddev=0.02))
output_bias = tf.get_variable(
"output_bias", [num_labels],
dtype=tf.flags.FLAGS.floatx,
initializer=tf.zeros_initializer())
with tf.variable_scope("loss"):
if is_training:
# I.e., 0.1 dropout
output_layer = tf.nn.dropout(output_layer, keep_prob=0.9)
logits = tf.matmul(output_layer, output_weights, transpose_b=True)
logits = tf.nn.bias_add(logits, output_bias)
probabilities = tf.nn.softmax(logits, axis=-1)
log_probs = tf.nn.log_softmax(logits, axis=-1)
one_hot_labels = tf.one_hot(labels, depth=num_labels, dtype=tf.flags.FLAGS.floatx)
per_example_loss = -tf.reduce_sum(one_hot_labels * log_probs, axis=-1)
loss = tf.reduce_mean(per_example_loss)
return (loss, per_example_loss, logits, probabilities)
def get_available_gpus():
local_device_protos = device_lib.list_local_devices()
return [x.name for x in local_device_protos if x.device_type == 'GPU']
def fast_transformer_model_trans(input_tensor,
attention_mask=None,
hidden_size=768,
num_hidden_layers=12,
num_attention_heads=12,
intermediate_size=3072,
intermediate_act_fn=gelu,
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
initializer_range=0.02,
do_return_all_layers=False):
""" Re-implementation of transformer_model function from modeling.py from Google's BERT repository https://github.com/google-research/bert
using FasterTransformer Tensorflow op.
Multi-headed, multi-layer Transformer from "Attention is All You Need".
This is almost an exact implementation of the original Transformer encoder.
See the original paper:
https://arxiv.org/abs/1706.03762
Also see:
https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/models/transformer.py
Args:
input_tensor: float Tensor of shape [batch_size, seq_length, hidden_size].
attention_mask: (optional) int32 Tensor of shape [batch_size, seq_length,
seq_length], with 1 for positions that can be attended to and 0 in
positions that should not be.
hidden_size: int. Hidden size of the Transformer.
num_hidden_layers: int. Number of layers (blocks) in the Transformer.
num_attention_heads: int. Number of attention heads in the Transformer.
intermediate_size: int. The size of the "intermediate" (a.k.a., feed
forward) layer.
intermediate_act_fn: function. The non-linear activation function to apply
to the output of the intermediate/feed-forward layer.
hidden_dropout_prob: float. Dropout probability for the hidden layers.
attention_probs_dropout_prob: float. Dropout probability of the attention
probabilities.
initializer_range: float. Range of the initializer (stddev of truncated
normal).
do_return_all_layers: Whether to also return all layers or just the final
layer.
Returns:
float Tensor of shape [batch_size, seq_length, hidden_size], the final
hidden layer of the Transformer.
Raises:
ValueError: A Tensor shape or parameter is invalid.
"""
if hidden_size % num_attention_heads != 0:
raise ValueError(
"The hidden size (%d) is not a multiple of the number of attention "
"heads (%d)" % (hidden_size, num_attention_heads))
attention_head_size = int(hidden_size / num_attention_heads)
input_shape = get_shape_list(input_tensor, expected_rank=3)
batch_size = input_shape[0]
seq_length = input_shape[1]
input_width = input_shape[2]
# The Transformer performs sum residuals on all layers so the input needs
# to be the same as the hidden size.
if input_width != hidden_size:
raise ValueError("The width of the input tensor (%d) != hidden size (%d)" %
(input_width, hidden_size))
# We keep the representation as a 2D tensor to avoid re-shaping it back and
# forth from a 3D tensor to a 2D tensor. Re-shapes are normally free on
# the GPU/CPU but may not be free on the TPU, so we want to minimize them to
# help the optimizer.
prev_output = reshape_to_matrix(input_tensor)
all_layer_outputs = []
for layer_idx in range(num_hidden_layers):
with tf.variable_scope("layer_%d" % layer_idx):
layer_input = prev_output
with tf.variable_scope("attention"):
attention_heads = []
with tf.variable_scope("self"):
attention_head = attention_layer(
from_tensor=layer_input,
to_tensor=layer_input,
attention_mask=attention_mask,
num_attention_heads=num_attention_heads,
size_per_head=attention_head_size,
attention_probs_dropout_prob=attention_probs_dropout_prob,
initializer_range=initializer_range,
do_return_2d_tensor=True,
batch_size=batch_size,
from_seq_length=seq_length,
to_seq_length=seq_length)
attention_heads.append(attention_head)
attention_output = None
if len(attention_heads) == 1:
attention_output = attention_heads[0]
else:
# In the case where we have other sequences, we just concatenate
# them to the self-attention head before the projection.
attention_output = tf.concat(attention_heads, axis=-1)
# Run a linear projection of `hidden_size` then add a residual
# with `layer_input`.
with tf.variable_scope("output"):
attention_output = tf.layers.dense(
attention_output,
hidden_size,
kernel_initializer=create_initializer(initializer_range))
attention_output = dropout(
attention_output, hidden_dropout_prob)
attention_output = layer_norm(
attention_output + layer_input)
# The activation is only applied to the "intermediate" hidden layer.
with tf.variable_scope("intermediate"):
intermediate_output = tf.layers.dense(
attention_output,
intermediate_size,
activation=intermediate_act_fn,
kernel_initializer=create_initializer(initializer_range))
# Down-project back to `hidden_size` then add the residual.
with tf.variable_scope("output"):
layer_output = tf.layers.dense(
intermediate_output,
hidden_size,
kernel_initializer=create_initializer(initializer_range))
layer_output = dropout(layer_output, hidden_dropout_prob)
layer_output = layer_norm(layer_output + attention_output)
# FASTINFER: fast transformer encoder inference
trainable_vars = tf.get_collection(
tf.GraphKeys.TRAINABLE_VARIABLES, scope=tf.get_variable_scope().name)
layer_output = transformer_op_module.bert_transformer(
layer_input,
layer_input,
trainable_vars[0], trainable_vars[2], trainable_vars[4], trainable_vars[1], trainable_vars[3], trainable_vars[5],
attention_mask,
trainable_vars[6], trainable_vars[7], trainable_vars[8], trainable_vars[9], trainable_vars[10], trainable_vars[11],
trainable_vars[12], trainable_vars[13], trainable_vars[14], trainable_vars[15],
batch_size=batch_size, from_seq_len=seq_length, to_seq_len=seq_length, head_num=num_attention_heads, size_per_head=attention_head_size)
prev_output = layer_output
all_layer_outputs.append(layer_output)
if do_return_all_layers:
final_outputs = []
for layer_output in all_layer_outputs:
final_output = reshape_from_matrix(layer_output, input_shape)
final_outputs.append(final_output)
return final_outputs
else:
final_output = reshape_from_matrix(prev_output, input_shape)
return final_output

View file

@ -0,0 +1,991 @@
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is mostly the same as bert/modeling.py from Google's BERT repository https://github.com/google-research/bert
# with configurable float types by setting tf.flags.FLAGS.floatx
"""The main BERT model and related functions."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import copy
import json
import math
import re
import numpy as np
import six
import tensorflow as tf
class BertConfig(object):
"""Configuration for `BertModel`."""
def __init__(self,
vocab_size,
hidden_size=768,
num_hidden_layers=12,
num_attention_heads=12,
intermediate_size=3072,
hidden_act="gelu",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
max_position_embeddings=512,
type_vocab_size=16,
initializer_range=0.02):
"""Constructs BertConfig.
Args:
vocab_size: Vocabulary size of `inputs_ids` in `BertModel`.
hidden_size: Size of the encoder layers and the pooler layer.
num_hidden_layers: Number of hidden layers in the Transformer encoder.
num_attention_heads: Number of attention heads for each attention layer in
the Transformer encoder.
intermediate_size: The size of the "intermediate" (i.e., feed-forward)
layer in the Transformer encoder.
hidden_act: The non-linear activation function (function or string) in the
encoder and pooler.
hidden_dropout_prob: The dropout probability for all fully connected
layers in the embeddings, encoder, and pooler.
attention_probs_dropout_prob: The dropout ratio for the attention
probabilities.
max_position_embeddings: The maximum sequence length that this model might
ever be used with. Typically set this to something large just in case
(e.g., 512 or 1024 or 2048).
type_vocab_size: The vocabulary size of the `token_type_ids` passed into
`BertModel`.
initializer_range: The stdev of the truncated_normal_initializer for
initializing all weight matrices.
"""
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.hidden_act = hidden_act
self.intermediate_size = intermediate_size
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.max_position_embeddings = max_position_embeddings
self.type_vocab_size = type_vocab_size
self.initializer_range = initializer_range
@classmethod
def from_dict(cls, json_object):
"""Constructs a `BertConfig` from a Python dictionary of parameters."""
config = BertConfig(vocab_size=None)
for (key, value) in six.iteritems(json_object):
config.__dict__[key] = value
return config
@classmethod
def from_json_file(cls, json_file):
"""Constructs a `BertConfig` from a json file of parameters."""
with tf.gfile.GFile(json_file, "r") as reader:
text = reader.read()
return cls.from_dict(json.loads(text))
def to_dict(self):
"""Serializes this instance to a Python dictionary."""
output = copy.deepcopy(self.__dict__)
return output
def to_json_string(self):
"""Serializes this instance to a JSON string."""
return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
class BertModel(object):
"""BERT model ("Bidirectional Encoder Representations from Transformers").
Example usage:
```python
# Already been converted into WordPiece token ids
input_ids = tf.constant([[31, 51, 99], [15, 5, 0]])
input_mask = tf.constant([[1, 1, 1], [1, 1, 0]])
token_type_ids = tf.constant([[0, 0, 1], [0, 2, 0]])
config = modeling.BertConfig(vocab_size=32000, hidden_size=512,
num_hidden_layers=8, num_attention_heads=6, intermediate_size=1024)
model = modeling.BertModel(config=config, is_training=True,
input_ids=input_ids, input_mask=input_mask, token_type_ids=token_type_ids)
label_embeddings = tf.get_variable(...)
pooled_output = model.get_pooled_output()
logits = tf.matmul(pooled_output, label_embeddings)
...
```
"""
def __init__(self,
config,
is_training,
input_ids,
input_mask=None,
token_type_ids=None,
use_one_hot_embeddings=False,
scope=None):
"""Constructor for BertModel.
Args:
config: `BertConfig` instance.
is_training: bool. true for training model, false for eval model. Controls
whether dropout will be applied.
input_ids: int32 Tensor of shape [batch_size, seq_length].
input_mask: (optional) int32 Tensor of shape [batch_size, seq_length].
token_type_ids: (optional) int32 Tensor of shape [batch_size, seq_length].
use_one_hot_embeddings: (optional) bool. Whether to use one-hot word
embeddings or tf.embedding_lookup() for the word embeddings.
scope: (optional) variable scope. Defaults to "bert".
Raises:
ValueError: The config is invalid or one of the input tensor shapes
is invalid.
"""
config = copy.deepcopy(config)
if not is_training:
config.hidden_dropout_prob = 0.0
config.attention_probs_dropout_prob = 0.0
input_shape = get_shape_list(input_ids, expected_rank=2)
batch_size = input_shape[0]
seq_length = input_shape[1]
if input_mask is None:
input_mask = tf.ones(shape=[batch_size, seq_length], dtype=tf.int32)
if token_type_ids is None:
token_type_ids = tf.zeros(shape=[batch_size, seq_length], dtype=tf.int32)
with tf.variable_scope(scope, default_name="bert"):
with tf.variable_scope("embeddings"):
# Perform embedding lookup on the word ids.
(self.embedding_output, self.embedding_table) = embedding_lookup(
input_ids=input_ids,
vocab_size=config.vocab_size,
embedding_size=config.hidden_size,
initializer_range=config.initializer_range,
word_embedding_name="word_embeddings",
use_one_hot_embeddings=use_one_hot_embeddings)
# Add positional embeddings and token type embeddings, then layer
# normalize and perform dropout.
self.embedding_output = embedding_postprocessor(
input_tensor=self.embedding_output,
use_token_type=True,
token_type_ids=token_type_ids,
token_type_vocab_size=config.type_vocab_size,
token_type_embedding_name="token_type_embeddings",
use_position_embeddings=True,
position_embedding_name="position_embeddings",
initializer_range=config.initializer_range,
max_position_embeddings=config.max_position_embeddings,
dropout_prob=config.hidden_dropout_prob)
with tf.variable_scope("encoder"):
# This converts a 2D mask of shape [batch_size, seq_length] to a 3D
# mask of shape [batch_size, seq_length, seq_length] which is used
# for the attention scores.
attention_mask = create_attention_mask_from_input_mask(
input_ids, input_mask)
# Run the stacked transformer.
# `sequence_output` shape = [batch_size, seq_length, hidden_size].
self.all_encoder_layers = transformer_model(
input_tensor=self.embedding_output,
attention_mask=attention_mask,
hidden_size=config.hidden_size,
num_hidden_layers=config.num_hidden_layers,
num_attention_heads=config.num_attention_heads,
intermediate_size=config.intermediate_size,
intermediate_act_fn=get_activation(config.hidden_act),
hidden_dropout_prob=config.hidden_dropout_prob,
attention_probs_dropout_prob=config.attention_probs_dropout_prob,
initializer_range=config.initializer_range,
do_return_all_layers=True)
self.sequence_output = self.all_encoder_layers[-1]
# The "pooler" converts the encoded sequence tensor of shape
# [batch_size, seq_length, hidden_size] to a tensor of shape
# [batch_size, hidden_size]. This is necessary for segment-level
# (or segment-pair-level) classification tasks where we need a fixed
# dimensional representation of the segment.
with tf.variable_scope("pooler"):
# We "pool" the model by simply taking the hidden state corresponding
# to the first token. We assume that this has been pre-trained
first_token_tensor = tf.squeeze(self.sequence_output[:, 0:1, :], axis=1)
self.pooled_output = tf.layers.dense(
first_token_tensor,
config.hidden_size,
activation=tf.tanh,
kernel_initializer=create_initializer(config.initializer_range))
def get_pooled_output(self):
return self.pooled_output
def get_sequence_output(self):
"""Gets final hidden layer of encoder.
Returns:
float Tensor of shape [batch_size, seq_length, hidden_size] corresponding
to the final hidden of the transformer encoder.
"""
return self.sequence_output
def get_all_encoder_layers(self):
return self.all_encoder_layers
def get_embedding_output(self):
"""Gets output of the embedding lookup (i.e., input to the transformer).
Returns:
float Tensor of shape [batch_size, seq_length, hidden_size] corresponding
to the output of the embedding layer, after summing the word
embeddings with the positional embeddings and the token type embeddings,
then performing layer normalization. This is the input to the transformer.
"""
return self.embedding_output
def get_embedding_table(self):
return self.embedding_table
def gelu(x):
"""Gaussian Error Linear Unit.
This is a smoother version of the RELU.
Original paper: https://arxiv.org/abs/1606.08415
Args:
x: float Tensor to perform activation.
Returns:
`x` with the GELU activation applied.
"""
cdf = 0.5 * (1.0 + tf.tanh(
(np.sqrt(2 / np.pi) * (x + 0.044715 * tf.pow(x, 3)))))
return x * cdf
def get_activation(activation_string):
"""Maps a string to a Python function, e.g., "relu" => `tf.nn.relu`.
Args:
activation_string: String name of the activation function.
Returns:
A Python function corresponding to the activation function. If
`activation_string` is None, empty, or "linear", this will return None.
If `activation_string` is not a string, it will return `activation_string`.
Raises:
ValueError: The `activation_string` does not correspond to a known
activation.
"""
# We assume that anything that"s not a string is already an activation
# function, so we just return it.
if not isinstance(activation_string, six.string_types):
return activation_string
if not activation_string:
return None
act = activation_string.lower()
if act == "linear":
return None
elif act == "relu":
return tf.nn.relu
elif act == "gelu":
return gelu
elif act == "tanh":
return tf.tanh
else:
raise ValueError("Unsupported activation: %s" % act)
def get_assignment_map_from_checkpoint(tvars, init_checkpoint):
"""Compute the union of the current variables and checkpoint variables."""
assignment_map = {}
initialized_variable_names = {}
name_to_variable = collections.OrderedDict()
for var in tvars:
name = var.name
m = re.match("^(.*):\\d+$", name)
if m is not None:
name = m.group(1)
name_to_variable[name] = var
init_vars = tf.train.list_variables(init_checkpoint)
assignment_map = collections.OrderedDict()
for x in init_vars:
(name, var) = (x[0], x[1])
if name not in name_to_variable:
continue
assignment_map[name] = name
initialized_variable_names[name] = 1
initialized_variable_names[name + ":0"] = 1
return (assignment_map, initialized_variable_names)
def dropout(input_tensor, dropout_prob):
"""Perform dropout.
Args:
input_tensor: float Tensor.
dropout_prob: Python float. The probability of dropping out a value (NOT of
*keeping* a dimension as in `tf.nn.dropout`).
Returns:
A version of `input_tensor` with dropout applied.
"""
if dropout_prob is None or dropout_prob == 0.0:
return input_tensor
output = tf.nn.dropout(input_tensor, 1.0 - dropout_prob)
return output
def layer_norm(input_tensor, name=None):
"""Run layer normalization on the last dimension of the tensor."""
return tf.contrib.layers.layer_norm(
inputs=input_tensor, begin_norm_axis=-1, begin_params_axis=-1, scope=name)
def layer_norm_and_dropout(input_tensor, dropout_prob, name=None):
"""Runs layer normalization followed by dropout."""
output_tensor = layer_norm(input_tensor, name)
output_tensor = dropout(output_tensor, dropout_prob)
return output_tensor
def create_initializer(initializer_range=0.02):
"""Creates a `truncated_normal_initializer` with the given range."""
return tf.truncated_normal_initializer(stddev=initializer_range, dtype=tf.flags.FLAGS.floatx)
def embedding_lookup(input_ids,
vocab_size,
embedding_size=128,
initializer_range=0.02,
word_embedding_name="word_embeddings",
use_one_hot_embeddings=False):
"""Looks up words embeddings for id tensor.
Args:
input_ids: int32 Tensor of shape [batch_size, seq_length] containing word
ids.
vocab_size: int. Size of the embedding vocabulary.
embedding_size: int. Width of the word embeddings.
initializer_range: float. Embedding initialization range.
word_embedding_name: string. Name of the embedding table.
use_one_hot_embeddings: bool. If True, use one-hot method for word
embeddings. If False, use `tf.gather()`.
Returns:
float Tensor of shape [batch_size, seq_length, embedding_size].
"""
# This function assumes that the input is of shape [batch_size, seq_length,
# num_inputs].
#
# If the input is a 2D tensor of shape [batch_size, seq_length], we
# reshape to [batch_size, seq_length, 1].
if input_ids.shape.ndims == 2:
input_ids = tf.expand_dims(input_ids, axis=[-1])
embedding_table = tf.get_variable(
name=word_embedding_name,
shape=[vocab_size, embedding_size],
dtype=tf.flags.FLAGS.floatx,
initializer=create_initializer(initializer_range))
flat_input_ids = tf.reshape(input_ids, [-1])
if use_one_hot_embeddings:
one_hot_input_ids = tf.one_hot(flat_input_ids, depth=vocab_size)
output = tf.matmul(one_hot_input_ids, embedding_table)
else:
output = tf.gather(embedding_table, flat_input_ids)
input_shape = get_shape_list(input_ids)
output = tf.reshape(output,
input_shape[0:-1] + [input_shape[-1] * embedding_size])
return (output, embedding_table)
def embedding_postprocessor(input_tensor,
use_token_type=False,
token_type_ids=None,
token_type_vocab_size=16,
token_type_embedding_name="token_type_embeddings",
use_position_embeddings=True,
position_embedding_name="position_embeddings",
initializer_range=0.02,
max_position_embeddings=512,
dropout_prob=0.1):
"""Performs various post-processing on a word embedding tensor.
Args:
input_tensor: float Tensor of shape [batch_size, seq_length,
embedding_size].
use_token_type: bool. Whether to add embeddings for `token_type_ids`.
token_type_ids: (optional) int32 Tensor of shape [batch_size, seq_length].
Must be specified if `use_token_type` is True.
token_type_vocab_size: int. The vocabulary size of `token_type_ids`.
token_type_embedding_name: string. The name of the embedding table variable
for token type ids.
use_position_embeddings: bool. Whether to add position embeddings for the
position of each token in the sequence.
position_embedding_name: string. The name of the embedding table variable
for positional embeddings.
initializer_range: float. Range of the weight initialization.
max_position_embeddings: int. Maximum sequence length that might ever be
used with this model. This can be longer than the sequence length of
input_tensor, but cannot be shorter.
dropout_prob: float. Dropout probability applied to the final output tensor.
Returns:
float tensor with same shape as `input_tensor`.
Raises:
ValueError: One of the tensor shapes or input values is invalid.
"""
input_shape = get_shape_list(input_tensor, expected_rank=3)
batch_size = input_shape[0]
seq_length = input_shape[1]
width = input_shape[2]
output = input_tensor
if use_token_type:
if token_type_ids is None:
raise ValueError("`token_type_ids` must be specified if"
"`use_token_type` is True.")
token_type_table = tf.get_variable(
name=token_type_embedding_name,
shape=[token_type_vocab_size, width],
dtype=tf.flags.FLAGS.floatx,
initializer=create_initializer(initializer_range))
# This vocab will be small so we always do one-hot here, since it is always
# faster for a small vocabulary.
flat_token_type_ids = tf.reshape(token_type_ids, [-1])
one_hot_ids = tf.one_hot(flat_token_type_ids, depth=token_type_vocab_size, dtype=tf.flags.FLAGS.floatx)
token_type_embeddings = tf.matmul(one_hot_ids, token_type_table)
token_type_embeddings = tf.reshape(token_type_embeddings,
[batch_size, seq_length, width])
output += token_type_embeddings
if use_position_embeddings:
assert_op = tf.assert_less_equal(seq_length, max_position_embeddings)
with tf.control_dependencies([assert_op]):
full_position_embeddings = tf.get_variable(
name=position_embedding_name,
shape=[max_position_embeddings, width],
dtype=tf.flags.FLAGS.floatx,
initializer=create_initializer(initializer_range))
# Since the position embedding table is a learned variable, we create it
# using a (long) sequence length `max_position_embeddings`. The actual
# sequence length might be shorter than this, for faster training of
# tasks that do not have long sequences.
#
# So `full_position_embeddings` is effectively an embedding table
# for position [0, 1, 2, ..., max_position_embeddings-1], and the current
# sequence has positions [0, 1, 2, ... seq_length-1], so we can just
# perform a slice.
position_embeddings = tf.slice(full_position_embeddings, [0, 0],
[seq_length, -1])
num_dims = len(output.shape.as_list())
# Only the last two dimensions are relevant (`seq_length` and `width`), so
# we broadcast among the first dimensions, which is typically just
# the batch size.
position_broadcast_shape = []
for _ in range(num_dims - 2):
position_broadcast_shape.append(1)
position_broadcast_shape.extend([seq_length, width])
position_embeddings = tf.reshape(position_embeddings,
position_broadcast_shape)
output += position_embeddings
output = layer_norm_and_dropout(output, dropout_prob)
return output
def create_attention_mask_from_input_mask(from_tensor, to_mask):
"""Create 3D attention mask from a 2D tensor mask.
Args:
from_tensor: 2D or 3D Tensor of shape [batch_size, from_seq_length, ...].
to_mask: int32 Tensor of shape [batch_size, to_seq_length].
Returns:
float Tensor of shape [batch_size, from_seq_length, to_seq_length].
"""
from_shape = get_shape_list(from_tensor, expected_rank=[2, 3])
batch_size = from_shape[0]
from_seq_length = from_shape[1]
to_shape = get_shape_list(to_mask, expected_rank=2)
to_seq_length = to_shape[1]
to_mask = tf.cast(
tf.reshape(to_mask, [batch_size, 1, to_seq_length]), tf.flags.FLAGS.floatx)
# We don't assume that `from_tensor` is a mask (although it could be). We
# don't actually care if we attend *from* padding tokens (only *to* padding)
# tokens so we create a tensor of all ones.
#
# `broadcast_ones` = [batch_size, from_seq_length, 1]
broadcast_ones = tf.ones(
shape=[batch_size, from_seq_length, 1], dtype=tf.flags.FLAGS.floatx)
# Here we broadcast along two dimensions to create the mask.
mask = broadcast_ones * to_mask
return mask
def attention_layer(from_tensor,
to_tensor,
attention_mask=None,
num_attention_heads=1,
size_per_head=512,
query_act=None,
key_act=None,
value_act=None,
attention_probs_dropout_prob=0.0,
initializer_range=0.02,
do_return_2d_tensor=False,
batch_size=None,
from_seq_length=None,
to_seq_length=None):
"""Performs multi-headed attention from `from_tensor` to `to_tensor`.
This is an implementation of multi-headed attention based on "Attention
is all you Need". If `from_tensor` and `to_tensor` are the same, then
this is self-attention. Each timestep in `from_tensor` attends to the
corresponding sequence in `to_tensor`, and returns a fixed-with vector.
This function first projects `from_tensor` into a "query" tensor and
`to_tensor` into "key" and "value" tensors. These are (effectively) a list
of tensors of length `num_attention_heads`, where each tensor is of shape
[batch_size, seq_length, size_per_head].
Then, the query and key tensors are dot-producted and scaled. These are
softmaxed to obtain attention probabilities. The value tensors are then
interpolated by these probabilities, then concatenated back to a single
tensor and returned.
In practice, the multi-headed attention are done with transposes and
reshapes rather than actual separate tensors.
Args:
from_tensor: float Tensor of shape [batch_size, from_seq_length,
from_width].
to_tensor: float Tensor of shape [batch_size, to_seq_length, to_width].
attention_mask: (optional) int32 Tensor of shape [batch_size,
from_seq_length, to_seq_length]. The values should be 1 or 0. The
attention scores will effectively be set to -infinity for any positions in
the mask that are 0, and will be unchanged for positions that are 1.
num_attention_heads: int. Number of attention heads.
size_per_head: int. Size of each attention head.
query_act: (optional) Activation function for the query transform.
key_act: (optional) Activation function for the key transform.
value_act: (optional) Activation function for the value transform.
attention_probs_dropout_prob: (optional) float. Dropout probability of the
attention probabilities.
initializer_range: float. Range of the weight initializer.
do_return_2d_tensor: bool. If True, the output will be of shape [batch_size
* from_seq_length, num_attention_heads * size_per_head]. If False, the
output will be of shape [batch_size, from_seq_length, num_attention_heads
* size_per_head].
batch_size: (Optional) int. If the input is 2D, this might be the batch size
of the 3D version of the `from_tensor` and `to_tensor`.
from_seq_length: (Optional) If the input is 2D, this might be the seq length
of the 3D version of the `from_tensor`.
to_seq_length: (Optional) If the input is 2D, this might be the seq length
of the 3D version of the `to_tensor`.
Returns:
float Tensor of shape [batch_size, from_seq_length,
num_attention_heads * size_per_head]. (If `do_return_2d_tensor` is
true, this will be of shape [batch_size * from_seq_length,
num_attention_heads * size_per_head]).
Raises:
ValueError: Any of the arguments or tensor shapes are invalid.
"""
def transpose_for_scores(input_tensor, batch_size, num_attention_heads,
seq_length, width):
output_tensor = tf.reshape(
input_tensor, [batch_size, seq_length, num_attention_heads, width])
output_tensor = tf.transpose(output_tensor, [0, 2, 1, 3])
return output_tensor
from_shape = get_shape_list(from_tensor, expected_rank=[2, 3])
to_shape = get_shape_list(to_tensor, expected_rank=[2, 3])
if len(from_shape) != len(to_shape):
raise ValueError(
"The rank of `from_tensor` must match the rank of `to_tensor`.")
if len(from_shape) == 3:
batch_size = from_shape[0]
from_seq_length = from_shape[1]
to_seq_length = to_shape[1]
elif len(from_shape) == 2:
if (batch_size is None or from_seq_length is None or to_seq_length is None):
raise ValueError(
"When passing in rank 2 tensors to attention_layer, the values "
"for `batch_size`, `from_seq_length`, and `to_seq_length` "
"must all be specified.")
# Scalar dimensions referenced here:
# B = batch size (number of sequences)
# F = `from_tensor` sequence length
# T = `to_tensor` sequence length
# N = `num_attention_heads`
# H = `size_per_head`
from_tensor_2d = reshape_to_matrix(from_tensor)
to_tensor_2d = reshape_to_matrix(to_tensor)
# `query_layer` = [B*F, N*H]
query_layer = tf.layers.dense(
from_tensor_2d,
num_attention_heads * size_per_head,
activation=query_act,
name="query",
kernel_initializer=create_initializer(initializer_range))
# `key_layer` = [B*T, N*H]
key_layer = tf.layers.dense(
to_tensor_2d,
num_attention_heads * size_per_head,
activation=key_act,
name="key",
kernel_initializer=create_initializer(initializer_range))
# `value_layer` = [B*T, N*H]
value_layer = tf.layers.dense(
to_tensor_2d,
num_attention_heads * size_per_head,
activation=value_act,
name="value",
kernel_initializer=create_initializer(initializer_range))
# `query_layer` = [B, N, F, H]
query_layer = transpose_for_scores(query_layer, batch_size,
num_attention_heads, from_seq_length,
size_per_head)
# `key_layer` = [B, N, T, H]
key_layer = transpose_for_scores(key_layer, batch_size, num_attention_heads,
to_seq_length, size_per_head)
# Take the dot product between "query" and "key" to get the raw
# attention scores.
# `attention_scores` = [B, N, F, T]
attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True)
attention_scores = tf.multiply(attention_scores,
1.0 / math.sqrt(size_per_head))
if attention_mask is not None:
# `attention_mask` = [B, 1, F, T]
attention_mask = tf.expand_dims(attention_mask, axis=[1])
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for
# positions we want to attend and -10000.0 for masked positions.
adder = (1.0 - tf.cast(attention_mask, tf.flags.FLAGS.floatx)) * -10000.0
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
attention_scores += adder
# Normalize the attention scores to probabilities.
# `attention_probs` = [B, N, F, T]
attention_probs = tf.nn.softmax(attention_scores)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attention_probs = dropout(attention_probs, attention_probs_dropout_prob)
# `value_layer` = [B, T, N, H]
value_layer = tf.reshape(
value_layer,
[batch_size, to_seq_length, num_attention_heads, size_per_head])
# `value_layer` = [B, N, T, H]
value_layer = tf.transpose(value_layer, [0, 2, 1, 3])
# `context_layer` = [B, N, F, H]
context_layer = tf.matmul(attention_probs, value_layer)
# `context_layer` = [B, F, N, H]
context_layer = tf.transpose(context_layer, [0, 2, 1, 3])
if do_return_2d_tensor:
# `context_layer` = [B*F, N*H]
context_layer = tf.reshape(
context_layer,
[batch_size * from_seq_length, num_attention_heads * size_per_head])
else:
# `context_layer` = [B, F, N*H]
context_layer = tf.reshape(
context_layer,
[batch_size, from_seq_length, num_attention_heads * size_per_head])
return context_layer
def transformer_model(input_tensor,
attention_mask=None,
hidden_size=768,
num_hidden_layers=12,
num_attention_heads=12,
intermediate_size=3072,
intermediate_act_fn=gelu,
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
initializer_range=0.02,
do_return_all_layers=False):
"""Multi-headed, multi-layer Transformer from "Attention is All You Need".
This is almost an exact implementation of the original Transformer encoder.
See the original paper:
https://arxiv.org/abs/1706.03762
Also see:
https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/models/transformer.py
Args:
input_tensor: float Tensor of shape [batch_size, seq_length, hidden_size].
attention_mask: (optional) int32 Tensor of shape [batch_size, seq_length,
seq_length], with 1 for positions that can be attended to and 0 in
positions that should not be.
hidden_size: int. Hidden size of the Transformer.
num_hidden_layers: int. Number of layers (blocks) in the Transformer.
num_attention_heads: int. Number of attention heads in the Transformer.
intermediate_size: int. The size of the "intermediate" (a.k.a., feed
forward) layer.
intermediate_act_fn: function. The non-linear activation function to apply
to the output of the intermediate/feed-forward layer.
hidden_dropout_prob: float. Dropout probability for the hidden layers.
attention_probs_dropout_prob: float. Dropout probability of the attention
probabilities.
initializer_range: float. Range of the initializer (stddev of truncated
normal).
do_return_all_layers: Whether to also return all layers or just the final
layer.
Returns:
float Tensor of shape [batch_size, seq_length, hidden_size], the final
hidden layer of the Transformer.
Raises:
ValueError: A Tensor shape or parameter is invalid.
"""
if hidden_size % num_attention_heads != 0:
raise ValueError(
"The hidden size (%d) is not a multiple of the number of attention "
"heads (%d)" % (hidden_size, num_attention_heads))
attention_head_size = int(hidden_size / num_attention_heads)
input_shape = get_shape_list(input_tensor, expected_rank=3)
batch_size = input_shape[0]
seq_length = input_shape[1]
input_width = input_shape[2]
# The Transformer performs sum residuals on all layers so the input needs
# to be the same as the hidden size.
if input_width != hidden_size:
raise ValueError("The width of the input tensor (%d) != hidden size (%d)" %
(input_width, hidden_size))
# We keep the representation as a 2D tensor to avoid re-shaping it back and
# forth from a 3D tensor to a 2D tensor. Re-shapes are normally free on
# the GPU/CPU but may not be free on the TPU, so we want to minimize them to
# help the optimizer.
prev_output = reshape_to_matrix(input_tensor)
all_layer_outputs = []
for layer_idx in range(num_hidden_layers):
with tf.variable_scope("layer_%d" % layer_idx):
layer_input = prev_output
with tf.variable_scope("attention"):
attention_heads = []
with tf.variable_scope("self"):
attention_head = attention_layer(
from_tensor=layer_input,
to_tensor=layer_input,
attention_mask=attention_mask,
num_attention_heads=num_attention_heads,
size_per_head=attention_head_size,
attention_probs_dropout_prob=attention_probs_dropout_prob,
initializer_range=initializer_range,
do_return_2d_tensor=True,
batch_size=batch_size,
from_seq_length=seq_length,
to_seq_length=seq_length)
attention_heads.append(attention_head)
attention_output = None
if len(attention_heads) == 1:
attention_output = attention_heads[0]
else:
# In the case where we have other sequences, we just concatenate
# them to the self-attention head before the projection.
attention_output = tf.concat(attention_heads, axis=-1)
# Run a linear projection of `hidden_size` then add a residual
# with `layer_input`.
with tf.variable_scope("output"):
attention_output = tf.layers.dense(
attention_output,
hidden_size,
kernel_initializer=create_initializer(initializer_range))
attention_output = dropout(attention_output, hidden_dropout_prob)
attention_output = layer_norm(attention_output + layer_input)
# The activation is only applied to the "intermediate" hidden layer.
with tf.variable_scope("intermediate"):
intermediate_output = tf.layers.dense(
attention_output,
intermediate_size,
activation=intermediate_act_fn,
kernel_initializer=create_initializer(initializer_range))
# Down-project back to `hidden_size` then add the residual.
with tf.variable_scope("output"):
layer_output = tf.layers.dense(
intermediate_output,
hidden_size,
kernel_initializer=create_initializer(initializer_range))
layer_output = dropout(layer_output, hidden_dropout_prob)
layer_output = layer_norm(layer_output + attention_output)
prev_output = layer_output
all_layer_outputs.append(layer_output)
if do_return_all_layers:
final_outputs = []
for layer_output in all_layer_outputs:
final_output = reshape_from_matrix(layer_output, input_shape)
final_outputs.append(final_output)
return final_outputs
else:
final_output = reshape_from_matrix(prev_output, input_shape)
return final_output
def get_shape_list(tensor, expected_rank=None, name=None):
"""Returns a list of the shape of tensor, preferring static dimensions.
Args:
tensor: A tf.Tensor object to find the shape of.
expected_rank: (optional) int. The expected rank of `tensor`. If this is
specified and the `tensor` has a different rank, and exception will be
thrown.
name: Optional name of the tensor for the error message.
Returns:
A list of dimensions of the shape of tensor. All static dimensions will
be returned as python integers, and dynamic dimensions will be returned
as tf.Tensor scalars.
"""
if name is None:
name = tensor.name
if expected_rank is not None:
assert_rank(tensor, expected_rank, name)
shape = tensor.shape.as_list()
non_static_indexes = []
for (index, dim) in enumerate(shape):
if dim is None:
non_static_indexes.append(index)
if not non_static_indexes:
return shape
dyn_shape = tf.shape(tensor)
for index in non_static_indexes:
shape[index] = dyn_shape[index]
return shape
def reshape_to_matrix(input_tensor):
"""Reshapes a >= rank 2 tensor to a rank 2 tensor (i.e., a matrix)."""
ndims = input_tensor.shape.ndims
if ndims < 2:
raise ValueError("Input tensor must have at least rank 2. Shape = %s" %
(input_tensor.shape))
if ndims == 2:
return input_tensor
width = input_tensor.shape[-1]
output_tensor = tf.reshape(input_tensor, [-1, width])
return output_tensor
def reshape_from_matrix(output_tensor, orig_shape_list):
"""Reshapes a rank 2 tensor back to its original rank >= 2 tensor."""
if len(orig_shape_list) == 2:
return output_tensor
output_shape = get_shape_list(output_tensor)
orig_dims = orig_shape_list[0:-1]
width = output_shape[-1]
return tf.reshape(output_tensor, orig_dims + [width])
def assert_rank(tensor, expected_rank, name=None):
"""Raises an exception if the tensor rank is not of the expected rank.
Args:
tensor: A tf.Tensor to check the rank of.
expected_rank: Python integer or list of integers, expected rank.
name: Optional name of the tensor for the error message.
Raises:
ValueError: If the expected shape doesn't match the actual shape.
"""
if name is None:
name = tensor.name
expected_rank_dict = {}
if isinstance(expected_rank, six.integer_types):
expected_rank_dict[expected_rank] = True
else:
for x in expected_rank:
expected_rank_dict[x] = True
actual_rank = tensor.shape.ndims
if actual_rank not in expected_rank_dict:
scope_name = tf.get_variable_scope().name
raise ValueError(
"For the tensor `%s` in scope `%s`, the actual rank "
"`%d` (shape = %s) is not equal to the expected rank `%s`" %
(name, scope_name, actual_rank, str(tensor.shape), str(expected_rank)))

View file

@ -0,0 +1,196 @@
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# usage example
# export BERT_BASE_DIR=/path/to/bert/uncased_L-12_H-768_A-12
# export GLUE_DIR=/path/to/glue
# python profile_bert_inference.py --task_name=MRPC --data_dir=$GLUE_DIR/MRPC --vocab_file=$BERT_BASE_DIR/vocab.txt --bert_config_file=$BERT_BASE_DIR/bert_config.json --predict_batch_size=8 --max_seq_length=128 --output_dir=mrpc_output --init_checkpoint=$BERT_BASE_DIR/bert_model.ckpt --tf_profile=true --profiling_output_file=time_elapsed --xla=false --floatx=float32
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import sys
import numpy as np
import fast_infer_util as fiu
import profile_util
import tensorflow as tf
import os
from tensorflow.python.client import timeline
import contextlib
import time
from tensorflow.python.client import device_lib
import my_modeling
bert_submodule = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'bert')
sys.path.insert(0, bert_submodule)
import tokenization
import run_classifier as rc
flags = tf.flags
FLAGS = flags.FLAGS
def create_model(bert_config, is_training, input_ids, input_mask, segment_ids):
"""Creates a classification model."""
model = my_modeling.BertModel(
config=bert_config,
is_training=is_training,
input_ids=input_ids,
input_mask=input_mask,
token_type_ids=segment_ids,
use_one_hot_embeddings=False)
seq_output = model.get_sequence_output()
return seq_output
def model_fn_builder(bert_config):
def model_fn(features):
# print features
tf.logging.info("*** Features ***")
for name in sorted(features.keys()):
tf.logging.info(" name = %s, shape = %s" %
(name, features[name].shape))
input_ids = features["input_ids"]
input_mask = features["input_mask"]
segment_ids = features["segment_ids"]
label_ids = features["label_ids"]
fetches = create_model(
bert_config, False, input_ids, input_mask, segment_ids)
# # fetch mrpc logits for prediction
# num_labels = 2 # for mrpc
# _, _, fetches, _ = fiu.create_model(bert_config, False, input_ids, input_mask, segment_ids, label_ids,
# num_labels, False)
return fetches
return model_fn
def main(_):
tf.logging.set_verbosity(tf.logging.INFO)
num_iter = 20
jit_xla = tf.OptimizerOptions.ON_1 if FLAGS.xla else 0
processors = {
"cola": rc.ColaProcessor,
"mnli": rc.MnliProcessor,
"mrpc": rc.MrpcProcessor,
"xnli": rc.XnliProcessor,
}
# sanity check
tokenization.validate_case_matches_checkpoint(FLAGS.do_lower_case,
FLAGS.init_checkpoint)
bert_config = my_modeling.BertConfig.from_json_file(FLAGS.bert_config_file)
if FLAGS.max_seq_length > bert_config.max_position_embeddings:
raise ValueError(
"Cannot use sequence length %d because the BERT model "
"was only trained up to sequence length %d" %
(FLAGS.max_seq_length, bert_config.max_position_embeddings))
tf.gfile.MakeDirs(FLAGS.output_dir)
task_name = FLAGS.task_name.lower()
if task_name not in processors:
raise ValueError("Task not found: %s" % (task_name))
# prepare data
processor = processors[task_name]()
label_list = processor.get_labels()
predict_examples = processor.get_test_examples(FLAGS.data_dir)
tokenizer = tokenization.FullTokenizer(
vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)
predict_file = os.path.join(FLAGS.output_dir, "predict.tf_record")
rc.file_based_convert_examples_to_features(predict_examples, label_list,
FLAGS.max_seq_length, tokenizer,
predict_file)
# get model function and input function
# drop_remainder option should be turned on for fast transformer inference
drop_remainder = True
predict_input_fn = rc.file_based_input_fn_builder(
input_file=predict_file,
seq_length=FLAGS.max_seq_length,
is_training=False,
drop_remainder=drop_remainder)
def graph_fn():
model_fn = model_fn_builder(bert_config=bert_config)
dataset = predict_input_fn({'batch_size': FLAGS.predict_batch_size})
next_item = dataset.make_one_shot_iterator().get_next()
output_var = model_fn(next_item)
return output_var
if FLAGS.tf_profile:
tf.logging.info("***** Running tensorflow transformer*****")
p1 = profile_util.Profiler(os.path.join(
FLAGS.output_dir, 'prof/bert_origin'))
t1, r1 = profile_util.run_profile(
graph_fn, jit_xla, num_iter, p1, init_checkpoint=FLAGS.init_checkpoint)
tf.reset_default_graph()
my_modeling.transformer_model = fiu.fast_transformer_model_trans
tf.logging.info("***** Running fast transformer*****")
p2 = profile_util.Profiler(os.path.join(
FLAGS.output_dir, 'prof/bert_fastinfer'))
t2, r2 = profile_util.run_profile(
graph_fn, jit_xla, num_iter, p2, init_checkpoint=FLAGS.init_checkpoint)
else:
tf.logging.info("***** Running tensorflow transformer*****")
t1, r1 = profile_util.run_profile(
graph_fn, jit_xla, num_iter, check_result=False, init_checkpoint=FLAGS.init_checkpoint)
tf.reset_default_graph()
my_modeling.transformer_model = fiu.fast_transformer_model_trans
tf.logging.info("***** Running fast transformer*****")
t2, r2 = profile_util.run_profile(
graph_fn, jit_xla, num_iter, check_result=False, init_checkpoint=FLAGS.init_checkpoint)
print('average time (seconds) elasped original tensorflow:', t1)
print('average time (seconds) elasped fast transformer:', t2)
if len(r1) + len(r2) > 0:
check_res = np.asarray([np.allclose(
r1[i], r2[i], atol=1e-4, rtol=0) for i in range(num_iter)])
if check_res.all():
print('Pass')
print(np.mean(r1))
print(np.mean(r2))
else:
for i in np.where(np.logical_not(check_res))[0]:
diff = np.fabs(r1[i] - r2[i])
idx = np.unravel_index(diff.argmax(), diff.shape)
print('Failed iter:', i, "max diff:",
diff[idx], idx, r1[i][idx], r2[i][idx])
if __name__ == "__main__":
flags.mark_flag_as_required("data_dir")
flags.mark_flag_as_required("task_name")
flags.mark_flag_as_required("vocab_file")
flags.mark_flag_as_required("bert_config_file")
flags.mark_flag_as_required("output_dir")
flags.DEFINE_string("profiling_output_file", None,
"The output file for profiling results.")
flags.mark_flag_as_required("profiling_output_file")
flags.DEFINE_string("floatx", "float32", "float32 or float16")
flags.mark_flag_as_required("floatx")
flags.DEFINE_bool("xla", False, "whether to turn on XLA")
flags.mark_flag_as_required("xla")
flags.DEFINE_bool("tf_profile", False,
"whether to use tensorflow profiling")
tf.app.run()

View file

@ -0,0 +1,223 @@
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# usage example
# export BERT_BASE_DIR=/path/to/bert/uncased_L-12_H-768_A-12
# python profile_transformer_inference.py --init_checkpoint=$BERT_BASE_DIR/bert_model.ckpt --tf_profile=false --output_dir=mrpc_output --profiling_output_file=time_elapsed --xla=false --floatx=float32
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.client import device_lib
import time
import contextlib
from tensorflow.python.client import timeline
import os
import tensorflow as tf
import fast_infer_util as fiu
import numpy as np
import profile_util
import sys
import my_modeling
bert_submodule = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'bert')
sys.path.insert(0, bert_submodule)
import run_classifier
import optimization
flags = tf.flags
FLAGS = flags.FLAGS
# stacked transformer encoders
class TransformerModel(object):
def __init__(self,
config,
is_training,
input_tensor,
attention_mask,
transformer_model_fn,
scope=None):
config = my_modeling.copy.deepcopy(config)
if not is_training:
config.hidden_dropout_prob = 0.0
config.attention_probs_dropout_prob = 0.0
input_shape = my_modeling.get_shape_list(input_tensor, expected_rank=3)
batch_size = input_shape[0]
seq_length = input_shape[1]
with tf.variable_scope(scope, default_name="bert"):
with tf.variable_scope("encoder"):
# Run the stacked transformer.
# `sequence_output` shape = [batch_size, seq_length, hidden_size].
self.all_encoder_layers = transformer_model_fn(
input_tensor=input_tensor,
attention_mask=attention_mask,
hidden_size=config.hidden_size,
num_hidden_layers=config.num_hidden_layers,
num_attention_heads=config.num_attention_heads,
intermediate_size=config.intermediate_size,
intermediate_act_fn=my_modeling.get_activation(
config.hidden_act),
hidden_dropout_prob=config.hidden_dropout_prob,
attention_probs_dropout_prob=config.attention_probs_dropout_prob,
initializer_range=config.initializer_range,
do_return_all_layers=True)
self.sequence_output = self.all_encoder_layers[-1]
with tf.variable_scope("pooler"):
first_token_tensor = tf.squeeze(self.sequence_output[:, 0:1, :], axis=1)
self.pooled_output = tf.layers.dense(
first_token_tensor,
config.hidden_size,
activation=tf.tanh,
kernel_initializer=my_modeling.create_initializer(config.initializer_range))
def get_pooled_output(self):
return self.pooled_output
def get_sequence_output(self):
return self.sequence_output
def model_fn_builder(bert_config, transformer_model_fn):
def model_fn(input_tensor, attention_mask): # pylint: disable=unused-argument
model = TransformerModel(
config=bert_config,
is_training=False,
input_tensor=input_tensor,
attention_mask=attention_mask,
transformer_model_fn=transformer_model_fn)
seq_output = model.get_sequence_output()
return seq_output
return model_fn
def profile_model(config, jit_xla, num_iter):
# initialize data
input_data = np.random.randn(
FLAGS.predict_batch_size, FLAGS.max_seq_length, config.hidden_size)
attention_mask = np.random.randint(2, size=(
FLAGS.predict_batch_size, FLAGS.max_seq_length))
attention_mask = np.repeat(
attention_mask[:, np.newaxis, :], FLAGS.max_seq_length, axis=1)
model_fn_tf = model_fn_builder(config, my_modeling.transformer_model)
model_fn_ft = model_fn_builder(config, fiu.fast_transformer_model_trans)
def graph_fn_builder(model_fn):
def graph_fn():
input_tensor = tf.constant(input_data, dtype=FLAGS.floatx)
mask_tensor = tf.constant(attention_mask, dtype=FLAGS.floatx)
output_var = model_fn(input_tensor, mask_tensor)
# for saving memcopy time
return tf.reduce_mean(output_var)
return graph_fn
if FLAGS.tf_profile:
tf.logging.info("***** Running tensorflow transformer*****")
p1 = profile_util.Profiler(os.path.join(
FLAGS.output_dir, 'prof/bert_origin'))
t1, r1 = profile_util.run_profile(graph_fn_builder(
model_fn_tf), jit_xla, num_iter, p1, init_checkpoint=FLAGS.init_checkpoint)
tf.reset_default_graph()
tf.logging.info("***** Running fast transformer*****")
p2 = profile_util.Profiler(os.path.join(
FLAGS.output_dir, 'prof/bert_fastinfer'))
t2, r2 = profile_util.run_profile(graph_fn_builder(
model_fn_ft), jit_xla, num_iter, p2, init_checkpoint=FLAGS.init_checkpoint)
else:
tf.logging.info("***** Running tensorflow transformer*****")
t1, r1 = profile_util.run_profile(graph_fn_builder(
model_fn_tf), jit_xla, num_iter, check_result=False, init_checkpoint=FLAGS.init_checkpoint)
tf.reset_default_graph()
tf.logging.info("***** Running fast transformer*****")
t2, r2 = profile_util.run_profile(graph_fn_builder(
model_fn_ft), jit_xla, num_iter, check_result=False, init_checkpoint=FLAGS.init_checkpoint)
# check errors
print('average time (seconds) elasped original tensorflow:', t1)
print('average time (seconds) elasped fast transformer:', t2)
if len(r1) + len(r2) > 0:
check_res = np.asarray([np.allclose(
r1[i], r2[i], atol=1e-4, rtol=0) for i in range(num_iter)])
if check_res.all():
print('Pass')
print(np.mean(r1))
print(np.mean(r2))
else:
for i in np.where(np.logical_not(check_res))[0]:
diff = np.fabs(r1[i] - r2[i])
idx = np.unravel_index(diff.argmax(), diff.shape)
print('Failed iter:', i, "max diff:",
diff[idx], idx, r1[i][idx], r2[i][idx])
return t1, t2
def main(_):
tf.logging.set_verbosity(tf.logging.INFO)
batch_size = [8]
seq_length = [128]
num_hidden_layers = [12]
attention_heads_num_size = [(12, 64)]
num_iter = 20
interval = 0
# collect results of both original bert and fast transformer
jit_xla = tf.OptimizerOptions.ON_1 if FLAGS.xla else 0
config = my_modeling.BertConfig(vocab_size=0)
tf.gfile.MakeDirs(FLAGS.output_dir)
local_device_protos = device_lib.list_local_devices()
with open(os.path.join(FLAGS.output_dir, FLAGS.profiling_output_file), 'w') as f:
for x in local_device_protos:
if x.device_type == 'GPU':
f.write(x.physical_device_desc + '\n')
f.write(str(FLAGS.floatx) + '\t' + 'XLA: ' + str(FLAGS.xla) + '\n')
f.write('batch_size\tseq_length\thidden_layers\tattention_heads\tattention_head_size\tTensorflow\tFasterTransformer\n')
for bs in batch_size:
FLAGS.predict_batch_size = bs
for sl in seq_length:
FLAGS.max_seq_length = sl
for hidden_layers in num_hidden_layers:
config.num_hidden_layers = hidden_layers
for head_num, head_size in attention_heads_num_size:
config.num_attention_heads = head_num
config.hidden_size = head_num * head_size
time.sleep(interval)
t1, t2 = profile_model(config, jit_xla, num_iter)
tmp = [FLAGS.predict_batch_size, FLAGS.max_seq_length, hidden_layers, head_num, head_size,
'{:.6}'.format(t1), '{:.6}'.format(t2)]
f.write('\t'.join([str(x) for x in tmp]) + '\n')
if __name__ == "__main__":
flags.mark_flag_as_required("output_dir")
flags.DEFINE_string("profiling_output_file", None,
"The output file for profiling results.")
flags.mark_flag_as_required("profiling_output_file")
flags.DEFINE_string("floatx", "float32", "float32 or float16")
flags.mark_flag_as_required("floatx")
flags.DEFINE_bool("xla", False, "whether to turn on XLA")
flags.mark_flag_as_required("xla")
flags.DEFINE_bool("tf_profile", False,
"whether to use tensorflow profiling")
tf.app.run()

View file

@ -0,0 +1,88 @@
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from tensorflow.python.client import device_lib
import time
import contextlib
from tensorflow.python.client import timeline
import os
import tensorflow as tf
class Profiler():
def __init__(self, profile_name_pref):
self.profile_name_pref = profile_name_pref
self.run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
self.run_metadata = tf.RunMetadata()
self.ctr = 0
self.time_avg = 0
@contextlib.contextmanager
def prof_run(self):
start = time.time()
yield
end = time.time()
self.time_avg = (self.time_avg * self.ctr + end - start)/(self.ctr + 1)
fetched_timeline = timeline.Timeline(self.run_metadata.step_stats)
chrome_trace = fetched_timeline.generate_chrome_trace_format()
file_name = self.profile_name_pref + '_' + str(self.ctr) + '.json'
os.makedirs(os.path.dirname(file_name), exist_ok=True)
with open(file_name, 'w') as f:
f.write(chrome_trace)
self.ctr += 1
def run_profile(graph_fn, jit_xla, num_iter, profiler=None, init_checkpoint=None, check_result=True, dryrun_iter=1):
config = tf.ConfigProto()
config.graph_options.optimizer_options.global_jit_level = jit_xla
fetches = graph_fn()
with tf.Session(config=config) as sess:
# init
if init_checkpoint is None:
sess.run(tf.global_variables_initializer())
else:
saver = tf.train.Saver()
saver.restore(sess, init_checkpoint)
# dry run
for _ in range(dryrun_iter):
sess.run(fetches)
res = []
if profiler is None:
start_time = time.time()
if check_result:
for _ in range(num_iter):
res.append(sess.run(fetches))
else:
for _ in range(num_iter):
sess.run(fetches)
end_time = time.time()
time_avg = (end_time - start_time)/num_iter
else:
if check_result:
for _ in range(num_iter):
with profiler.prof_run():
res.append(sess.run(fetches, options=profiler.run_options, run_metadata=profiler.run_metadata))
else:
for _ in range(num_iter):
with profiler.prof_run():
sess.run(fetches, options=profiler.run_options, run_metadata=profiler.run_metadata)
time_avg = profiler.time_avg
return time_avg, res

View file

@ -0,0 +1,76 @@
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# usage example
# export BERT_BASE_DIR=/path/to/bert/uncased_L-12_H-768_A-12
# export GLUE_DIR=/path/to/glue
# python run_classifier_wrap.py --floatx=float16 --task_name=MRPC --do_eval=true --data_dir=$GLUE_DIR/MRPC --vocab_file=$BERT_BASE_DIR/vocab.txt --bert_config_file=$BERT_BASE_DIR/bert_config.json --init_checkpoint=mrpc_output/fp16_model.ckpt --max_seq_length=128 --eval_batch_size=8 --output_dir=mrpc_output
# FP32 Tensorflow Transformer MRPC result
# INFO:tensorflow: eval_accuracy = 0.877451
# INFO:tensorflow: eval_loss = 0.44744828
# INFO:tensorflow: global_step = 0
# INFO:tensorflow: loss = 0.44744828
# FP32 Faster Transformer MRPC result
# INFO:tensorflow: eval_accuracy = 0.877451
# INFO:tensorflow: eval_loss = 0.4474482
# INFO:tensorflow: global_step = 0
# INFO:tensorflow: loss = 0.4474482
# FP16 Tensorflow Transformer MRPC result
# INFO:tensorflow: eval_accuracy = 0.875
# INFO:tensorflow: eval_loss = 0.44760832
# INFO:tensorflow: global_step = 0
# INFO:tensorflow: loss = 0.44760215
# FP16 Faster Transformer MRPC result
# INFO:tensorflow: eval_accuracy = 0.875
# INFO:tensorflow: eval_loss = 0.44731623
# INFO:tensorflow: global_step = 0
# INFO:tensorflow: loss = 0.44728807
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import sys
import os
bert_submodule = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'bert')
sys.path.insert(0, bert_submodule)
import tensorflow as tf
import run_classifier as rc
import fast_infer_util as fiu
import my_modeling
flags = tf.flags
FLAGS = flags.FLAGS
# replace transformer implementation
my_modeling.transformer_model = fiu.fast_transformer_model_trans
# replace the model to support fp16 data type
rc.create_model = fiu.create_model
# replace the input function to drop remainder
rc.file_based_input_fn_builder = fiu.file_based_input_fn_builder_drop
main = rc.main
if __name__ == "__main__":
flags.mark_flag_as_required("data_dir")
flags.mark_flag_as_required("task_name")
flags.mark_flag_as_required("vocab_file")
flags.mark_flag_as_required("bert_config_file")
flags.mark_flag_as_required("output_dir")
flags.DEFINE_string("floatx", None, "float32 or float16")
flags.mark_flag_as_required("floatx")
tf.app.run()

View file

@ -0,0 +1,148 @@
Tensorflow BERT Samples
---
**Using bert_transformer Tensorflow op in a transformer encoder**
The trunk network of BERT model consists of a multi-layer transformer encoder,
which is implemented as the `transformer_model()` function in the file `modeling.py` in their official [Github repository](https://github.com/google-research/bert).
Samples provided in file `fast_infer_util.py` show how to re-implement this function with our ops in order to get an inference time speedup.
The function `fast_transformer_model_trans()` implements the transformer encoder using the `bert_transformer` op.
In order to do that, we only need to first import the op at the beginning of the file, then call `bert_transformer` op at the end of each encoder layer. This turns out can be done by adding several lines of code to the original `transformer_model()` function as the following.
```python
# import op
transformer_op_module = tf.load_op_library(os.path.join('../../build/lib/libtf_fastertransformer.so'))
...
def fast_transformer_model_trans(...)
...
# original code
with tf.variable_scope("output"):
layer_output = tf.layers.dense(
intermediate_output,
hidden_size,
kernel_initializer=create_initializer(initializer_range))
layer_output = dropout(layer_output, hidden_dropout_prob)
layer_output = layer_norm(layer_output + attention_output)
# calling bert_transformer
trainable_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=tf.get_variable_scope().name)
layer_output = transformer_op_module.bert_transformer(
layer_input,
layer_input,
trainable_vars[0], trainable_vars[2], trainable_vars[4], trainable_vars[1], trainable_vars[3], trainable_vars[5],
attention_mask,
trainable_vars[6], trainable_vars[7], trainable_vars[8], trainable_vars[9], trainable_vars[10], trainable_vars[11],
trainable_vars[12], trainable_vars[13], trainable_vars[14], trainable_vars[15],
batch_size=batch_size, from_seq_len=seq_length, to_seq_len=seq_length, head_num=num_attention_heads, size_per_head=attention_head_size)
# original code
prev_output = layer_output
all_layer_outputs.append(layer_output)
...
```
**Running GLEU tasks with fast transformer inference**
The above shows how to implement a transformer encoder using our ops, to integrate it into the BERT pipeline
we can simply replace the `transformer_model` function in `modeling.py` with `fast_transformer_model_trans`.
Our implementation supports FP16 data type to further exploit the potential of inference acceleration.
FP16 inference was not supported by the original BERT code, here we made necessary modifications to build a FP16 compatible model,
which was implemented in `my_modeling.py` and the `create_model` function in `fast_infer_util.py`.
FP32 Tensorflow checkpoint files cannot be used directly for FP16 inference, we can convert its data type to FP16 in advance.
The `ckpt_type_convert.py` script is provided for checkpoint data type conversion.
It is also important to note that our implementation requires a fixed batch size, this can be done by setting `drop_remainder` option to `True` for Tensorflow `Dataset` instances. We have re-implemented this as well in the `file_based_input_fn_builder_drop` function.
On top of the above modifications, it's easy to run any of the GLEU tasks supported by the open source BERT sample with our ops for better inference performance. We only need to replace several functions in original `run_classifier.py` script with the implementations we provide.
```python
import run_classifier as rc
import fast_infer_util as fiu
import my_modeling
...
# replace transformer implementation
my_modeling.transformer_model = fiu.fast_transformer_model_trans
# replace the model to support fp16 data type
rc.create_model = fiu.create_model
# replace the input function to drop remainder
rc.file_based_input_fn_builder = fiu.file_based_input_fn_builder_drop
...
```
The sample `run_classifier_wrap.py` is a wrapper of the original `run_classifier.py` script for BERT, it supports the same options as described in [BERT readme](https://github.com/google-research/bert) with additional `floatx` options to specify floating point type.
For example, to compare the performance of original BERT and our implementation on MRPC task we can run the following command.
```bash
export BERT_BASE_DIR=/path/to/bert/uncased_L-12_H-768_A-12
export GLUE_DIR=/path/to/glue
python run_classifier.py --task_name=MRPC --do_eval=true --data_dir=$GLUE_DIR/MRPC --vocab_file=$BERT_BASE_DIR/vocab.txt --bert_config_file=$BERT_BASE_DIR/bert_config.json --init_checkpoint=ckpt_dir/fp32_model.ckpt --max_seq_length=128 --eval_batch_size=8 --output_dir=mrpc_output
python run_classifier_wrap.py --task_name=MRPC --do_eval=true --data_dir=$GLUE_DIR/MRPC --vocab_file=$BERT_BASE_DIR/vocab.txt --bert_config_file=$BERT_BASE_DIR/bert_config.json --init_checkpoint=ckpt_dir/fp16_model.ckpt --max_seq_length=128 --eval_batch_size=8 --output_dir=mrpc_output --floatx=float16
```
The evaluation result should be like
```
# original Tensorflow op
...
INFO:tensorflow:***** Eval results *****
INFO:tensorflow: eval_accuracy = 0.877451
INFO:tensorflow: eval_loss = 0.44744828
INFO:tensorflow: global_step = 0
INFO:tensorflow: loss = 0.44744828
# faster_transformer op with fp16 data type
INFO:tensorflow:***** Eval results *****
INFO:tensorflow: eval_accuracy = 0.875
INFO:tensorflow: eval_loss = 0.44731623
INFO:tensorflow: global_step = 0
INFO:tensorflow: loss = 0.44728807
...
```
We see the evaluation accuracy and loss drop slightly with FP16 inference for the MRPC sentence pair classification task.
The following section will show such minor sacrifice in accuracy will bring considerable performance gain.
**Tensorflow profiling**
The sample script `profile_transformer_inference.py` shows how to run and profile a BERT inference model from scratch. Results show we received a 6.36x speedup compared to FP32 Tensorflow (1.48x speedup compared to FP16 Tensorflow XLA) for an end-to-end classification model in our experiment settings.
GPU: Tesla T4
CUDA: 10.0.0
Model: BERT-Base: 12-layer, 768-hidden, 12-heads , 110M parameters
Max sequence length: 128
Batch size: 32
Average time elapsed:
| settings | seconds |
| ------------- | ------------- |
| FP32 Tensorflow | 0.2495 |
| FP32 Tensorflow XLA | 0.1998 |
| FP16 Tensorflow | 0.0978 |
| FP16 Tensorflow XLA | 0.0582 |
| FP16 FasterTransformer | 0.0392 |
**Content summary**
| file name | summary |
| ------------- | ------------- |
| `ckpt_type_convert.py` | script for checkpoint data type conversion |
| `fast_infer_util.py` | example functions to use faster transformer ops in Tensorflow |
| `my_modeling.py` | basically the same as `modeling.py` in the original BERT repository, modifications are made to support FP16 data types |
| `run_classifier_wrap.py` | a wrapper script of `run_classifier.py` in the original BERT repository, shows how to run classification tasks using faster transformer ops |
| `profile_bert_inference.py` | for profiling BERT model pipelines |
| `profile_transformer_inference.py` | for profiling transformer encoder layers |
| `profile_util.py` | helper functions for profiling |