Adding FasterTransformer: A faster transformer layer inference implementation for BERT and other transformer based models.

This commit is contained in:
Xipeng Li 2019-07-14 00:29:45 +08:00
parent f89dcca19d
commit 75502be814
44 changed files with 7239 additions and 0 deletions

View file

@ -0,0 +1,121 @@
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
cmake_minimum_required(VERSION 3.8 FATAL_ERROR)
project(FasterTransformer LANGUAGES CXX CUDA)
find_package(CUDA 10.0 REQUIRED)
option(BUILD_TRT "Build in TensorRT mode" OFF)
option(BUILD_TF "Build in TensorFlow mode" OFF)
set(CUDA_PATH ${CUDA_TOOLKIT_ROOT_DIR})
set(TF_PATH "" CACHE STRING "TensorFlow path")
#set(TF_PATH "/usr/local/lib/python3.5/dist-packages/tensorflow")
if(BUILD_TF AND NOT TF_PATH)
message(FATAL_ERROR "TF_PATH must be set if BUILD_TF(=TensorFlow mode) is on.")
endif()
set(TRT_PATH "" CACHE STRING "TensorRT path")
#set(TRT_PATH "/myspace/TensorRT-5.1.5.0")
if(BUILD_TRT AND NOT TRT_PATH)
message(FATAL_ERROR "TRT_PATH must be set if BUILD_TRT(=TensorRT mode) is on.")
endif()
list(APPEND CMAKE_MODULE_PATH ${CUDA_PATH}/lib64)
find_package(CUDA REQUIRED)
# setting compiler flags
if (SM STREQUAL 70 OR
SM STREQUAL 75 OR
SM STREQUAL 61 OR
SM STREQUAL 60)
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -gencode=arch=compute_${SM},code=\\\"sm_${SM},compute_${SM}\\\" -rdc=true")
if (SM STREQUAL 70 OR SM STREQUAL 75)
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -DWMMA")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DWMMA")
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -DWMMA")
endif()
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS}")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}")
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Xcompiler -Wall")
message("-- Assign GPU architecture (sm=${SM})")
else()
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -gencode=arch=compute_60,code=\\\"sm_60,compute_60\\\" -rdc=true")
message("-- Unknown or unsupported GPU architecture (set sm=60)")
endif()
set(CMAKE_C_FLAGS_DEBUG "${CMAKE_C_FLAGS_DEBUG} -Wall -O0")
set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -Wall -O0")
set(CMAKE_CUDA_FLAGS_DEBUG "${CMAKE_CUDA_FLAGS_DEBUG} -O0 -G -Xcompiler -Wall")
set(CMAKE_CXX_STANDARD 11)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
if(CMAKE_CXX_STANDARD STREQUAL "11")
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-extended-lambda")
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-relaxed-constexpr")
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --std=c++11")
endif()
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O3")
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Xcompiler -O3")
set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib)
set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib)
set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin)
set(COMMON_HEADER_DIRS
${PROJECT_SOURCE_DIR}
${CUDA_PATH}/include
)
set(COMMON_LIB_DIRS
${CUDA_PATH}/lib64
)
if(BUILD_TF)
list(APPEND COMMON_HEADER_DIRS ${TF_PATH}/include)
list(APPEND COMMON_LIB_DIRS ${TF_PATH})
endif()
if(BUILD_TRT)
list(APPEND COMMON_HEADER_DIRS ${TRT_PATH}/include)
list(APPEND COMMON_LIB_DIRS ${TRT_PATH}/lib)
endif()
include_directories(
${COMMON_HEADER_DIRS}
)
link_directories(
${COMMON_LIB_DIRS}
)
add_subdirectory(tools)
add_subdirectory(fastertransformer)
add_subdirectory(sample)
if(BUILD_TF)
add_custom_target(copy ALL COMMENT "Copying tensorflow test scripts")
add_custom_command(TARGET copy
POST_BUILD
COMMAND cp ${PROJECT_SOURCE_DIR}/sample/tensorflow/*.py ${PROJECT_SOURCE_DIR}/build/
)
endif()

110
FasterTransformer/README.md Normal file
View file

@ -0,0 +1,110 @@
Faster Transformer
===================
## What is it?
The Faster Transformer implements an equivalent but highly optimized BERT transformer layer for inference. On Volta and Turing GPUs, FP16 precision is used automatically to access the computing power of tensor cores.
Faster Transformer is built on top of the CUDA and cuBLAS. It supports three kinds of sequence lengths, 32, 64 and 128. Two key parameters of the transformer layer, the number of heads and the size of each head, are passed in runtime. Thus, not only the BERT Base (12 heads * 64 per head) , but also customized models like 4 heads * 32 per head and 8 heads * 96 per heads, are well supported. Our implementation shows good speedups on both small and large batch size cases.
C++ API, TensorRT plugin, and TensorFlow OP wrapper are available. You can easily integrate this optimized transformer layer into your TensorFlow or other inference service codes that built in native C++ or TensorRT. In addition to codes that illustrate the API invocations, we also provide a simple end-to-end BERT TensorFlow inference sample.
## Environment requirements
* CMake >= 3.8
* CUDA 10.0
* Python 2.7
* Tensorflow 1.13
* TensorRT 5.1.5
* The project is tested in nvidia/cuda 10.0-cudnn7-devel-ubuntu16.04 docker image. If you encountered compiling errors, try to compile with this docker image.
## Performance ##
* CPU: Intel(R) Xeon(R) Gold 6132 CPU @ 2.60GHz
* T4 (with mclk 5000MHz, pclk 1590MHz)
* P4 (with mclk 2999MHz, pclk 1531MHz)
* V100 (with mclk 877MHz, pclk 1380MHz)
When batch size equals to 1, the Tensorflow execution time really depends on the CPU you are using.
We only report the faster transformer performance here.
The performance of the faster transformer mainly depends on GPU. The execution time is stable.
| <batch_size, layers, seq_len, head_num, size_per_head> | P4 FP32 (in ms) | T4 FP32 (in ms)| T4 FP16 (in ms)|
|:-------------:|:-------------:|:---------:|:-----------:|
| (1, 12, 32, 12, 64) | 3.43 | 2.74 | 1.56 |
| (1, 12, 64, 12, 64) | 4.04 | 3.64 | 1.77 |
| (1, 12, 128, 12, 64) | 6.22 | 5.93 | 2.23 |
For large batch size case, we report both Tensorflow XLA and faster transformer's performance.
| <batch_size, layers, seq_len, head_num, size_per_head> | Tensorflow XLA on V100 FP16 (in ms)| Faster Transformer V100 FP16 (in ms) | Speedup |
|:-------------:|:-------------:|:---------:|:-----------:|
| (100, 12, 32, 12, 64) | 13.96 | 9.57 | 1.459 |
| (200, 12, 32, 12, 64) | 26.47 | 18.37 | 1.44 |
| (300, 12, 32, 12, 64) | 38.4 | 27.41 | 1.401 |
| (400, 12, 32, 12, 64) | 49.65 | 35.63 | 1.393 |
| (500, 12, 32, 12, 64) | 62.2 | 44.57 | 1.396 |
| <batch_size, layers, seq_len, head_num, size_per_head> | Tensorflow XLA on V100 FP16 (in ms)| Faster Transformer V100 FP16 (in ms) | Speedup |
|:-------------:|:-------------:|:---------:|:-----------:|
| (100, 12, 32, 4, 32) | 3.49 | 1.73 | 2.017 |
| (200, 12, 32, 4, 32) | 4.9 | 2.55 | 1.922 |
| (300, 12, 32, 4, 32) | 6.35 | 3.356 | 1.892 |
| (400, 12, 32, 4, 32) | 8 | 4.31 | 1.856 |
| (500, 12, 32, 4, 32) | 9.93 | 5.13 | 1.936 |
## Directory Structure
```
/fastertransformer: source code of transformer
|--/cuda: some CUDA kernels and multi-head attention implementation, both are compiled with cuda/cuBLAS.
|--/tf_op: custom Tensorflow OP implementation
|--/trt_plugin: TensorRT plugin implementation
/sample: c++ and tensorflow transformer interface samples
|--/cpp: both FP16 and FP32 c++ interface samples
|--/fastertransformer_bert: samples that show of how to integrate our Tensorflow OP into the open source BERT model for sentence (and sentence-pair) classification tasks (GLUE), the samples support both FP16 and FP32, see readme file within this folder more details
|--/tensorflow: both FP16 and FP32 tensorflow OP samples
|--/tensorRT: both FP16 and FP32 tensorRT plugin samples
/tools/gemm_test: loop over all GEMM algorithms to pick the best one
```
## How to build?
### Init Git ###
```shell
$ git submodule init
$ git submodule update
```
### Build with Release ###
```shell
$ mkdir -p build
$ cd build
$ cmake -DSM=xx -DCMAKE_BUILD_TYPE=Release .. # C++ only
$ cmake -DSM=xx -DCMAKE_BUILD_TYPE=Release -DBUILD_TRT=ON -DTRT_PATH=/myspace/TensorRT-5.1.5.0 .. # TensorRT mode
$ cmake -DSM=xx -DCMAKE_BUILD_TYPE=Release -DBUILD_TF=ON -DTF_PATH=/usr/local/lib/python2.7/dist-packages/tensorflow .. # Tensorflow mode
$ cmake -DSM=xx -DCMAKE_BUILD_TYPE=Release -DBUILD_TRT=ON -DTRT_PATH=/myspace/TensorRT-5.1.5.0 -DBUILD_TF=ON -DTF_PATH=/usr/local/lib/python2.7/dist-packages/tensorflow .. # C++, TensorRT and Tensorflow
$ make
```
Note: xx is the compute capability of your GPU. For example, 60 (P40) or 61 (P4) or 70 (V100) or 75(T4).
### Execute demos ###
```shell
$ Step1 Generate the gemm_config.in file under the path build to pick GEMM algorithms for the best performance.
$ ./build/bin/gemm_fp16(32) <batch_size> <seq_len> <head_num> <size_per_head>
$ Step2 Execute demos
$ 1. Tensorflow demos: python build/transformer_fp16(32).py <batch_size> <num_layers> <seq_len> <head_num> <size_per_head>
$ 2. c++ demos: ./build/bin/transformer_fp16(32) <batch_size> <num_layerse> <seq_len> <head_num> <size_per_head>
$ 3. TensorRT demos: ./build/bin/transformer_trt <batch_size> <num_layerse> <seq_len> <head_num> <size_per_head> fp16(fp32)
```
### Useful sample code ###
```shell
$ 1. sample/tensorflow/transformer_fp32.py: transformer_layer Tensorflow FP32 OP call, time measurement, timeline generation
$ 2. sample/tensorflow/transformer_fp16.py: transformer_layer Tensorflow FP16 OP call, time measurement, timeline generation
$ 3. sample/tensorflow/error_check.py: how to catch custom OP runtime errors
$ 4. sample/cpp/transformer_fp32.cc: transformer layer C++ FP32 sample
$ 5. sample/cpp/transformer_fp16.cc: transformer layer C++ FP16 sample
$ 6. sample/tensorRT/transformer_trt.cc: transformer layer tensorRT FP32/FP16 sample
$ 7. tools/gemm_test/gemm_fp16.cu: loop over all cublas FP16 GEMM algorithms and pick the best one
$ 8. tools/gemm_test/gemm_fp32.cu: loop over all cublas FP32 GEMM algorithms and pick the best one
```

View file

@ -0,0 +1,180 @@
# Distributed under the OSI-approved BSD 3-Clause License. See accompanying
# file Copyright.txt or https://cmake.org/licensing for details.
#.rst:
# FindCUDNN
# -------
#
# Find CUDNN library
#
# Valiables that affect result:
# <VERSION>, <REQUIRED>, <QUIETLY>: as usual
#
# <EXACT> : as usual, plus we do find '5.1' version if you wanted '5'
# (not if you wanted '5.0', as usual)
#
# Result variables
# ^^^^^^^^^^^^^^^^
#
# This module will set the following variables in your project:
#
# ``CUDNN_INCLUDE``
# where to find cudnn.h.
# ``CUDNN_LIBRARY``
# the libraries to link against to use CUDNN.
# ``CUDNN_FOUND``
# If false, do not try to use CUDNN.
# ``CUDNN_VERSION``
# Version of the CUDNN library we looked for
#
# Exported functions
# ^^^^^^^^^^^^^^^^
# function(CUDNN_INSTALL version __dest_libdir [__dest_incdir])
# This function will try to download and install CUDNN.
# CUDNN5 and CUDNN6 are supported.
#
#
function(CUDNN_INSTALL version dest_libdir dest_incdir dest_bindir)
message(STATUS "CUDNN_INSTALL: Installing CUDNN ${version}, lib:${dest_libdir}, inc:${dest_incdir}, bin:${dest_bindir}")
string(REGEX REPLACE "-rc$" "" version_base "${version}")
set(tar_libdir cuda/lib64)
set(tar_incdir cuda/include)
if(${CMAKE_SYSTEM_NAME} MATCHES "Linux")
set(url_extension tgz)
if("${CMAKE_SYSTEM_PROCESSOR}" STREQUAL "x86_64")
set(url_arch_name linux-x64 )
elseif("${CMAKE_SYSTEM_PROCESSOR}" MATCHES "ppc")
set(url_arch_name linux-ppc64le )
# TX1 has to be installed via JetPack
endif()
elseif (APPLE)
set(url_extension tgz)
set(tar_libdir cuda/lib)
set(url_arch_name osx-x64)
elseif(WIN32)
set(url_extension zip)
set(tar_bindir cuda/bin)
set(tar_libdir cuda/lib/x64)
if(CMAKE_SYSTEM_VERSION MATCHES "10")
set(url_arch_name windows10-x64)
else()
set(url_arch_name windows7-x64)
endif()
endif()
# Download and install CUDNN locally if not found on the system
if(url_arch_name)
set(download_dir ${CMAKE_CURRENT_BINARY_DIR}/downloads/cudnn${version})
file(MAKE_DIRECTORY ${download_dir})
set(cudnn_filename cudnn-${CUDA_VERSION}-${url_arch_name}-v${version}.${url_extension})
set(base_url http://developer.download.nvidia.com/compute/redist/cudnn)
set(cudnn_url ${base_url}/v${version_base}/${cudnn_filename})
set(cudnn_file ${download_dir}/${cudnn_filename})
if(NOT EXISTS ${cudnn_file})
message(STATUS "Downloading CUDNN library from NVIDIA...")
file(DOWNLOAD ${cudnn_url} ${cudnn_file}
SHOW_PROGRESS STATUS cudnn_status
)
execute_process(COMMAND ${CMAKE_COMMAND} -E tar xzvf ${cudnn_file} WORKING_DIRECTORY ${download_dir} RESULT_VARIABLE cudnn_status)
if(NOT "${cudnn_status}" MATCHES "0")
message(STATUS "Was not able to download CUDNN from ${cudnn_url}. Please install CuDNN manually from https://developer.nvidia.com/cuDNN")
endif()
endif()
if(dest_bindir AND tar_bindir)
file(COPY ${download_dir}/${tar_bindir}/ DESTINATION ${dest_bindir})
endif()
if(dest_incdir)
file(COPY ${download_dir}/${tar_incdir}/ DESTINATION ${dest_incdir})
endif()
file(COPY ${download_dir}/${tar_libdir}/ DESTINATION ${dest_libdir} )
get_filename_component(dest_dir ${dest_libdir} DIRECTORY)
set(CUDNN_ROOT_DIR ${dest_dir} PARENT_SCOPE)
unset(CUDNN_LIBRARY CACHE)
unset(CUDNN_INCLUDE_DIR CACHE)
endif(url_arch_name)
endfunction()
#####################################################
find_package(PkgConfig)
pkg_check_modules(PC_CUDNN QUIET CUDNN)
get_filename_component(__libpath_cudart "${CUDA_CUDART_LIBRARY}" PATH)
# We use major only in library search as major/minor is not entirely consistent among platforms.
# Also, looking for exact minor version of .so is in general not a good idea.
# More strict enforcement of minor/patch version is done if/when the header file is examined.
if(CUDNN_FIND_VERSION_EXACT)
SET(__cudnn_ver_suffix ".${CUDNN_FIND_VERSION_MAJOR}")
SET(__cudnn_lib_win_name cudnn64_${CUDNN_FIND_VERSION_MAJOR})
else()
SET(__cudnn_lib_win_name cudnn64)
endif()
find_library(CUDNN_LIBRARY
NAMES libcudnn.so${__cudnn_ver_suffix} libcudnn${__cudnn_ver_suffix}.dylib ${__cudnn_lib_win_name}
PATHS $ENV{LD_LIBRARY_PATH} ${__libpath_cudart} ${CUDNN_ROOT_DIR} ${PC_CUDNN_LIBRARY_DIRS} ${CMAKE_INSTALL_PREFIX}
PATH_SUFFIXES lib lib64 bin
DOC "CUDNN library." )
if(CUDNN_LIBRARY)
SET(CUDNN_MAJOR_VERSION ${CUDNN_FIND_VERSION_MAJOR})
set(CUDNN_VERSION ${CUDNN_MAJOR_VERSION})
get_filename_component(__found_cudnn_root ${CUDNN_LIBRARY} PATH)
find_path(CUDNN_INCLUDE_DIR
NAMES cudnn.h
HINTS ${PC_CUDNN_INCLUDE_DIRS} ${CUDNN_ROOT_DIR} ${CUDA_TOOLKIT_INCLUDE} ${__found_cudnn_root}
PATH_SUFFIXES include
DOC "Path to CUDNN include directory." )
endif()
if(CUDNN_LIBRARY AND CUDNN_INCLUDE_DIR)
file(READ ${CUDNN_INCLUDE_DIR}/cudnn.h CUDNN_VERSION_FILE_CONTENTS)
string(REGEX MATCH "define CUDNN_MAJOR * +([0-9]+)"
CUDNN_MAJOR_VERSION "${CUDNN_VERSION_FILE_CONTENTS}")
string(REGEX REPLACE "define CUDNN_MAJOR * +([0-9]+)" "\\1"
CUDNN_MAJOR_VERSION "${CUDNN_MAJOR_VERSION}")
string(REGEX MATCH "define CUDNN_MINOR * +([0-9]+)"
CUDNN_MINOR_VERSION "${CUDNN_VERSION_FILE_CONTENTS}")
string(REGEX REPLACE "define CUDNN_MINOR * +([0-9]+)" "\\1"
CUDNN_MINOR_VERSION "${CUDNN_MINOR_VERSION}")
string(REGEX MATCH "define CUDNN_PATCHLEVEL * +([0-9]+)"
CUDNN_PATCH_VERSION "${CUDNN_VERSION_FILE_CONTENTS}")
string(REGEX REPLACE "define CUDNN_PATCHLEVEL * +([0-9]+)" "\\1"
CUDNN_PATCH_VERSION "${CUDNN_PATCH_VERSION}")
set(CUDNN_VERSION ${CUDNN_MAJOR_VERSION}.${CUDNN_MINOR_VERSION})
endif()
if(CUDNN_MAJOR_VERSION)
## Fixing the case where 5.1 does not fit 'exact' 5.
if(CUDNN_FIND_VERSION_EXACT AND NOT CUDNN_FIND_VERSION_MINOR)
if("${CUDNN_MAJOR_VERSION}" STREQUAL "${CUDNN_FIND_VERSION_MAJOR}")
set(CUDNN_VERSION ${CUDNN_FIND_VERSION})
endif()
endif()
else()
# Try to set CUDNN version from config file
set(CUDNN_VERSION ${PC_CUDNN_CFLAGS_OTHER})
endif()
find_package_handle_standard_args(
CUDNN
FOUND_VAR CUDNN_FOUND
REQUIRED_VARS CUDNN_LIBRARY
VERSION_VAR CUDNN_VERSION
)
if(CUDNN_FOUND)
set(CUDNN_LIBRARIES ${CUDNN_LIBRARY})
set(CUDNN_INCLUDE_DIRS ${CUDNN_INCLUDE_DIR})
set(CUDNN_DEFINITIONS ${PC_CUDNN_CFLAGS_OTHER})
endif()

View file

@ -0,0 +1,18 @@
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
cmake_minimum_required(VERSION 3.8)
add_subdirectory(cuda)
if(BUILD_TF)
add_subdirectory(tf_op)
endif()

View file

@ -0,0 +1,104 @@
/*
* Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/**
* Memory Allocator
**/
#pragma once
#include "fastertransformer/common.h"
#include "fastertransformer/utils.h"
#include <cuda_runtime.h>
#ifdef GOOGLE_CUDA
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/shape_inference.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/lib/core/errors.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#endif
namespace fastertransformer{
class IAllocator{
public:
virtual void* malloc(size_t size) const = 0;
virtual void free(void* ptr) const = 0;
};
template<AllocatorType AllocType_>
class Allocator;
template<>
class Allocator<AllocatorType::CUDA> : public IAllocator{
const int device_id_;
public:
Allocator(int device_id): device_id_(device_id){}
void* malloc(size_t size) const {
void* ptr = nullptr;
int o_device = 0;
check_cuda_error(get_set_device(device_id_, &o_device));
check_cuda_error(cudaMalloc(&ptr, size));
check_cuda_error(get_set_device(o_device));
return ptr;
}
void free(void* ptr) const {
int o_device = 0;
check_cuda_error(get_set_device(device_id_, &o_device));
check_cuda_error(cudaFree(ptr));
check_cuda_error(get_set_device(o_device));
return;
}
};
//TODO: allocator of TensorFlow
// You can add context to constructor
#ifdef GOOGLE_CUDA
using namespace tensorflow;
template<>
class Allocator<AllocatorType::TF> : public IAllocator{
OpKernelContext *context_;
public:
Allocator(OpKernelContext *context): context_(context){}
void* malloc(size_t size) const {
Tensor buf;
long long int buf_size = (long long int)size;
tensorflow::Status status = context_->allocate_temp(DT_UINT8, TensorShape{buf_size}, &buf);
if(status != tensorflow::Status::OK())
throw std::runtime_error("TF error: context->allocate_temp failed");
auto flat = buf.flat<uint8>();
void* ptr = (void*)flat.data();
return ptr;
}
void free(void* ptr) const {
printf("call from allocator free\n");
return;
}
};
#endif
}//namespace fastertransformer

View file

@ -0,0 +1,305 @@
/*
* Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/**
* BERT Encoder transformer
**/
#pragma once
#include <cuda_runtime.h>
#include "fastertransformer/allocator.h"
#include "fastertransformer/cuda/cuda_kernels.h"
#include "fastertransformer/cuda/open_attention.h"
#include "fastertransformer/encoder_transformer.h"
namespace fastertransformer{
template<typename T>
class EncoderInitParam
{
public:
const T* from_tensor;
const T* to_tensor;
const T* attr_kernel_Q;
const T* attr_kernel_K;
const T* attr_kernel_V;
const T* attr_bias_Q;
const T* attr_bias_K;
const T* attr_bias_V;
const T* attr_mask;
const T* attr_output_kernel;
const T* attr_output_bias;
const T* attr_output_layernorm_gamma;
const T* attr_output_layernorm_beta;
const T* inter_kernel;
const T* inter_bias;
const T* output_kernel;
const T* output_bias;
const T* output_layernorm_gamma;
const T* output_layernorm_beta;
T* transformer_out;
cublasHandle_t cublas_handle;
cudaStream_t stream;
};
template<OperationType OpType_, template<OperationType> class MultiHeadAttention_>
class BertEncoderTransformerTraits;
template<template<OperationType> class MultiHeadAttention_>
class BertEncoderTransformerTraits<OperationType::FP32, MultiHeadAttention_>
{
public:
typedef float DataType;
static const OperationType OpType = OperationType::FP32;
typedef MultiHeadAttention_<OpType> MultiHeadAttention;
static cudaDataType_t const computeType = CUDA_R_32F;
static cudaDataType_t const AType = CUDA_R_32F;
static cudaDataType_t const BType = CUDA_R_32F;
static cudaDataType_t const CType = CUDA_R_32F;
//add FP32 Traits here
};
template<template<OperationType> class MultiHeadAttention_>
class BertEncoderTransformerTraits<OperationType::HALF, MultiHeadAttention_>
{
public:
typedef __half DataType;
static const OperationType OpType = OperationType::HALF;
typedef MultiHeadAttention_<OpType> MultiHeadAttention;
static cudaDataType_t const computeType = CUDA_R_16F;
static cudaDataType_t const AType = CUDA_R_16F;
static cudaDataType_t const BType = CUDA_R_16F;
static cudaDataType_t const CType = CUDA_R_16F;
//add HALF Traits here
};
template<class Traits_>
class BertEncoderTransformer:IEncoderTransformer<Traits_::OpType>
{
const IAllocator& allocator_;
typename Traits_::MultiHeadAttention *attention_;
typedef typename Traits_::DataType DataType_;
EncoderInitParam<DataType_> param_;
const cudaDataType_t computeType_ = Traits_::computeType;
const cudaDataType_t AType_ = Traits_::AType;
const cudaDataType_t BType_ = Traits_::BType;
const cudaDataType_t CType_ = Traits_::CType;
int cublasAlgo_[3];
DataType_* buf_;
DataType_* attr_out_buf_;
DataType_* attr_matmul_buf_;
DataType_* inter_matmul_buf_;
int batch_size_;
int from_seq_len_;
int to_seq_len_;
int head_num_;
int size_per_head_;
public:
BertEncoderTransformer(const IAllocator& allocator, int batch_size, int from_seq_len,
int to_seq_len, int head_num, int size_per_head):
allocator_(allocator), batch_size_(batch_size), from_seq_len_(from_seq_len),
to_seq_len_(to_seq_len), head_num_(head_num), size_per_head_(size_per_head){
#ifndef NDEBUG
PRINT_FUNC_NAME_();
#endif
int m = batch_size_ * from_seq_len_;
int k = head_num_ * size_per_head_;
int n = k;
int buf_size = m * n;
try
{
buf_ = reinterpret_cast<DataType_*>(allocator_.malloc(sizeof(DataType_) * buf_size * 6));
if(buf_ == nullptr)
throw std::runtime_error(std::string("Tensorflow Allocator failed to allocate internal buffer."));
attr_out_buf_ = buf_;
attr_matmul_buf_ = attr_out_buf_ + buf_size;
inter_matmul_buf_ = attr_matmul_buf_ + buf_size;
attention_ = new typename Traits_::MultiHeadAttention(allocator_, batch_size_, from_seq_len_, to_seq_len_, head_num_, size_per_head_);
FILE* fd = fopen("gemm_config.in", "r");
int err = 0;
if(fd == NULL)
printf("gemm_config.in is not found\n");
else
{
err = fscanf(fd, "%d%d%d%*d%*d", &cublasAlgo_[0], &cublasAlgo_[1], &cublasAlgo_[2]);
fclose(fd);
}
if(err != 3)
{
printf("loading GEMM algorithms error, using default GEMM algorithms!\n");
if(Traits_::OpType == OperationType::FP32)
{
cublasAlgo_[0] = -1;
cublasAlgo_[1] = -1;
cublasAlgo_[2] = -1;
}
else
{
cublasAlgo_[0] = 99;
cublasAlgo_[1] = 99;
cublasAlgo_[2] = 99;
}
}
}
catch(std::runtime_error& error)
{
throw error;
}
}
/**
* Initialize the parameters in class
* We will keep the Ctor empty to ensure the sub classes follow the same init routine.
* Please be aware that no dynamic memory allocation should be placed
**/
void initialize(EncoderInitParam<DataType_> param)
{
#ifndef NDEBUG
PRINT_FUNC_NAME_();
#endif
param_ = param;
cuda::MultiHeadInitParam<DataType_> multi_head_init_param;
multi_head_init_param.from_tensor = param.from_tensor;
multi_head_init_param.to_tensor = param.to_tensor;
multi_head_init_param.attr_kernel_Q = param.attr_kernel_Q;
multi_head_init_param.attr_kernel_K = param.attr_kernel_K;
multi_head_init_param.attr_kernel_V = param.attr_kernel_V;
multi_head_init_param.attr_bias_Q = param.attr_bias_Q;
multi_head_init_param.attr_bias_K = param.attr_bias_K;
multi_head_init_param.attr_bias_V = param.attr_bias_V;
multi_head_init_param.attr_mask = param.attr_mask;
multi_head_init_param.stream = param.stream;
multi_head_init_param.cublas_handle = param.cublas_handle;
multi_head_init_param.attr_out = attr_out_buf_;
attention_->initialize(multi_head_init_param);
}
/**
* do forward
**/
void forward() override
{
#ifndef NDEBUG
PRINT_FUNC_NAME_();
#endif
try{
attention_->forward();
#ifndef NDEBUG
cudaDeviceSynchronize();
check_cuda_error(cudaGetLastError());
#endif
DataType_ alpha = (DataType_)1.0f;
DataType_ beta = (DataType_)0.0f;
int m = batch_size_ * from_seq_len_;
int k = head_num_ * size_per_head_;
int n = k;
check_cuda_error(cublasGemmEx(param_.cublas_handle,
CUBLAS_OP_N, CUBLAS_OP_N,
n, m, k,
&alpha,
param_.attr_output_kernel, AType_, n,
attr_out_buf_, BType_, k,
&beta,
attr_matmul_buf_, CType_, n,
computeType_,
static_cast<cublasGemmAlgo_t>(cublasAlgo_[0])));
add_bias_input_layernorm_kernelLauncher<DataType_>(attr_matmul_buf_,
param_.from_tensor, param_.attr_output_bias, param_.attr_output_layernorm_gamma,
param_.attr_output_layernorm_beta, m, n, param_.stream);
#ifndef NDEBUG
cudaDeviceSynchronize();
check_cuda_error(cudaGetLastError());
#endif
n *= 4;
check_cuda_error(cublasGemmEx(param_.cublas_handle,
CUBLAS_OP_N, CUBLAS_OP_N,
n, m, k,
&alpha,
param_.inter_kernel, AType_, n,
attr_matmul_buf_, BType_, k,
&beta,
inter_matmul_buf_, CType_, n,
computeType_,
static_cast<cublasGemmAlgo_t>(cublasAlgo_[1])));
add_bias_act_kernelLauncher<DataType_>(inter_matmul_buf_, param_.inter_bias, m, n, param_.stream);
#ifndef NDEBUG
cudaDeviceSynchronize();
check_cuda_error(cudaGetLastError());
#endif
n = k;
k *= 4;
check_cuda_error(cublasGemmEx(param_.cublas_handle,
CUBLAS_OP_N, CUBLAS_OP_N,
n, m, k,
&alpha,
param_.output_kernel, AType_, n,
inter_matmul_buf_, BType_, k,
&beta,
param_.transformer_out, CType_, n,
computeType_,
static_cast<cublasGemmAlgo_t>(cublasAlgo_[2])));
add_bias_input_layernorm_kernelLauncher<DataType_>(param_.transformer_out, attr_matmul_buf_, param_.output_bias,
param_.output_layernorm_gamma,
param_.output_layernorm_beta,
m, n, param_.stream);
#ifndef NDEBUG
cudaDeviceSynchronize();
check_cuda_error(cudaGetLastError());
#endif
}
catch(std::runtime_error& error)
{
throw error;
}
}
void trt_initialize(DataType_* from_tensor, DataType_* to_tensor, DataType_* attr_mask, DataType_* out, cudaStream_t stream, cublasHandle_t cublas_handle)
{
param_.from_tensor = from_tensor;
param_.to_tensor = to_tensor;
param_.stream = stream;
param_.transformer_out = out;
param_.cublas_handle = cublas_handle;
attention_->trt_initialize(from_tensor, to_tensor, attr_mask, stream, param_.cublas_handle);
}
~BertEncoderTransformer()
{
delete attention_;
allocator_.free(buf_);
}
};
}

View file

@ -0,0 +1,81 @@
/*
* Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <iostream>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cublas_v2.h>
namespace fastertransformer{
enum class OperationType{FP32, HALF};
enum class AllocatorType{CUDA, TF};
#define PRINT_FUNC_NAME_() do{\
std::cout << "[FT][CALL] " << __FUNCTION__ << " " << std::endl; \
} while (0)
static const char *_cudaGetErrorEnum(cudaError_t error) {
return cudaGetErrorString(error);
}
static const char *_cudaGetErrorEnum(cublasStatus_t error) {
switch (error) {
case CUBLAS_STATUS_SUCCESS:
return "CUBLAS_STATUS_SUCCESS";
case CUBLAS_STATUS_NOT_INITIALIZED:
return "CUBLAS_STATUS_NOT_INITIALIZED";
case CUBLAS_STATUS_ALLOC_FAILED:
return "CUBLAS_STATUS_ALLOC_FAILED";
case CUBLAS_STATUS_INVALID_VALUE:
return "CUBLAS_STATUS_INVALID_VALUE";
case CUBLAS_STATUS_ARCH_MISMATCH:
return "CUBLAS_STATUS_ARCH_MISMATCH";
case CUBLAS_STATUS_MAPPING_ERROR:
return "CUBLAS_STATUS_MAPPING_ERROR";
case CUBLAS_STATUS_EXECUTION_FAILED:
return "CUBLAS_STATUS_EXECUTION_FAILED";
case CUBLAS_STATUS_INTERNAL_ERROR:
return "CUBLAS_STATUS_INTERNAL_ERROR";
case CUBLAS_STATUS_NOT_SUPPORTED:
return "CUBLAS_STATUS_NOT_SUPPORTED";
case CUBLAS_STATUS_LICENSE_ERROR:
return "CUBLAS_STATUS_LICENSE_ERROR";
}
return "<unknown>";
}
template <typename T>
void check(T result, char const *const func, const char *const file, int const line) {
if (result) {
throw std::runtime_error(std::string("[FT][ERROR] CUDA runtime error: ") + \
(_cudaGetErrorEnum(result)) + " " + file + \
":" + std::to_string(line) + " \n");\
}
}
#define check_cuda_error(val) check((val), #val, __FILE__, __LINE__)
}//namespace fastertransformer

View file

@ -0,0 +1,23 @@
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
cmake_minimum_required(VERSION 3.8)
set(cuda_kernel_files
cuda_kernels.cu
open_attention.cu
)
add_library(fastertransformer STATIC ${cuda_kernel_files})
target_link_libraries(fastertransformer PUBLIC -lcublas -lcudart ${CMAKE_THREAD_LIBS_INIT})

View file

@ -0,0 +1,243 @@
/*
* Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "cuda_kernels.h"
#include <assert.h>
#include <cstdio>
#include <cstdlib>
namespace fastertransformer{
#define FINAL_MASK 0xffffffff
#define CUDART_PI_F 3.141592654f
template <typename T>
__inline__ __device__
T gelu(T x)
{
float cdf = 0.5f * (1.0f + tanhf((0.7978845608028654f * (x + 0.044715f * x * x * x))));
return x * cdf;
}
template <>
__inline__ __device__
half2 gelu(half2 val)
{
half2 val_pow3 = __hmul2(val, __hmul2(val, val));
float2 tmp_pow = __half22float2(val_pow3);
float2 tmp = __half22float2(val);
tmp.x = 0.5f * (1.0f + tanhf((0.7978845608028654f * (tmp.x + 0.044715f * tmp_pow.x))));
tmp.y = 0.5f * (1.0f + tanhf((0.7978845608028654f * (tmp.y + 0.044715f * tmp_pow.y))));
return __hmul2(val, __float22half2_rn(tmp));
}
template <typename T>
__inline__ __device__
T warpReduceSum(T val)
{
for(int mask = 16; mask > 0; mask >>= 1)
val += __shfl_xor_sync(FINAL_MASK, val, mask, 32);
return val;
}
template <typename T>
__inline__ __device__
T blockReduceSum(T val)
{
static __shared__ T shared[32];
int lane = threadIdx.x & 0x1f;
int wid = threadIdx.x >> 5;
val = warpReduceSum<T>(val);
if(lane == 0)
shared[wid] = val;
__syncthreads();
val = (threadIdx.x < (blockDim.x >> 5 )) ? shared[lane] : (T)0.0f;
val = warpReduceSum(val);
return val;
}
template <typename T>
__global__
void add_bias_act(T* out, const T* bias, int m, int n)
{
T val, reg_bias;
int row_id = blockIdx.x;
int ite = n / blockDim.x;
int tid = threadIdx.x;
for(int i = 0; i < ite; ++i)
{
reg_bias = __ldg(&bias[i * blockDim.x + tid]);
row_id = blockIdx.x;
while(row_id < m){
val = out[tid + i * blockDim.x + row_id * n]+ reg_bias;
out[tid + i * blockDim.x + row_id * n] = gelu<T>(val);
row_id += gridDim.x;
}
}
}
template <>
__global__
void add_bias_act(__half* out, const __half* bias, int m, int n)
{
half2 val, reg_bias;
int row_id = blockIdx.x;
int ite = n / blockDim.x / 2;
int tid = threadIdx.x;
half2* out_ptr = (half2*) out;
const half2* bias_ptr = (half2*) bias;
for(int i = 0; i < ite; ++i)
{
reg_bias = __ldg(&bias_ptr[i * blockDim.x + tid]);
row_id = blockIdx.x;
while(row_id < m){
val = out_ptr[tid + i * blockDim.x + row_id * n / 2];
val = __hadd2(val, reg_bias);
out_ptr[tid + i * blockDim.x + row_id * n / 2] = gelu<half2>(val);
row_id += gridDim.x;
}
}
}
template <typename T>
__global__
void add_bias_input_layernorm(T* out, const T* input, const T* bias, const T* gamma, const T* beta, int m, int n)
{
int tid = threadIdx.x;
__shared__ float s_mean;
__shared__ float s_variance;
float mean = 0.0f;
float variance = 0.0f;
float local_out = 0.0f;
for(int i = tid; i < n; i += blockDim.x)
local_out += (float)(out[blockIdx.x * n + i] + input[blockIdx.x * n + i] + __ldg(&bias[i]));
mean = blockReduceSum<float>(local_out);
if(threadIdx.x == 0)
s_mean = mean / n;
__syncthreads();
variance = blockReduceSum<float>((local_out - s_mean) * (local_out - s_mean));
if(threadIdx.x == 0)
s_variance = variance / n + 1e-6f;
__syncthreads();
for(int i = tid; i < n; i += blockDim.x)
out[blockIdx.x * n + i] =
(T)(((local_out - s_mean) * rsqrtf(s_variance)) * (float)(__ldg(&gamma[i])) + (float)(__ldg(&beta[i])));
}
template <>
__global__
void add_bias_input_layernorm(__half* out, const __half* input, const __half* bias,
const __half* gamma, const __half* beta, int m, int n)
{
int tid = threadIdx.x;
__shared__ float s_mean;
__shared__ float s_variance;
float mean = 0.0f;
float variance = 0.0f;
float2 local_out_fp2;
half2* out_ptr = (half2*)out;
const half2* input_ptr = (const half2*)input;
const half2* bias_ptr = (const half2*)bias;
const half2* gamma_ptr = (const half2*)gamma;
const half2* beta_ptr = (const half2*)beta;
float local_out = 0.0f;
int id = blockIdx.x * n / 2 + tid;
local_out_fp2 = __half22float2(__hadd2(__hadd2(out_ptr[id], input_ptr[id]), __ldg(&bias_ptr[tid])));
local_out += local_out_fp2.x;
local_out += local_out_fp2.y;
mean = blockReduceSum<float>(local_out);
if(threadIdx.x == 0)
s_mean = mean / n;
__syncthreads();
variance = (local_out_fp2.x - s_mean) * (local_out_fp2.x - s_mean);
variance += (local_out_fp2.y - s_mean) * (local_out_fp2.y - s_mean);
variance = blockReduceSum<float>(variance);
if(threadIdx.x == 0)
s_variance = rsqrtf(variance / n + 1e-6f);
__syncthreads();
float2 gamma_val = __half22float2(__ldg(&gamma_ptr[tid]));
float2 beta_val = __half22float2(__ldg(&beta_ptr[tid]));
local_out_fp2.x = (local_out_fp2.x - s_mean) * s_variance * gamma_val.x + beta_val.x;
local_out_fp2.y = (local_out_fp2.y - s_mean) * s_variance * gamma_val.y + beta_val.y;
out_ptr[id] = __float22half2_rn(local_out_fp2);
}
template <typename T>
void add_bias_act_kernelLauncher(T* out, const T* bias, int m, int n, cudaStream_t stream)
{
// dim3 grid(m / 64);
dim3 grid(m / 4);
dim3 block(n / 4);
assert(block.x > 1024);
// dim3 block(n);
add_bias_act<T><<<grid, block, 0, stream>>>(out, bias, m, n);
}
template<typename T>
void add_bias_input_layernorm_kernelLauncher(T* out, const T* input, const T* bias,
const T* gamma, const T* beta, int m, int n, cudaStream_t stream)
{
assert(n > 1024);
dim3 grid(m);
dim3 block(n);
add_bias_input_layernorm<T><<<grid, block, 0, stream>>>(out, input, bias, gamma, beta, m, n);
}
template <>
void add_bias_input_layernorm_kernelLauncher(__half* out, const __half* input, const __half* bias,
const __half* gamma, const __half* beta, int m, int n, cudaStream_t stream)
{
assert(n / 2 > 1024);
dim3 grid(m);
dim3 block(n / 2);
add_bias_input_layernorm<__half><<<grid, block, 0, stream>>>(out, input, bias, gamma, beta, m, n);
}
template void add_bias_act_kernelLauncher<float>(
float* out, const float* bias, int m, int n, cudaStream_t stream);
template void add_bias_input_layernorm_kernelLauncher<float>(
float* out, const float* input, const float* bias, const float* gamma, const float* beta,
int m, int n, cudaStream_t stream);
template void add_bias_act_kernelLauncher<__half>(
__half* out, const __half* bias, int m, int n, cudaStream_t stream);
template void add_bias_input_layernorm_kernelLauncher<__half>(
__half* out, const __half* input, const __half* bias, const __half* gamma, const __half* beta,
int m, int n, cudaStream_t stream);
}//namespace

View file

@ -0,0 +1,28 @@
/*
* Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <cuda_runtime.h>
#include <cuda_fp16.h>
namespace fastertransformer{
template <typename T>
void add_bias_act_kernelLauncher(T* out, const T* bias, int m, int n, cudaStream_t stream);
template <typename T>
void add_bias_input_layernorm_kernelLauncher(T* out, const T* input_tensor, const T* bias,
const T* gamma, const T* beta, int m, int n, cudaStream_t stream);
}//namespace fastertransformer

View file

@ -0,0 +1,84 @@
/*
* Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/**
* Multi-head attention interface
*/
#pragma once
#include "fastertransformer/common.h"
namespace fastertransformer{
namespace cuda{
template<typename T>
class MultiHeadInitParam{
public:
const T* from_tensor;
const T* to_tensor;
const T* attr_kernel_Q;
const T* attr_kernel_K;
const T* attr_kernel_V;
const T* attr_bias_Q;
const T* attr_bias_K;
const T* attr_bias_V;
const T* attr_mask;
T* attr_out;
cublasHandle_t cublas_handle;
cudaStream_t stream;
MultiHeadInitParam(){
from_tensor = nullptr;
to_tensor = nullptr;
attr_kernel_Q = nullptr;
attr_kernel_K = nullptr;
attr_kernel_V = nullptr;
attr_bias_Q = nullptr;
attr_bias_K = nullptr;
attr_bias_V = nullptr;
attr_mask = nullptr;
attr_out = nullptr;
cublas_handle = nullptr;
stream = 0;
}
};
/**
* Interface of attention operation
*/
template<OperationType OpType_>
class IMultiHeadAttention{
public:
// typedef MultiHeadInitParam<OpType_> InitParam;
/**
* do forward
**/
virtual void forward() = 0;
/**
* Initialize the parameters in class
* We will keep the Ctor empty to ensure the sub classes follow the same init routine.
* Please be aware that no dynamic memory allocation should be placed
**/
// virtual void free() = 0;
virtual ~IMultiHeadAttention(){}
};
}//namespace cuda
}//namespace fastertransformer

View file

@ -0,0 +1,388 @@
/*
* Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/**
* Open sourced multi-head attention
**/
#include "fastertransformer/allocator.h"
#include "fastertransformer/cuda/multi_head_attention.h"
#include "fastertransformer/cuda/open_attention.h"
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cmath>
namespace fastertransformer{
namespace cuda{
/**
* Multi-head attetion open sourced
*/
#define FINAL_MASK 0xffffffff
template <typename T>
__inline__ __device__
T warpReduceSum(T val)
{
for(int mask = 16; mask > 0; mask >>= 1)
val += __shfl_xor_sync(FINAL_MASK, val, mask, 32);
return val;
}
/* Calculate the sum of all elements in a block */
template <typename T>
__inline__ __device__
T blockReduceSum(T val)
{
static __shared__ T shared[32];
int lane = threadIdx.x & 0x1f;
int wid = threadIdx.x >> 5;
val = warpReduceSum<T>(val);
if(lane == 0)
shared[wid] = val;
__syncthreads();
val = (threadIdx.x < (blockDim.x >> 5 )) ? shared[lane] : (T)(0.0f);
val = warpReduceSum<T>(val);
return val;
}
__inline__ __device__
int target_index(int id1, int id2, int id3, int id4, int dim_1, int dim_2, int dim_3, int dim_4)
{
return id1 * (dim_2 * dim_3 * dim_4) + id3 * (dim_2 * dim_4) + id2 * dim_4 + id4;
}
template<typename T>
__global__
void add_QKV_bias(T* Q, const T* bias_Q, T* K, const T* bias_K, T* V, const T* bias_V, T* q_buf_, T* k_buf_, T* v_buf_,
const int batch_size, const int seq_len, const int head_num, const int size_per_head, const int word_per_block)
{
T* data_ptr;
T* buf_ptr;
const T* bias_ptr;
int m = batch_size * seq_len;
int n = head_num * size_per_head;
int qkv_id = blockIdx.x * word_per_block / m;
int row_offset = (blockIdx.x * word_per_block % m) * n;
if(qkv_id == 0)
{
data_ptr = Q + row_offset;
buf_ptr = q_buf_;
bias_ptr = bias_Q;
}
else if(qkv_id == 1)
{
data_ptr = K + row_offset;
buf_ptr = k_buf_;
bias_ptr = bias_K;
}
else
{
data_ptr = V + row_offset;
buf_ptr = v_buf_;
bias_ptr = bias_V;
}
int batch_id = (blockIdx.x * word_per_block % m) / seq_len;
int head_id = threadIdx.x / size_per_head;
int id_in_head = threadIdx.x % size_per_head;
int word_start_id = (blockIdx.x * word_per_block) % seq_len;
T bias = __ldg(&bias_ptr[threadIdx.x]);
for(int i = word_start_id; i < word_start_id + word_per_block; ++i)
{
T tmp = data_ptr[threadIdx.x] + bias;
int target_id = batch_id * (seq_len * head_num * size_per_head) + head_id * seq_len * size_per_head +
i * size_per_head + id_in_head;
buf_ptr[target_id] = tmp;
data_ptr += n;
}
}
template <>
__global__
void add_QKV_bias(__half* Q, const __half* bias_Q, __half* K, const __half* bias_K, __half* V, const __half* bias_V,
__half* q_buf_, __half* k_buf_, __half* v_buf_,
const int batch_size, const int seq_len, const int head_num, const int size_per_head, const int word_per_block)
{
int tid = blockIdx.x * blockDim.x + threadIdx.x;
int batch_id = tid / (head_num * seq_len * size_per_head);
int seq_id = (tid % (head_num * seq_len * size_per_head)) / (head_num * size_per_head);
int head_id = (tid % (head_num * size_per_head)) / size_per_head;
int id = tid % size_per_head;
int target_id = target_index(batch_id, seq_id, head_id, id, batch_size, seq_len, head_num, size_per_head);
int bias_id = threadIdx.x;
half2* src_ptr = (half2*)Q;
half2* dst_ptr = (half2*)q_buf_;
const half2* bias_ptr = (const half2*)bias_Q;
dst_ptr[target_id] = __hadd2(src_ptr[tid], __ldg(&bias_ptr[bias_id]));
src_ptr = (half2*)K;
dst_ptr = (half2*)k_buf_;
bias_ptr = (const half2*)bias_K;
dst_ptr[target_id] = __hadd2(src_ptr[tid], __ldg(&bias_ptr[bias_id]));
src_ptr = (half2*)V;
dst_ptr = (half2*)v_buf_;
bias_ptr = (const half2*)bias_V;
dst_ptr[target_id] = __hadd2(src_ptr[tid], __ldg(&bias_ptr[bias_id]));
}
template <typename T>
__global__
void softmax_kernel(T* qk_buf_, const T* attr_mask, const int batch_size, const int head_num, const int seq_len,
const T scaler)
{
int batch_id = blockIdx.x / head_num;
int qk_offset = blockIdx.x * seq_len * seq_len;
int mask_offset = batch_id * seq_len * seq_len;
__shared__ float s_sum;
for(int i = 0; i < seq_len; ++i)
{
T qk = qk_buf_[threadIdx.x + qk_offset];
T mask_val = attr_mask[threadIdx.x + mask_offset];
mask_val = ((T)1.0f - mask_val) * (T)(-10000.0f);
qk = __expf((float)(qk * scaler + mask_val));
float sum_val = blockReduceSum<float>(qk);
if(threadIdx.x == 0)
{
s_sum = sum_val + 1e-6f;
}
__syncthreads();
qk_buf_[threadIdx.x + qk_offset] = qk / (T)s_sum;
qk_offset += seq_len;
mask_offset += seq_len;
}
}
template <typename T>
__global__
void softmax_kernel_v2(T* qk_buf_, const T* attr_mask, const int batch_size, const int head_num,
const int seq_len, const T scaler)
{
int batch_id = blockIdx.x / head_num / seq_len;
int seq_id = blockIdx.x % seq_len;
int qk_offset = blockIdx.x * seq_len;
int mask_offset = batch_id * seq_len * seq_len + seq_id * seq_len;
__shared__ float s_sum;
T qk = qk_buf_[threadIdx.x + qk_offset];
T mask_val = attr_mask[threadIdx.x + mask_offset];
mask_val = ((T)1.0f - mask_val) * (T)(-10000.0f);
float qk_tmp = __expf((float)(qk * scaler + mask_val));
float sum_val = blockReduceSum<float>(qk_tmp);
if(threadIdx.x == 0)
{
s_sum = sum_val + 1e-6f;
}
__syncthreads();
qk_buf_[threadIdx.x + qk_offset] = (T)(qk_tmp / s_sum);
}
template<typename T>
__global__
void transpose(T* src, T* dst, const int batch_size, const int seq_len, const int head_num, const int size_per_head)
{
int batch_id = blockIdx.x / (head_num * seq_len);
int seq_id = blockIdx.x % seq_len;
int head_id = (blockIdx.x % (head_num * seq_len))/ seq_len;
dst[batch_id * (head_num * seq_len * size_per_head) + seq_id * head_num * size_per_head
+ head_id * size_per_head + threadIdx.x] = src[blockIdx.x * size_per_head + threadIdx.x];
}
template<>
__global__
void transpose(__half* src, __half* dst,
const int batch_size, const int seq_len, const int head_num, const int size_per_head)
{
int tid = blockIdx.x * blockDim.x + threadIdx.x;
int batch_id = tid / (head_num * seq_len * size_per_head);
int head_id = (tid % (head_num * seq_len * size_per_head)) / (seq_len * size_per_head);
int seq_id = (tid % (seq_len * size_per_head)) / size_per_head;
int id = tid % size_per_head;
int target_id = target_index(batch_id, head_id, seq_id, id, batch_size, head_num, seq_len, size_per_head);
half2* src_ptr = (half2*)src;
half2* dst_ptr = (half2*)dst;
dst_ptr[target_id] = src_ptr[tid];
}
template<OperationType OpType_>
void OpenMultiHeadAttention<OpType_>::multiHeadAttr_nofuse_kernelLauncher(
cudaStream_t stream,
cublasHandle_t cublas_handle,
DataType_* Q,
const DataType_* bias_Q,
DataType_* K,
const DataType_* bias_K,
DataType_* V,
const DataType_* bias_V,
const DataType_* attr_mask,
DataType_* dst,
const int batch_size,
const int seq_len,
const int head_num,
const int size_per_head,
const DataType_ scaler)
{
int m = batch_size * seq_len;
int k = head_num * size_per_head;
dim3 grid;
dim3 block;
if(OpType_ == OperationType::FP32)
{
const int word_per_block = 32;
assert(k > 1024);
assert(m / word_per_block * 3 > 65536);
dim3 grid(m / word_per_block * 3);
dim3 block(k);
add_QKV_bias<DataType_><<<grid, block, 0, stream>>>(Q, bias_Q, K, bias_K, V, bias_V, q_buf_, k_buf_, v_buf_,
batch_size, seq_len, head_num, size_per_head, word_per_block);
}
else
{
const int word_per_block = 1;
grid.x = batch_size * seq_len / word_per_block;
block.x = head_num * size_per_head * word_per_block / 2;
assert(block.x);
add_QKV_bias<DataType_><<<grid, block, 0, stream>>>(Q, bias_Q, K, bias_K, V, bias_V, q_buf_, k_buf_,
v_buf_, batch_size, seq_len, head_num, size_per_head / 2, word_per_block);
}
DataType_ alpha = (DataType_)1.0f, beta = (DataType_)0.0f;
check_cuda_error(cublasGemmStridedBatchedEx(cublas_handle,
CUBLAS_OP_T, CUBLAS_OP_N,
seq_len, seq_len, size_per_head,
&alpha,
k_buf_, AType_, size_per_head, seq_len * size_per_head,
q_buf_, BType_, size_per_head, seq_len * size_per_head,
&beta,
qk_buf_, CType_, seq_len, seq_len * seq_len,
batch_size * head_num,
computeType_,
static_cast<cublasGemmAlgo_t>(cublasAlgo_[1])));
if(batch_size * head_num <= 120)
{
grid.x = batch_size * head_num * seq_len;
block.x = seq_len;
softmax_kernel_v2<DataType_><<<grid, block, 0, stream>>>(qk_buf_, attr_mask, batch_size, head_num, seq_len, scaler);
}
else
{
grid.x = batch_size * head_num;
block.x = seq_len;
softmax_kernel<DataType_><<<grid, block, 0, stream>>>(qk_buf_, attr_mask, batch_size, head_num, seq_len, scaler);
}
check_cuda_error(cublasGemmStridedBatchedEx(cublas_handle,
CUBLAS_OP_N, CUBLAS_OP_N,
size_per_head, seq_len, seq_len,
&alpha,
v_buf_, AType_, size_per_head, seq_len * size_per_head,
qk_buf_, BType_, seq_len, seq_len * seq_len,
&beta,
transpose_dst_, CType_, size_per_head, seq_len * size_per_head,
batch_size * head_num,
computeType_,
static_cast<cublasGemmAlgo_t>(cublasAlgo_[2])));
/* for half2 only */
if(OpType_ == OperationType::HALF)
{
const int seq_per_block = 4;
grid.x = batch_size * head_num * seq_len / seq_per_block;
block.x = seq_per_block * size_per_head / 2;
transpose<DataType_><<<grid, block, 0, stream>>>(transpose_dst_, dst,
batch_size, seq_len, head_num, size_per_head / 2);
}
else
{
const int seq_per_block = 1;
grid.x = batch_size * head_num * seq_len / seq_per_block;
block.x = seq_per_block * size_per_head;
transpose<DataType_><<<grid, block, 0, stream>>>(transpose_dst_, dst,
batch_size, seq_len, head_num, size_per_head);
}
}
template void OpenMultiHeadAttention<OperationType::FP32>::multiHeadAttr_nofuse_kernelLauncher(
cudaStream_t stream,
cublasHandle_t handle,
float* Q,
const float* bias_Q,
float* K,
const float* bias_K,
float* V,
const float* bias_V,
const float* attr_mask,
float* dst,
const int batch_size,
const int seq_len,
const int head_num,
const int size_per_head,
const float scaler);
template void OpenMultiHeadAttention<OperationType::HALF>::multiHeadAttr_nofuse_kernelLauncher(
cudaStream_t stream,
cublasHandle_t handle,
__half* Q,
const __half* bias_Q,
__half* K,
const __half* bias_K,
__half* V,
const __half* bias_V,
const __half* attr_mask,
__half* dst,
const int batch_size,
const int seq_len,
const int head_num,
const int size_per_head,
const __half scaler);
}//namespace cuda
}//namespace fastertransformer

View file

@ -0,0 +1,277 @@
/*
* Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/**
* Open sourced multi-head attention
**/
#pragma once
#include "fastertransformer/allocator.h"
#include "fastertransformer/cuda/multi_head_attention.h"
#include <assert.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
namespace fastertransformer{
namespace cuda{
template<OperationType OpType_>
class OpenMultiHeadAttentionTraits;
template<>
class OpenMultiHeadAttentionTraits<OperationType::FP32>
{
public:
typedef float DataType;
static cudaDataType_t const computeType = CUDA_R_32F;
static cudaDataType_t const AType = CUDA_R_32F;
static cudaDataType_t const BType = CUDA_R_32F;
static cudaDataType_t const CType = CUDA_R_32F;
//others
};
template<>
class OpenMultiHeadAttentionTraits<OperationType::HALF>
{
public:
typedef __half DataType;
static cudaDataType_t const computeType = CUDA_R_16F;
static cudaDataType_t const AType = CUDA_R_16F;
static cudaDataType_t const BType = CUDA_R_16F;
static cudaDataType_t const CType = CUDA_R_16F;
//others
};
/**
* Multi-head attetion open sourced
*/
template<OperationType OpType_>
class OpenMultiHeadAttention: IMultiHeadAttention<OpType_>
{
private:
typedef OpenMultiHeadAttentionTraits<OpType_> Traits_;
typedef typename Traits_::DataType DataType_;
const cudaDataType_t computeType_ = Traits_::computeType;
const cudaDataType_t AType_ = Traits_::AType;
const cudaDataType_t BType_ = Traits_::BType;
const cudaDataType_t CType_ = Traits_::CType;
const IAllocator& allocator_;
MultiHeadInitParam<DataType_> param_;
int cublasAlgo_[3];
DataType_* buf_;
DataType_* query_buf_;
DataType_* key_buf_;
DataType_* value_buf_;
DataType_* q_buf_;
DataType_* k_buf_;
DataType_* v_buf_;
DataType_* qk_buf_;
DataType_* transpose_dst_;
int batch_size_;
int from_seq_len_;
int to_seq_len_;
int head_num_;
int size_per_head_;
public:
//Ctor
OpenMultiHeadAttention(const IAllocator& allocator, int batch_size, int from_seq_len,
int to_seq_len, int head_num, int size_per_head):
allocator_(allocator), batch_size_(batch_size), from_seq_len_(from_seq_len), to_seq_len_(to_seq_len),
head_num_(head_num), size_per_head_(size_per_head)
{
#ifndef NDEBUG
PRINT_FUNC_NAME_();
#endif
int buf_size = batch_size_ * head_num_ * from_seq_len_ * size_per_head_;
int qk_buf_size = batch_size_ * head_num_ * from_seq_len_ * from_seq_len_;
try
{
buf_ = (DataType_*) allocator_.malloc(sizeof(DataType_) * (buf_size * 7 + qk_buf_size));
query_buf_ = buf_;
key_buf_ = buf_ + buf_size;
value_buf_ = buf_ + 2 * buf_size;
q_buf_ = buf_ + 3 * buf_size;
k_buf_ = buf_ + 4 * buf_size;
v_buf_ = buf_ + 5 * buf_size;
qk_buf_ = buf_ + 6 * buf_size;
transpose_dst_ = qk_buf_ + qk_buf_size;
FILE* fd = fopen("gemm_config.in", "r");
int err = 0;
if(fd == NULL)
printf("gemm_config.in is not found\n");
else
{
err = fscanf(fd, "%d%*d%*d%d%d", &cublasAlgo_[0], &cublasAlgo_[1], &cublasAlgo_[2]);
fclose(fd);
}
if(err != 3)
{
printf("loading GEMM algorithms error, using default GEMM algorithms\n");
if(OpType_ == OperationType::FP32)
{
cublasAlgo_[0] = -1;
cublasAlgo_[1] = -1;
cublasAlgo_[2] = -1;
}
else
{
cublasAlgo_[0] = 99;
cublasAlgo_[1] = 99;
cublasAlgo_[2] = 99;
}
}
}
catch(std::runtime_error& error)
{
throw error;
}
}
void forward()
{
#ifndef NDEBUG
PRINT_FUNC_NAME_();
#endif
int m = batch_size_ * from_seq_len_;
int k = head_num_ * size_per_head_;
int n = k;
DataType_ alpha = (DataType_)1.0f, beta = (DataType_)0.0f;
try
{
check_cuda_error(cublasGemmEx(param_.cublas_handle,
CUBLAS_OP_N, CUBLAS_OP_N,
n, m, k,
&alpha,
param_.attr_kernel_Q, AType_, n,
param_.from_tensor, BType_, k,
&beta,
query_buf_, CType_, n,
computeType_,
static_cast<cublasGemmAlgo_t>(cublasAlgo_[0])));
check_cuda_error(cublasGemmEx(param_.cublas_handle,
CUBLAS_OP_N, CUBLAS_OP_N,
n, m, k,
&alpha,
param_.attr_kernel_K, AType_, n,
param_.to_tensor, BType_, k,
&beta,
key_buf_, CType_, n,
computeType_,
static_cast<cublasGemmAlgo_t>(cublasAlgo_[0])));
check_cuda_error(cublasGemmEx(param_.cublas_handle,
CUBLAS_OP_N, CUBLAS_OP_N,
n, m, k,
&alpha,
param_.attr_kernel_V, AType_, n,
param_.to_tensor, BType_, k,
&beta,
value_buf_, CType_, n,
computeType_,
static_cast<cublasGemmAlgo_t>(cublasAlgo_[0])));
DataType_ scaler = 1 / sqrtf(size_per_head_ * 1.0f);
multiHeadAttr_nofuse_kernelLauncher(
param_.stream,
param_.cublas_handle,
query_buf_,
param_.attr_bias_Q,
key_buf_,
param_.attr_bias_K,
value_buf_,
param_.attr_bias_V,
param_.attr_mask,
param_.attr_out,
batch_size_,
from_seq_len_,
head_num_,
size_per_head_,
scaler);
}
catch(std::runtime_error& error)
{
throw error;
}
}
void multiHeadAttr_kernelLauncher(
cudaStream_t stream,
const DataType_* Q,
const DataType_* bias_Q,
const DataType_* K,
const DataType_* bias_K,
const DataType_* V,
const DataType_* bias_V,
const DataType_* attr_mask,
DataType_* dst,
const int batch_size,
const int seq_len,
const int head_num,
const int size_per_head,
const DataType_ scaler);
void multiHeadAttr_nofuse_kernelLauncher(
cudaStream_t stream,
cublasHandle_t handle,
DataType_* Q,
const DataType_* bias_Q,
DataType_* K,
const DataType_* bias_K,
DataType_* V,
const DataType_* bias_V,
const DataType_* attr_mask,
DataType_* dst,
const int batch_size,
const int seq_len,
const int head_num,
const int size_per_head,
const DataType_ scaler);
void initialize(MultiHeadInitParam<DataType_> param)
{
#ifndef NDEBUG
PRINT_FUNC_NAME_();
#endif
//Do all the malloc here
param_ = param;
}
void trt_initialize(DataType_* from_tensor, DataType_* to_tensor, DataType_* attr_mask, cudaStream_t stream,
cublasHandle_t cublas_handle)
{
param_.from_tensor = from_tensor;
param_.to_tensor = to_tensor;
param_.attr_mask = attr_mask;
param_.stream = stream;
param_.cublas_handle = cublas_handle;
}
~OpenMultiHeadAttention() override
{
allocator_.free(buf_);
}
};
}//namespace cuda
}//namespace fastertransformer

View file

@ -0,0 +1,39 @@
/*
* Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/**
* Encoder transformer
**/
#pragma once
#include "fastertransformer/common.h"
namespace fastertransformer{
template<OperationType OpType_>
class IEncoderTransformer{
public:
/**
* do forward
**/
virtual void forward() = 0;
virtual ~IEncoderTransformer(){}
};
} //namespace fastertransformer

View file

@ -0,0 +1,28 @@
/*
* Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/**
* c++ interface of Faster Transformer
**/
#pragma once
#include "fastertransformer/bert_encoder_transformer.h"
#include <cuda_fp16.h>
namespace fastertransformer{
}//namespace fastertransformer

View file

@ -0,0 +1,28 @@
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
cmake_minimum_required(VERSION 3.8)
set(tf_bert_transformer_files
bert_transformer_op.cc
bert_transformer_op.cu.cc
../cuda/open_attention.cu
../cuda/cuda_kernels.cu
)
add_definitions(-D_GLIBCXX_USE_CXX11_ABI=0)
add_definitions(-DGOOGLE_CUDA=1)
add_definitions(-DNDEBUG)
add_library(tf_fastertransformer SHARED ${tf_bert_transformer_files})
target_link_libraries(tf_fastertransformer PRIVATE -lcublas -lcudart -ltensorflow_framework ${CMAKE_THREAD_LIBS_INIT})

View file

@ -0,0 +1,185 @@
/*
* Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "fastertransformer/faster_transformer.h"
#include "fastertransformer/tf_op/bert_transformer_op.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/shape_inference.h"
#include "tensorflow/core/framework/register_types.h"
#include <cuda_fp16.h>
namespace tensorflow
{
namespace
{
using CPUDevice = Eigen::ThreadPoolDevice;
using GPUDevice = Eigen::GpuDevice;
REGISTER_OP("BertTransformer")
.Input("from_tensor: T")
.Input("to_tensor: T")
.Input("attr_kernel_q: T")
.Input("attr_kernel_k: T")
.Input("attr_kernel_v: T")
.Input("attr_bias_q: T")
.Input("attr_bias_k: T")
.Input("attr_bias_v: T")
.Input("attr_mask: T")
.Input("attr_output_kernel: T")
.Input("attr_output_bias: T")
.Input("attr_output_layernorm_beta: T")
.Input("attr_output_layernorm_gamma: T")
.Input("inter_kernel: T")
.Input("inter_bias: T")
.Input("output_kernel: T")
.Input("output_bias: T")
.Input("output_layernorm_beta: T")
.Input("output_layernorm_gamma: T")
.Output("output: T")
.Attr("T: {float, half}")
.Attr("batch_size: int >= 1")
.Attr("from_seq_len: int >= 1")
.Attr("to_seq_len: int >= 1")
.Attr("head_num: int >= 1")
.Attr("size_per_head: int >= 1")
.SetShapeFn([](shape_inference::InferenceContext *c) {
int batch_size, from_seq_len, to_seq_len, head_num, size_per_head;
c->GetAttr("batch_size", &batch_size);
c->GetAttr("from_seq_len", &from_seq_len);
c->GetAttr("to_seq_len", &to_seq_len);
c->GetAttr("head_num", &head_num);
c->GetAttr("size_per_head", &size_per_head);
c->set_output(0, c->MakeShape({batch_size * from_seq_len, head_num * size_per_head}));
return Status::OK();
});
template <typename Device, typename T>
class BertTransformerOp : public OpKernel
{
public:
explicit BertTransformerOp(OpKernelConstruction *context) : OpKernel(context)
{
OP_REQUIRES_OK(context, context->GetAttr("batch_size", &batch_size_));
OP_REQUIRES_OK(context, context->GetAttr("from_seq_len", &from_seq_len_));
OP_REQUIRES_OK(context, context->GetAttr("to_seq_len", &to_seq_len_));
OP_REQUIRES_OK(context, context->GetAttr("head_num", &head_num_));
OP_REQUIRES_OK(context, context->GetAttr("size_per_head", &size_per_head_));
OP_REQUIRES(context, (from_seq_len_ == to_seq_len_),
errors::InvalidArgument("Only support from_seq_len == to_seq_len"));
try
{
check_cuda_error(cublasCreate(&cublas_handle_));
}
catch(std::runtime_error& error)
{
OP_REQUIRES(context, false, errors::Internal(error.what()));
}
}
void Compute(OpKernelContext *context) override
{
typedef BertEncoderTransformerTraits<traits_::OpType, cuda::OpenMultiHeadAttention> EncoderTraits_;
BertEncoderTransformer<EncoderTraits_> *encoder_transformer_;
try
{
fastertransformer::Allocator<AllocatorType::TF> allocator_(context);
encoder_transformer_ = new BertEncoderTransformer<EncoderTraits_>(allocator_,
batch_size_, from_seq_len_, to_seq_len_, head_num_, size_per_head_);
}
catch(std::runtime_error& error)
{
OP_REQUIRES(context, false, errors::Internal(error.what()));
}
OP_REQUIRES(context, context->num_inputs() == 19, errors::InvalidArgument("Less input arguments"));
EncoderInitParam<DataType_> param; //init param here
param.cublas_handle = cublas_handle_;
param.from_tensor = reinterpret_cast<const DataType_ *>(context->input(0).flat<T>().data());
param.to_tensor = reinterpret_cast<const DataType_ *>(context->input(1).flat<T>().data());
param.attr_kernel_Q = reinterpret_cast<const DataType_ *>(context->input(2).flat<T>().data());
param.attr_kernel_K = reinterpret_cast<const DataType_ *>(context->input(3).flat<T>().data());
param.attr_kernel_V = reinterpret_cast<const DataType_ *>(context->input(4).flat<T>().data());
param.attr_bias_Q = reinterpret_cast<const DataType_ *>(context->input(5).flat<T>().data());
param.attr_bias_K = reinterpret_cast<const DataType_ *>(context->input(6).flat<T>().data());
param.attr_bias_V = reinterpret_cast<const DataType_ *>(context->input(7).flat<T>().data());
param.attr_mask = reinterpret_cast<const DataType_ *>(context->input(8).flat<T>().data());
param.attr_output_kernel = reinterpret_cast<const DataType_ *>(context->input(9).flat<T>().data());
param.attr_output_bias = reinterpret_cast<const DataType_ *>(context->input(10).flat<T>().data());
param.attr_output_layernorm_beta = reinterpret_cast<const DataType_ *>(context->input(11).flat<T>().data());
param.attr_output_layernorm_gamma = reinterpret_cast<const DataType_ *>(context->input(12).flat<T>().data());
param.inter_kernel = reinterpret_cast<const DataType_ *>(context->input(13).flat<T>().data());
param.inter_bias = reinterpret_cast<const DataType_ *>(context->input(14).flat<T>().data());
param.output_kernel = reinterpret_cast<const DataType_ *>(context->input(15).flat<T>().data());
param.output_bias = reinterpret_cast<const DataType_ *>(context->input(16).flat<T>().data());
param.output_layernorm_beta = reinterpret_cast<const DataType_ *>(context->input(17).flat<T>().data());
param.output_layernorm_gamma = reinterpret_cast<const DataType_ *>(context->input(18).flat<T>().data());
OP_REQUIRES(context, param.from_tensor != nullptr, errors::InvalidArgument("from tensor is null"));
OP_REQUIRES(context, param.to_tensor != nullptr, errors::InvalidArgument("to tensor is null"));
OP_REQUIRES(context, param.attr_kernel_Q != nullptr, errors::InvalidArgument("attr_kernel_Q is null"));
OP_REQUIRES(context, param.attr_kernel_K != nullptr, errors::InvalidArgument("attr_kernel_K is null"));
OP_REQUIRES(context, param.attr_kernel_V != nullptr, errors::InvalidArgument("attr_kernel_V is null"));
OP_REQUIRES(context, param.attr_bias_Q != nullptr, errors::InvalidArgument("attr_bias_Q is null"));
OP_REQUIRES(context, param.attr_bias_K != nullptr, errors::InvalidArgument("attr_bias_K is null"));
OP_REQUIRES(context, param.attr_bias_V != nullptr, errors::InvalidArgument("attr_bias_V is null"));
OP_REQUIRES(context, param.attr_mask != nullptr, errors::InvalidArgument("attr_mask is null"));
OP_REQUIRES(context, param.attr_output_kernel != nullptr, errors::InvalidArgument("attr_output_kernel is null"));
OP_REQUIRES(context, param.attr_output_bias != nullptr, errors::InvalidArgument("attr_output_bias is null"));
OP_REQUIRES(context, param.attr_output_layernorm_beta != nullptr, errors::InvalidArgument("attr_output_layernorm_beta is null"));
OP_REQUIRES(context, param.attr_output_layernorm_gamma != nullptr, errors::InvalidArgument("attr_output_layernorm_gamma is null"));
OP_REQUIRES(context, param.inter_kernel != nullptr, errors::InvalidArgument("inter_kernel is null"));
OP_REQUIRES(context, param.inter_bias != nullptr, errors::InvalidArgument("inter_bias is null"));
OP_REQUIRES(context, param.output_kernel != nullptr, errors::InvalidArgument("output_kernel is null"));
OP_REQUIRES(context, param.output_bias != nullptr, errors::InvalidArgument("output_bias is null"));
OP_REQUIRES(context, param.output_layernorm_beta != nullptr, errors::InvalidArgument("output_layernorm_beta is null"));
OP_REQUIRES(context, param.output_layernorm_gamma != nullptr, errors::InvalidArgument("output_layernorm_gamma is null"));
Tensor *output = nullptr;
OP_REQUIRES_OK(
context,
context->allocate_output(0, {batch_size_ * from_seq_len_, head_num_ * size_per_head_}, &output));
param.transformer_out = reinterpret_cast<DataType_ *>(output->flat<T>().data());
OP_REQUIRES_OK(
context,
functor::BertTransformerOpFunctor<Device, T>::Compute(
context,
param,
encoder_transformer_));
}
private:
int batch_size_, from_seq_len_, to_seq_len_, head_num_, size_per_head_;
typedef TransformerTFTraits<T> traits_;
typedef typename traits_::DataType DataType_;
cublasHandle_t cublas_handle_;
};
#ifdef GOOGLE_CUDA
#define REGISTER_GPU(T) \
REGISTER_KERNEL_BUILDER( \
Name("BertTransformer").Device(DEVICE_GPU).TypeConstraint<T>("T"), \
BertTransformerOp<GPUDevice, T>)
REGISTER_GPU(float);
REGISTER_GPU(Eigen::half);
#undef REGISTER_GPU
#endif
} //namespace
} //namespace tensorflow

View file

@ -0,0 +1,64 @@
/*
* Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifdef GOOGLE_CUDA
#define EIGEN_USE_GPU
#include "fastertransformer/tf_op/bert_transformer_op.h"
#include "fastertransformer/common.h"
#include "fastertransformer/faster_transformer.h"
#include "tensorflow/core/framework/op.h"
#include <cuda_runtime.h>
#include <string>
namespace tensorflow
{
using GPUDevice = Eigen::GpuDevice;
using namespace fastertransformer;
namespace functor
{
template <typename T>
struct BertTransformerOpFunctor<GPUDevice, T>
{
typedef typename TransformerTFTraits<T>::DataType DataType_;
static Status Compute(OpKernelContext *context,
EncoderInitParam<DataType_ > param,
BertEncoderTransformer<BertEncoderTransformerTraits< TransformerTFTraits<T>::OpType,
cuda::OpenMultiHeadAttention > > *encoder_transformer)
{
const cudaStream_t &stream = context->eigen_device<GPUDevice>().stream();
param.stream = stream;
try
{
check_cuda_error(cublasSetStream(param.cublas_handle, stream));
encoder_transformer->initialize(param);
encoder_transformer->forward();
return Status::OK();
}
catch(std::runtime_error& error)
{
return errors::Internal(error.what());
}
catch(...)
{
return errors::Internal("Runtime error");
}
}
};
} //namespace functor
template struct functor::BertTransformerOpFunctor<GPUDevice, float>;
template struct functor::BertTransformerOpFunctor<GPUDevice, Eigen::half>;
} //namespace tensorflow
#endif

View file

@ -0,0 +1,63 @@
/*
* Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#ifndef TENSORFLOW_CORE_KERNELS_MULTIHEADATTR_OP_H_
#define TENSORFLOW_CORE_KERNELS_MULTIHEADATTR_OP_H_
#include "fastertransformer/common.h"
#include "fastertransformer/faster_transformer.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/lib/core/errors.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include <cublas_v2.h>
using namespace fastertransformer;
namespace tensorflow
{
template <typename T> class TransformerTFTraits;
template <>
class TransformerTFTraits<float>
{
public:
typedef float DataType;
static const OperationType OpType = OperationType::FP32;
};
template <>
class TransformerTFTraits<Eigen::half>
{
public:
typedef __half DataType;
static const OperationType OpType = OperationType::HALF;
};
namespace functor
{
template <typename Device, typename T>
struct BertTransformerOpFunctor
{
typedef typename TransformerTFTraits<T>::DataType DataType_;
static Status Compute(OpKernelContext *context,
EncoderInitParam<DataType_ > param,
BertEncoderTransformer<BertEncoderTransformerTraits< TransformerTFTraits<T>::OpType,
cuda::OpenMultiHeadAttention > > *encoder_transformer);
};
} //namespace functor
} //namespace tensorflow
#endif

View file

@ -0,0 +1,367 @@
/*
* Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "fastertransformer/faster_transformer.h"
#include <assert.h>
#include <iostream>
#include <cuda_runtime_api.h>
#include <cublas_v2.h>
#include <cuda_fp16.h>
#include <vector>
#include <iomanip>
#include <chrono>
#include "NvInfer.h"
#include "NvCaffeParser.h"
using namespace nvinfer1;
using namespace nvcaffeparser1;
using namespace std;
using namespace fastertransformer;
template <typename T> class TransformerTrtTraits;
template <>
class TransformerTrtTraits<float>
{
public:
static const OperationType OpType = OperationType::FP32;
static const nvinfer1::DataType DataType = nvinfer1::DataType::kFLOAT;
};
template <>
class TransformerTrtTraits<__half>
{
public:
static const OperationType OpType = OperationType::HALF;
static const nvinfer1::DataType DataType = nvinfer1::DataType::kHALF;
};
class Logger : public nvinfer1::ILogger
{
public:
Logger(Severity severity = Severity::kINFO) : reportableSeverity(severity) {}
void log(Severity severity, const char* msg) override
{
if (severity > reportableSeverity) return;
switch (severity)
{
case Severity::kINTERNAL_ERROR: std::cerr << "INTERNAL_ERROR: "; break;
case Severity::kERROR: std::cerr << "ERROR: "; break;
case Severity::kWARNING: std::cerr << "WARNING: "; break;
case Severity::kINFO: std::cerr << "INFO: "; break;
default: std::cerr << "UNKNOWN: "; break;
}
std::cerr << msg << std::endl;
}
Severity reportableSeverity;
};
static Logger gLogger(ILogger::Severity::kWARNING);
template <typename T>
class TransformerPlugin: public IPluginV2
{
public:
TransformerPlugin(
int hidden_dim, int head_num, int seq_len, int max_batch_size,
const nvinfer1::Weights &w_attr_kernel_Q,
const nvinfer1::Weights &w_attr_kernel_K,
const nvinfer1::Weights &w_attr_kernel_V,
const nvinfer1::Weights &w_attr_bias_Q,
const nvinfer1::Weights &w_attr_bias_K,
const nvinfer1::Weights &w_attr_bias_V,
const nvinfer1::Weights &w_attr_output_kernel,
const nvinfer1::Weights &w_attr_output_bias,
const nvinfer1::Weights &w_attr_output_layernorm_beta,
const nvinfer1::Weights &w_attr_output_layernorm_gamma,
const nvinfer1::Weights &w_inter_kernel,
const nvinfer1::Weights &w_inter_bias,
const nvinfer1::Weights &w_output_kernel,
const nvinfer1::Weights &w_output_bias,
const nvinfer1::Weights &w_output_layernorm_beta,
const nvinfer1::Weights &w_output_layernorm_gamma
): hidden_dim_(hidden_dim), head_num_(head_num), seq_len_(seq_len), max_batch_size_(max_batch_size)
{
cudaMallocAndCopy(d_attr_kernel_Q_, w_attr_kernel_Q, hidden_dim * hidden_dim);
cudaMallocAndCopy(d_attr_kernel_K_, w_attr_kernel_K, hidden_dim * hidden_dim);
cudaMallocAndCopy(d_attr_kernel_V_, w_attr_kernel_V, hidden_dim * hidden_dim);
cudaMallocAndCopy(d_attr_bias_Q_, w_attr_bias_Q, hidden_dim);
cudaMallocAndCopy(d_attr_bias_K_, w_attr_bias_K, hidden_dim);
cudaMallocAndCopy(d_attr_bias_V_, w_attr_bias_V, hidden_dim);
cudaMallocAndCopy(d_attr_output_kernel_, w_attr_output_kernel, hidden_dim * hidden_dim);
cudaMallocAndCopy(d_attr_output_bias_, w_attr_output_bias, hidden_dim);
cudaMallocAndCopy(d_attr_output_layernorm_beta_, w_attr_output_layernorm_beta, hidden_dim);
cudaMallocAndCopy(d_attr_output_layernorm_gamma_, w_attr_output_layernorm_gamma, hidden_dim);
cudaMallocAndCopy(d_inter_kernel_, w_inter_kernel, hidden_dim * hidden_dim * 4);
cudaMallocAndCopy(d_inter_bias_, w_inter_bias, hidden_dim * 4);
cudaMallocAndCopy(d_output_kernel_, w_output_kernel, hidden_dim * hidden_dim * 4);
cudaMallocAndCopy(d_output_bias_, w_output_bias, hidden_dim);
cudaMallocAndCopy(d_output_layernorm_beta_, w_output_layernorm_beta, hidden_dim);
cudaMallocAndCopy(d_output_layernorm_gamma_, w_output_layernorm_gamma, hidden_dim);
/* should modify 0 to current device id */
try
{
check_cuda_error(cublasCreate(&cublas_handle_));
int device_id;
check_cuda_error(cudaGetDevice(&device_id));
allocator_ = new fastertransformer::Allocator<AllocatorType::CUDA>(device_id);
encoder_transformer_ = new
BertEncoderTransformer<EncoderTraits_>(*allocator_, max_batch_size, seq_len, seq_len, head_num, hidden_dim / head_num);
EncoderInitParam<T> encoder_param; //init param here
encoder_param.attr_kernel_Q = d_attr_kernel_Q_;
encoder_param.attr_kernel_K = d_attr_kernel_K_;
encoder_param.attr_kernel_V = d_attr_kernel_V_;
encoder_param.attr_bias_Q = d_attr_bias_Q_;
encoder_param.attr_bias_K = d_attr_bias_K_;
encoder_param.attr_bias_V = d_attr_bias_V_;
encoder_param.attr_output_kernel = d_attr_output_kernel_;
encoder_param.attr_output_bias = d_attr_output_bias_;
encoder_param.attr_output_layernorm_beta = d_attr_output_layernorm_beta_;
encoder_param.attr_output_layernorm_gamma = d_attr_output_layernorm_gamma_;
encoder_param.inter_kernel = d_inter_kernel_;
encoder_param.inter_bias = d_inter_bias_;
encoder_param.output_kernel = d_output_kernel_;
encoder_param.output_bias = d_output_bias_;
encoder_param.output_layernorm_beta = d_output_layernorm_beta_;
encoder_param.output_layernorm_gamma = d_output_layernorm_gamma_;
encoder_param.cublas_handle = cublas_handle_;
encoder_transformer_->initialize(encoder_param);
}
catch(std::runtime_error& error)
{
std::cout << error.what() << std::endl;
}
}
TransformerPlugin(
int hidden_dim, int head_num, int seq_len, int max_batch_size,
const T* dp_attr_kernel_Q,
const T* dp_attr_kernel_K,
const T* dp_attr_kernel_V,
const T* dp_attr_bias_Q,
const T* dp_attr_bias_K,
const T* dp_attr_bias_V,
const T* dp_attr_output_kernel,
const T* dp_attr_output_bias,
const T* dp_attr_output_layernorm_beta,
const T* dp_attr_output_layernorm_gamma,
const T* dp_inter_kernel,
const T* dp_inter_bias,
const T* dp_output_kernel,
const T* dp_output_bias,
const T* dp_output_layernorm_beta,
const T* dp_output_layernorm_gamma
): hidden_dim_(hidden_dim), head_num_(head_num), seq_len_(seq_len), max_batch_size_(max_batch_size)
{
cudaMallocAndCopy(d_attr_kernel_Q_, dp_attr_kernel_Q, hidden_dim * hidden_dim);
cudaMallocAndCopy(d_attr_kernel_K_, dp_attr_kernel_K, hidden_dim * hidden_dim);
cudaMallocAndCopy(d_attr_kernel_V_, dp_attr_kernel_V, hidden_dim * hidden_dim);
cudaMallocAndCopy(d_attr_bias_Q_, dp_attr_bias_Q, hidden_dim);
cudaMallocAndCopy(d_attr_bias_K_, dp_attr_bias_K, hidden_dim);
cudaMallocAndCopy(d_attr_bias_V_, dp_attr_bias_V, hidden_dim);
cudaMallocAndCopy(d_attr_output_kernel_, dp_attr_output_kernel, hidden_dim * hidden_dim);
cudaMallocAndCopy(d_attr_output_bias_, dp_attr_output_bias, hidden_dim);
cudaMallocAndCopy(d_attr_output_layernorm_beta_, dp_attr_output_layernorm_beta, hidden_dim);
cudaMallocAndCopy(d_attr_output_layernorm_gamma_, dp_attr_output_layernorm_gamma, hidden_dim);
cudaMallocAndCopy(d_inter_kernel_, dp_inter_kernel, hidden_dim * hidden_dim * 4);
cudaMallocAndCopy(d_inter_bias_, dp_inter_bias, hidden_dim * 4);
cudaMallocAndCopy(d_output_kernel_, dp_output_kernel, hidden_dim * hidden_dim * 4);
cudaMallocAndCopy(d_output_bias_, dp_output_bias, hidden_dim);
cudaMallocAndCopy(d_output_layernorm_beta_, dp_output_layernorm_beta, hidden_dim);
cudaMallocAndCopy(d_output_layernorm_gamma_, dp_output_layernorm_gamma, hidden_dim);
try
{
check_cuda_error(cublasCreate(&cublas_handle_));
/* should modify 0 to current device id */
int device_id;
check_cuda_error(cudaGetDevice(&device_id));
allocator_ = new fastertransformer::Allocator<AllocatorType::CUDA>(device_id);
encoder_transformer_ = new
BertEncoderTransformer<EncoderTraits_>(*allocator_, max_batch_size, seq_len, seq_len, head_num, hidden_dim / head_num);
EncoderInitParam<T> encoder_param; //init param here
encoder_param.attr_kernel_Q = d_attr_kernel_Q_;
encoder_param.attr_kernel_K = d_attr_kernel_K_;
encoder_param.attr_kernel_V = d_attr_kernel_V_;
encoder_param.attr_bias_Q = d_attr_bias_Q_;
encoder_param.attr_bias_K = d_attr_bias_K_;
encoder_param.attr_bias_V = d_attr_bias_V_;
encoder_param.attr_output_kernel = d_attr_output_kernel_;
encoder_param.attr_output_bias = d_attr_output_bias_;
encoder_param.attr_output_layernorm_beta = d_attr_output_layernorm_beta_;
encoder_param.attr_output_layernorm_gamma = d_attr_output_layernorm_gamma_;
encoder_param.inter_kernel = d_inter_kernel_;
encoder_param.inter_bias = d_inter_bias_;
encoder_param.output_kernel = d_output_kernel_;
encoder_param.output_bias = d_output_bias_;
encoder_param.output_layernorm_beta = d_output_layernorm_beta_;
encoder_param.output_layernorm_gamma = d_output_layernorm_gamma_;
encoder_param.cublas_handle = cublas_handle_;
encoder_transformer_->initialize(encoder_param);
}
catch(std::runtime_error& error)
{
std::cout << error.what() << std::endl;
}
}
~TransformerPlugin()
{
try{
check_cuda_error(cudaFree(d_attr_kernel_Q_));
check_cuda_error(cudaFree(d_attr_kernel_K_));
check_cuda_error(cudaFree(d_attr_kernel_V_));
check_cuda_error(cudaFree(d_attr_bias_Q_));
check_cuda_error(cudaFree(d_attr_bias_K_));
check_cuda_error(cudaFree(d_attr_bias_V_));
check_cuda_error(cudaFree(d_attr_output_kernel_));
check_cuda_error(cudaFree(d_attr_output_bias_));
check_cuda_error(cudaFree(d_attr_output_layernorm_beta_));
check_cuda_error(cudaFree(d_attr_output_layernorm_gamma_));
check_cuda_error(cudaFree(d_inter_kernel_));
check_cuda_error(cudaFree(d_inter_bias_));
check_cuda_error(cudaFree(d_output_kernel_));
check_cuda_error(cudaFree(d_output_bias_));
check_cuda_error(cudaFree(d_output_layernorm_beta_));
check_cuda_error(cudaFree(d_output_layernorm_gamma_));
check_cuda_error(cublasDestroy(cublas_handle_));
delete encoder_transformer_;
}
catch(std::runtime_error& error)
{
std::cout << error.what() << std::endl;
}
}
virtual size_t getSerializationSize() const override {return 0;}
virtual void serialize(void* buffer) const override {}
int getNbOutputs() const override {return 1;}
Dims getOutputDimensions(int index, const Dims* pInputDim, int nInputDim) override
{
assert(index == 0 && nInputDim == 2);
return DimsHW(seq_len_, hidden_dim_);
}
bool supportsFormat(nvinfer1::DataType type, PluginFormat format) const override
{
return type == nvinfer1::DataType::kFLOAT && format == PluginFormat::kNCHW;
}
void configureWithFormat(const Dims* pInputDim, int nInputDim, const Dims* pOutputDim,
int nOutputDim, nvinfer1::DataType dataType, nvinfer1::PluginFormat pluginFormat, int maxBatchSize) override
{
assert(dataType == nvinfer1::DataType::kFLOAT && pluginFormat == nvinfer1::PluginFormat::kNCHW);
assert(nInputDim == 2);
assert(pInputDim[0].nbDims == 2 && pInputDim[0].d[0] == seq_len_ && pInputDim[0].d[1] == hidden_dim_);
assert(pInputDim[1].nbDims == 2 && pInputDim[1].d[0] == seq_len_ && pInputDim[1].d[1] == seq_len_);
assert(nOutputDim == 1);
assert(pOutputDim[0].nbDims == 2 && pOutputDim[0].d[0] == seq_len_ && pOutputDim[0].d[1] == hidden_dim_);
}
virtual int enqueue(int batch_size, const void * const *inputs, void **outputs, void* workspace, cudaStream_t stream) override
{
T* from_tensor = (T*) (inputs[0]);
T* to_tensor = (T*) (inputs[0]);
T* attr_mask = (T*) (inputs[1]);
T* transformer_out = (T*) (outputs[0]);
try
{
check_cuda_error(cublasSetStream(cublas_handle_, stream));
encoder_transformer_->trt_initialize(from_tensor, to_tensor, attr_mask, transformer_out, stream, cublas_handle_);
encoder_transformer_->forward();
}
catch(std::runtime_error& error)
{
std::cout << error.what() << std::endl;
}
return 0;
}
virtual size_t getWorkspaceSize(int nBatch) const override {return 0;}
const char* getPluginType() const override {return "TransformerPlugin";}
const char* getPluginVersion() const override {return "0";}
IPluginV2* clone() const override
{
return new TransformerPlugin(
hidden_dim_, head_num_, seq_len_, max_batch_size_,
d_attr_kernel_Q_,
d_attr_kernel_K_,
d_attr_kernel_V_,
d_attr_bias_Q_,
d_attr_bias_K_,
d_attr_bias_V_,
d_attr_output_kernel_,
d_attr_output_bias_,
d_attr_output_layernorm_beta_,
d_attr_output_layernorm_gamma_,
d_inter_kernel_,
d_inter_bias_,
d_output_kernel_,
d_output_bias_,
d_output_layernorm_beta_,
d_output_layernorm_gamma_
);
}
int initialize() override {return 0;}
void terminate() override {}
void destroy() override { delete this; }
void setPluginNamespace(const char* szNamespace) override {}
const char* getPluginNamespace() const override {return "";}
static void cudaMallocAndCopy(T *&dpWeight, const nvinfer1::Weights &w, int nValue)
{
assert(w.count == nValue);
check_cuda_error(cudaMalloc(&dpWeight, nValue * sizeof(T)));
check_cuda_error(cudaMemcpy(dpWeight, w.values, nValue * sizeof(T), cudaMemcpyHostToDevice));
T* data = (T*)malloc(sizeof(T) * nValue);
cudaMemcpy(data, dpWeight, sizeof(T) * nValue, cudaMemcpyDeviceToHost);
}
static void cudaMallocAndCopy(T*&dpWeight, const T *&dpWeightOld, int nValue)
{
check_cuda_error(cudaMalloc(&dpWeight, nValue * sizeof(T)));
check_cuda_error(cudaMemcpy(dpWeight, dpWeightOld, nValue * sizeof(T), cudaMemcpyDeviceToDevice));
}
private:
int hidden_dim_ = 0, head_num_ = 0, seq_len_ = 0, max_batch_size_;
T *d_attr_kernel_Q_ = NULL, *d_attr_kernel_K_ = NULL, *d_attr_kernel_V_ = NULL;
T *d_attr_bias_Q_ = NULL, *d_attr_bias_K_ = NULL, *d_attr_bias_V_ = NULL;
T *d_attr_output_kernel_ = NULL, *d_attr_output_bias_ = NULL;
T *d_attr_output_layernorm_beta_ = NULL;
T *d_attr_output_layernorm_gamma_ = NULL;
T *d_inter_kernel_ = NULL, *d_inter_bias_ = NULL;
T *d_output_kernel_ = NULL, *d_output_bias_ = NULL, *d_output_layernorm_beta_ = NULL, *d_output_layernorm_gamma_ = NULL;
cublasHandle_t cublas_handle_;
typedef BertEncoderTransformerTraits< TransformerTrtTraits<T>::OpType , cuda::OpenMultiHeadAttention> EncoderTraits_;
BertEncoderTransformer<EncoderTraits_> *encoder_transformer_;
fastertransformer::Allocator<AllocatorType::CUDA> *allocator_;
};

View file

@ -0,0 +1,148 @@
/*
* Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "fastertransformer/common.h"
#include "fastertransformer/trt_plugin/bert_transformer_plugin.h"
#include <assert.h>
#include <cstdio>
#include <cstdlib>
#include <cuda_fp16.h>
#include <cuda_profiler_api.h>
#include <chrono>
#include <iostream>
#include <NvInfer.h>
#include <map>
#include <string>
#include <vector>
template<typename T>
class TRT_Transformer
{
public:
TRT_Transformer(const int batch_size, const int seq_len, const int head_num, const int hidden_dim, const int num_layers)
:batch_size_(batch_size), seq_len_(seq_len), head_num_(head_num), hidden_dim_(hidden_dim), num_layers_(num_layers)
{
dtype_ = TransformerTrtTraits<T>::DataType;
}
~TRT_Transformer()
{
check_cuda_error(cudaFree(buffers[input_index_]));
check_cuda_error(cudaFree(buffers[mask_index_]));
check_cuda_error(cudaFree(buffers[output_index_]));
engine_->destroy();
context_->destroy();
}
nvinfer1::Weights point2weight(T* ptr, int size)
{
return nvinfer1::Weights{dtype_, ptr, (long)size};
}
void build_engine(std::vector<std::vector<T* > > &weights)
{
assert(weights.size() == num_layers_);
for(int i = 0; i < num_layers_; ++i)
assert(weights[i].size() == 16);
nvinfer1::IBuilder* builder = nvinfer1::createInferBuilder(gLogger);
assert(builder);
nvinfer1::INetworkDefinition* network = builder->createNetwork();
auto from_tensor = network->addInput(INPUT_BLOB_NAME, dtype_, nvinfer1::Dims2{seq_len_, hidden_dim_});
auto mask_tensor = network->addInput(MASK_BLOB_NAME, dtype_, nvinfer1::Dims2{seq_len_, seq_len_});
assert(input_tensor);
assert(mask_tensor);
nvinfer1::ITensor* output_tensor = nullptr;
for(int i = 0; i < num_layers_; ++i)
{
auto plugin = new TransformerPlugin<T>(
hidden_dim_, head_num_, seq_len_, batch_size_,
point2weight(weights[i][0], hidden_dim_ * hidden_dim_),
point2weight(weights[i][1], hidden_dim_ * hidden_dim_),
point2weight(weights[i][2], hidden_dim_ * hidden_dim_),
point2weight(weights[i][3], hidden_dim_),
point2weight(weights[i][4], hidden_dim_),
point2weight(weights[i][5], hidden_dim_),
point2weight(weights[i][6], hidden_dim_ * hidden_dim_),
point2weight(weights[i][7], hidden_dim_),
point2weight(weights[i][8], hidden_dim_),
point2weight(weights[i][9], hidden_dim_),
point2weight(weights[i][10], hidden_dim_ * hidden_dim_ * 4),
point2weight(weights[i][11], hidden_dim_ * 4),
point2weight(weights[i][12], hidden_dim_ * hidden_dim_ * 4),
point2weight(weights[i][13], hidden_dim_),
point2weight(weights[i][14], hidden_dim_),
point2weight(weights[i][15], hidden_dim_)
);
assert(plugin);
ITensor *inputs[] = {from_tensor, mask_tensor};
auto transformerLayer = network->addPluginV2(inputs, 2, *plugin);
from_tensor = transformerLayer->getOutput(0);
output_tensor = from_tensor;
}
output_tensor->setName(OUTPUT_BLOB_NAME);
network->markOutput(*output_tensor);
builder->setMaxBatchSize(batch_size_);
builder->setMaxWorkspaceSize(1 << 20);
builder->setFp16Mode(false);
engine_ = builder->buildCudaEngine(*network);
assert(engine_);
network->destroy();
builder->destroy();
input_index_ = engine_->getBindingIndex(INPUT_BLOB_NAME);
mask_index_ = engine_->getBindingIndex(MASK_BLOB_NAME);
output_index_ = engine_->getBindingIndex(OUTPUT_BLOB_NAME);
check_cuda_error(cudaMalloc(&buffers[input_index_], batch_size_ * seq_len_ * hidden_dim_ * sizeof(T)));
check_cuda_error(cudaMalloc(&buffers[mask_index_], batch_size_ * seq_len_ * seq_len_ * sizeof(T)));
check_cuda_error(cudaMalloc(&buffers[output_index_], batch_size_ * seq_len_ * hidden_dim_ * sizeof(T)));
context_ = engine_->createExecutionContext();
}
void do_inference(int batch_size, const T* h_from_tensor, const T* h_attr_mask, T* h_output, cudaStream_t stream)
{
cudaMemcpyAsync(buffers[input_index_], h_from_tensor, batch_size * seq_len_ * hidden_dim_ * sizeof(T),
cudaMemcpyHostToDevice, stream);
cudaMemcpyAsync(buffers[mask_index_], h_attr_mask, batch_size * seq_len_ * seq_len_ * sizeof(T),
cudaMemcpyHostToDevice, stream);
context_->enqueue(batch_size_, buffers, stream, nullptr);
cudaMemcpyAsync(h_output, buffers[output_index_], batch_size * seq_len_ * hidden_dim_ * sizeof(T),
cudaMemcpyDeviceToHost, stream);
}
private:
const int batch_size_, seq_len_, head_num_, hidden_dim_, num_layers_;
nvinfer1::DataType dtype_;
int inputN_, outputN_, input_index_, mask_index_, output_index_;
nvinfer1::ICudaEngine* engine_;
nvinfer1::IExecutionContext* context_;
std::map<std::string, nvinfer1::Weights> weightMap_;
void* buffers[3];
const char* INPUT_BLOB_NAME = "input";
const char* MASK_BLOB_NAME = "mask";
const char* OUTPUT_BLOB_NAME = "prob";
};

View file

@ -0,0 +1,62 @@
/*
* Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/**
* Tools
**/
#pragma once
#include "fastertransformer/common.h"
#include <cuda_runtime.h>
namespace fastertransformer{
/**
* Pop current cuda device and set new device
* i_device - device ID to set
* o_device - device ID to pop
* ret - return code (the same as cudaError_t)
*/
inline cudaError_t get_set_device(int i_device, int* o_device = NULL){
int current_dev_id = 0;
cudaError_t err = cudaSuccess;
if (o_device != NULL) {
err = cudaGetDevice(&current_dev_id);
if (err != cudaSuccess)
return err;
if (current_dev_id == i_device){
*o_device = i_device;
}
else{
err = cudaSetDevice(i_device);
if (err != cudaSuccess) {
return err;
}
*o_device = current_dev_id;
}
}
else{
err = cudaSetDevice(i_device);
if (err != cudaSuccess) {
return err;
}
}
return cudaSuccess;
}
}

View file

@ -0,0 +1,18 @@
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
add_subdirectory(cpp)
if(BUILD_TRT)
add_subdirectory(tensorRT)
endif()

View file

@ -0,0 +1,28 @@
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
cmake_minimum_required(VERSION 3.8)
set(transformer_fp32_files
transformer_fp32.cc
)
set(transformer_fp16_files
transformer_fp16.cc
)
add_executable(transformer_fp32 ${transformer_fp32_files})
target_link_libraries(transformer_fp32 PUBLIC -lcublas -lcudart fastertransformer ${CMAKE_THREAD_LIBS_INIT})
add_executable(transformer_fp16 ${transformer_fp16_files})
target_link_libraries(transformer_fp16 PUBLIC -lcublas -lcudart fastertransformer ${CMAKE_THREAD_LIBS_INIT})

View file

@ -0,0 +1,169 @@
/*
* Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "fastertransformer/faster_transformer.h"
#include <cstdio>
#include <cstdlib>
#include <cuda_profiler_api.h>
#include <iostream>
#include <sys/time.h>
#include <cuda_fp16.h>
using namespace fastertransformer;
typedef __half T;
double diffTime(timeval start, timeval end)
{
return (end.tv_sec - start.tv_sec) * 1000 + (end.tv_usec - start.tv_usec) * 0.001;
}
void host_malloc(T** ptr, int size)
{
(*ptr) = (T*)malloc(sizeof(T) * size);
}
void device_malloc(T** ptr, int size)
{
cudaMalloc((void**)ptr, sizeof(T) * size);
}
void copy_to_device(T** d_ptr, T** h_ptr, int size)
{
cudaMemcpy((*d_ptr), (*h_ptr), sizeof(T) * size, cudaMemcpyHostToDevice);
}
int main(int argc, char* argv[])
{
struct cudaDeviceProp prop;
cudaGetDeviceProperties(&prop, 0);
if(argc != 6)
{
printf("./transformer_fp16 batch_size num_layers seq_len head_num size_per_head\n");
printf("e.g., ./transformer_fp16 1 12 128 12 64\n");
return 0;
}
printf("Device %s\n", prop.name);
int batch_size = atoi(argv[1]);
int num_layers = atoi(argv[2]);
int seq_len = atoi(argv[3]);
int head_num = atoi(argv[4]);
int size_per_head = atoi(argv[5]);
int from_seq_len = seq_len;
int to_seq_len = seq_len;
int hidden_dim = head_num * size_per_head;
T *d_from_tensor = NULL, *d_transformer_out = NULL;
T *d_attr_kernel_Q = NULL, *d_attr_kernel_K = NULL, *d_attr_kernel_V = NULL;
T *d_attr_bias_Q = NULL, *d_attr_bias_K = NULL, *d_attr_bias_V = NULL;
T *d_attr_mask = NULL, *d_attr_output_kernel = NULL, *d_attr_output_bias = NULL;
T *d_attr_output_layernorm_beta = NULL;
T *d_attr_output_layernorm_gamma = NULL;
T *d_inter_kernel = NULL, *d_inter_bias = NULL;
T *d_output_kernel = NULL, *d_output_bias = NULL, *d_output_layernorm_beta = NULL, *d_output_layernorm_gamma = NULL;
size_t free_bytes, total_bytes;
cudaMemGetInfo(&free_bytes, &total_bytes);
float free = (float)(free_bytes)/ 1024.0 / 1024.0 / 1024.0;
float total = (float)(total_bytes) / 1024.0 / 1024.0 / 1024.0;
printf("before allocate free %.2f GB total %.2f GB\n", free, total);
device_malloc(&d_from_tensor, batch_size * seq_len * hidden_dim);
device_malloc(&d_transformer_out, batch_size * seq_len * hidden_dim);
device_malloc(&d_attr_kernel_Q, hidden_dim * hidden_dim);
device_malloc(&d_attr_kernel_K, hidden_dim * hidden_dim);
device_malloc(&d_attr_kernel_V, hidden_dim * hidden_dim);
device_malloc(&d_attr_bias_Q, hidden_dim);
device_malloc(&d_attr_bias_K, hidden_dim);
device_malloc(&d_attr_bias_V, hidden_dim);
device_malloc(&d_attr_mask, batch_size * seq_len * seq_len);
device_malloc(&d_attr_output_kernel, hidden_dim * hidden_dim);
device_malloc(&d_attr_output_bias, hidden_dim);
device_malloc(&d_attr_output_layernorm_beta, hidden_dim);
device_malloc(&d_attr_output_layernorm_gamma, hidden_dim);
device_malloc(&d_inter_kernel, hidden_dim * hidden_dim * 4);
device_malloc(&d_inter_bias, hidden_dim * 4);
device_malloc(&d_output_kernel, hidden_dim * hidden_dim * 4);
device_malloc(&d_output_bias, hidden_dim);
device_malloc(&d_output_layernorm_beta, hidden_dim);
device_malloc(&d_output_layernorm_gamma, hidden_dim);
cudaMemGetInfo(&free_bytes, &total_bytes);
free = (float)(free_bytes)/ 1024.0 / 1024.0 / 1024.0;
total = (float)(total_bytes) / 1024.0 / 1024.0 / 1024.0;
printf("After allocate free %.2f GB used %.2f GB total %.2f GB\n", free, total - free, total);
cublasHandle_t cublasHandle;
cublasCreate(&cublasHandle);
cudaStream_t stream;
cudaStreamCreate(&stream);
cublasSetStream(cublasHandle, stream);
typedef BertEncoderTransformerTraits<OperationType::HALF, cuda::OpenMultiHeadAttention> EncoderTraits_;
fastertransformer::Allocator<AllocatorType::CUDA> allocator(0);
EncoderInitParam<__half> encoder_param; //init param here
encoder_param.from_tensor = d_from_tensor;
encoder_param.to_tensor = d_from_tensor;
encoder_param.attr_kernel_Q = d_attr_kernel_Q;
encoder_param.attr_kernel_K = d_attr_kernel_K;
encoder_param.attr_kernel_V = d_attr_kernel_V;
encoder_param.attr_bias_Q = d_attr_bias_Q;
encoder_param.attr_bias_K = d_attr_bias_K;
encoder_param.attr_bias_V = d_attr_bias_V;
encoder_param.attr_mask = d_attr_mask;
encoder_param.attr_output_kernel = d_attr_output_kernel;
encoder_param.attr_output_bias = d_attr_output_bias;
encoder_param.attr_output_layernorm_beta = d_attr_output_layernorm_beta;
encoder_param.attr_output_layernorm_gamma = d_attr_output_layernorm_gamma;
encoder_param.inter_kernel = d_inter_kernel;
encoder_param.inter_bias = d_inter_bias;
encoder_param.output_kernel = d_output_kernel;
encoder_param.output_bias = d_output_bias;
encoder_param.output_layernorm_beta = d_output_layernorm_beta;
encoder_param.output_layernorm_gamma = d_output_layernorm_gamma;
encoder_param.transformer_out = d_transformer_out;
encoder_param.cublas_handle = cublasHandle;
encoder_param.stream = stream;
BertEncoderTransformer<EncoderTraits_> *encoder_transformer_ = new
BertEncoderTransformer<EncoderTraits_>(allocator, batch_size, from_seq_len, to_seq_len, head_num, size_per_head);
encoder_transformer_->initialize(encoder_param);
int ite = 200;
//warp up
for(int i = 0; i < ite; ++i)
encoder_transformer_->forward();
struct timeval ss, ee;
cudaDeviceSynchronize();
gettimeofday(&ss, NULL);
for(int i = 0; i < ite; ++i)
{
for(int j = 0; j < num_layers; ++j)
encoder_transformer_->forward();
}
cudaDeviceSynchronize();
gettimeofday(&ee, NULL);
printf("[batch_size %d seq_len %d %d transformer layers] costs %.2f ms\n", batch_size, seq_len, num_layers,
diffTime(ss, ee) / ite);
delete encoder_transformer_;
return 0;
}

View file

@ -0,0 +1,168 @@
/*
* Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "fastertransformer/faster_transformer.h"
#include <cstdio>
#include <cstdlib>
#include <cuda_profiler_api.h>
#include <iostream>
#include <sys/time.h>
using namespace fastertransformer;
double diffTime(timeval start, timeval end)
{
return (end.tv_sec - start.tv_sec) * 1000 + (end.tv_usec - start.tv_usec) * 0.001;
}
void host_malloc(float** ptr, int size)
{
(*ptr) = (float*)malloc(sizeof(float) * size);
}
void device_malloc(float** ptr, int size)
{
cudaMalloc((void**)ptr, sizeof(float) * size);
}
void copy_to_device(float** d_ptr, float** h_ptr, int size)
{
cudaMemcpy((*d_ptr), (*h_ptr), sizeof(float) * size, cudaMemcpyHostToDevice);
}
int main(int argc, char* argv[])
{
if(argc != 6)
{
printf("./transformer_fp32 batch_size num_layers seq_len head_num size_per_head\n");
printf("e.g., ./transformer_fp32 1 12 128 12 64\n");
return 0;
}
struct cudaDeviceProp prop;
cudaGetDeviceProperties(&prop, 0);
printf("Device %s\n", prop.name);
int batch_size = atoi(argv[1]);
int num_layers = atoi(argv[2]);
int seq_len = atoi(argv[3]);
int head_num = atoi(argv[4]);
int size_per_head = atoi(argv[5]);
int from_seq_len = seq_len;
int to_seq_len = seq_len;
int hidden_dim = head_num * size_per_head;
float *d_from_tensor = NULL, *d_transformer_out = NULL;
float *d_attr_kernel_Q = NULL, *d_attr_kernel_K = NULL, *d_attr_kernel_V = NULL;
float *d_attr_bias_Q = NULL, *d_attr_bias_K = NULL, *d_attr_bias_V = NULL;
float *d_attr_mask = NULL, *d_attr_output_kernel = NULL, *d_attr_output_bias = NULL;
float *d_attr_output_layernorm_beta = NULL;
float *d_attr_output_layernorm_gamma = NULL;
float *d_inter_kernel = NULL, *d_inter_bias = NULL;
float *d_output_kernel = NULL, *d_output_bias = NULL, *d_output_layernorm_beta = NULL, *d_output_layernorm_gamma = NULL;
size_t free_bytes, total_bytes;
cudaMemGetInfo(&free_bytes, &total_bytes);
float free = (float)(free_bytes)/ 1024.0 / 1024.0 / 1024.0;
float total = (float)(total_bytes) / 1024.0 / 1024.0 / 1024.0;
printf("before allocate free %.2f GB total %.2f GB\n", free, total);
device_malloc(&d_from_tensor, batch_size * seq_len * hidden_dim);
device_malloc(&d_transformer_out, batch_size * seq_len * hidden_dim);
device_malloc(&d_attr_kernel_Q, hidden_dim * hidden_dim);
device_malloc(&d_attr_kernel_K, hidden_dim * hidden_dim);
device_malloc(&d_attr_kernel_V, hidden_dim * hidden_dim);
device_malloc(&d_attr_bias_Q, hidden_dim);
device_malloc(&d_attr_bias_K, hidden_dim);
device_malloc(&d_attr_bias_V, hidden_dim);
device_malloc(&d_attr_mask, batch_size * seq_len * seq_len);
device_malloc(&d_attr_output_kernel, hidden_dim * hidden_dim);
device_malloc(&d_attr_output_bias, hidden_dim);
device_malloc(&d_attr_output_layernorm_beta, hidden_dim);
device_malloc(&d_attr_output_layernorm_gamma, hidden_dim);
device_malloc(&d_inter_kernel, hidden_dim * hidden_dim * 4);
device_malloc(&d_inter_bias, hidden_dim * 4);
device_malloc(&d_output_kernel, hidden_dim * hidden_dim * 4);
device_malloc(&d_output_bias, hidden_dim);
device_malloc(&d_output_layernorm_beta, hidden_dim);
device_malloc(&d_output_layernorm_gamma, hidden_dim);
cudaMemGetInfo(&free_bytes, &total_bytes);
free = (float)(free_bytes)/ 1024.0 / 1024.0 / 1024.0;
total = (float)(total_bytes) / 1024.0 / 1024.0 / 1024.0;
printf("After allocate free %.2f GB used %.2f GB total %.2f GB\n", free, total - free, total);
cublasHandle_t cublasHandle;
cublasCreate(&cublasHandle);
cudaStream_t stream;
cudaStreamCreate(&stream);
cublasSetStream(cublasHandle, stream);
typedef BertEncoderTransformerTraits<OperationType::FP32, cuda::OpenMultiHeadAttention> EncoderTraits_;
fastertransformer::Allocator<AllocatorType::CUDA> allocator(0);
EncoderInitParam<float> encoder_param; //init param here
encoder_param.from_tensor = d_from_tensor;
encoder_param.to_tensor = d_from_tensor;
encoder_param.attr_kernel_Q = d_attr_kernel_Q;
encoder_param.attr_kernel_K = d_attr_kernel_K;
encoder_param.attr_kernel_V = d_attr_kernel_V;
encoder_param.attr_bias_Q = d_attr_bias_Q;
encoder_param.attr_bias_K = d_attr_bias_K;
encoder_param.attr_bias_V = d_attr_bias_V;
encoder_param.attr_mask = d_attr_mask;
encoder_param.attr_output_kernel = d_attr_output_kernel;
encoder_param.attr_output_bias = d_attr_output_bias;
encoder_param.attr_output_layernorm_beta = d_attr_output_layernorm_beta;
encoder_param.attr_output_layernorm_gamma = d_attr_output_layernorm_gamma;
encoder_param.inter_kernel = d_inter_kernel;
encoder_param.inter_bias = d_inter_bias;
encoder_param.output_kernel = d_output_kernel;
encoder_param.output_bias = d_output_bias;
encoder_param.output_layernorm_beta = d_output_layernorm_beta;
encoder_param.output_layernorm_gamma = d_output_layernorm_gamma;
encoder_param.transformer_out = d_transformer_out;
encoder_param.cublas_handle = cublasHandle;
encoder_param.stream = stream;
BertEncoderTransformer<EncoderTraits_> *encoder_transformer_ = new
BertEncoderTransformer<EncoderTraits_>(allocator, batch_size, from_seq_len, to_seq_len, head_num, size_per_head);
encoder_transformer_->initialize(encoder_param);
int ite = 200;
//warp up
for(int i = 0; i < ite; ++i)
encoder_transformer_->forward();
struct timeval ss, ee;
cudaDeviceSynchronize();
gettimeofday(&ss, NULL);
for(int i = 0; i < ite; ++i)
{
for(int j = 0; j < num_layers; ++j)
encoder_transformer_->forward();
}
cudaDeviceSynchronize();
gettimeofday(&ee, NULL);
printf("[batch_size %d seq_len %d %d transformer layers] costs %.2f ms\n", batch_size, seq_len, num_layers,
diffTime(ss, ee) / ite);
delete encoder_transformer_;
return 0;
}

View file

@ -0,0 +1,70 @@
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# usage example
# python ckpt_type_convert.py --init_checkpoint=mrpc_output/model.ckpt-343 --fp16_checkpoint=mrpc_output/fp16_model.ckpt
import tensorflow as tf
import numpy as np
from tensorflow.contrib.framework.python.framework import checkpoint_utils
from tensorflow.python.ops import io_ops
from tensorflow.python.training.saver import BaseSaverBuilder
def checkpoint_dtype_cast(in_checkpoint_file, out_checkpoint_file):
var_list = checkpoint_utils.list_variables(tf.flags.FLAGS.init_checkpoint)
def init_graph():
for name, shape in var_list:
var = checkpoint_utils.load_variable(tf.flags.FLAGS.init_checkpoint, name)
recon_dtype = tf.float16 if var.dtype == np.float32 else var.dtype
tf.get_variable(name, shape=shape, dtype=recon_dtype)
init_graph()
saver = tf.train.Saver(builder=CastFromFloat32SaverBuilder())
with tf.Session() as sess:
saver.restore(sess, in_checkpoint_file)
saver.save(sess, 'tmp.ckpt')
tf.reset_default_graph()
init_graph()
saver = tf.train.Saver()
with tf.Session() as sess:
saver.restore(sess, 'tmp.ckpt')
saver.save(sess, out_checkpoint_file)
class CastFromFloat32SaverBuilder(BaseSaverBuilder):
# Based on tensorflow.python.training.saver.BulkSaverBuilder.bulk_restore
def bulk_restore(self, filename_tensor, saveables, preferred_shard,
restore_sequentially):
restore_specs = []
for saveable in saveables:
for spec in saveable.specs:
restore_specs.append((spec.name, spec.slice_spec, spec.dtype))
names, slices, dtypes = zip(*restore_specs)
restore_dtypes = [tf.float32 if dtype.base_dtype==tf.float16 else dtype for dtype in dtypes]
# print info
for i in range(len(restore_specs)):
print(names[i], 'from', restore_dtypes[i], 'to', dtypes[i].base_dtype)
with tf.device("cpu:0"):
restored = io_ops.restore_v2(
filename_tensor, names, slices, restore_dtypes)
return [tf.cast(r, dt.base_dtype) for r, dt in zip(restored, dtypes)]
if __name__ == '__main__':
tf.flags.DEFINE_string("fp16_checkpoint", None, "fp16 checkpoint file")
tf.flags.DEFINE_string("init_checkpoint", None, "initial checkpoint file")
checkpoint_dtype_cast(tf.flags.FLAGS.init_checkpoint, tf.flags.FLAGS.fp16_checkpoint)

View file

@ -0,0 +1,286 @@
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python import pywrap_tensorflow
from tensorflow.python.client import device_lib
import tensorflow as tf
import os
import sys
from my_modeling import *
build_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), '../../build/lib')
transformer_op_module = tf.load_op_library(
os.path.join(build_path, 'libtf_fastertransformer.so'))
def file_based_input_fn_builder_drop(input_file, seq_length, is_training,
drop_remainder):
""" Re-implementation of file_based_input_fn_builder function from modeling.py from Google's BERT repository https://github.com/google-research/bert
with drop_remainder=True.
"""
name_to_features = {
"input_ids": tf.FixedLenFeature([seq_length], tf.int64),
"input_mask": tf.FixedLenFeature([seq_length], tf.int64),
"segment_ids": tf.FixedLenFeature([seq_length], tf.int64),
"label_ids": tf.FixedLenFeature([], tf.int64),
"is_real_example": tf.FixedLenFeature([], tf.int64),
}
def _decode_record(record, name_to_features):
"""Decodes a record to a TensorFlow example."""
example = tf.parse_single_example(record, name_to_features)
# tf.Example only supports tf.int64, but the TPU only supports tf.int32.
# So cast all int64 to int32.
for name in list(example.keys()):
t = example[name]
if t.dtype == tf.int64:
t = tf.to_int32(t)
example[name] = t
return example
def input_fn(params):
"""The actual input function."""
batch_size = params["batch_size"]
# For training, we want a lot of parallel reading and shuffling.
# For eval, we want no shuffling and parallel reading doesn't matter.
d = tf.data.TFRecordDataset(input_file)
if is_training:
d = d.repeat()
d = d.shuffle(buffer_size=100)
# FASTINFER: drop remainder always
d = d.apply(
tf.contrib.data.map_and_batch(
lambda record: _decode_record(record, name_to_features),
batch_size=batch_size,
drop_remainder=True))
return d
return input_fn
def create_model(bert_config, is_training, input_ids, input_mask, segment_ids,
labels, num_labels, use_one_hot_embeddings):
"""Creates a classification model."""
model = BertModel(
config=bert_config,
is_training=is_training,
input_ids=input_ids,
input_mask=input_mask,
token_type_ids=segment_ids,
use_one_hot_embeddings=use_one_hot_embeddings)
output_layer = model.get_pooled_output()
hidden_size = output_layer.shape[-1].value
output_weights = tf.get_variable(
"output_weights", [num_labels, hidden_size],
dtype=tf.flags.FLAGS.floatx,
initializer=tf.truncated_normal_initializer(stddev=0.02))
output_bias = tf.get_variable(
"output_bias", [num_labels],
dtype=tf.flags.FLAGS.floatx,
initializer=tf.zeros_initializer())
with tf.variable_scope("loss"):
if is_training:
# I.e., 0.1 dropout
output_layer = tf.nn.dropout(output_layer, keep_prob=0.9)
logits = tf.matmul(output_layer, output_weights, transpose_b=True)
logits = tf.nn.bias_add(logits, output_bias)
probabilities = tf.nn.softmax(logits, axis=-1)
log_probs = tf.nn.log_softmax(logits, axis=-1)
one_hot_labels = tf.one_hot(labels, depth=num_labels, dtype=tf.flags.FLAGS.floatx)
per_example_loss = -tf.reduce_sum(one_hot_labels * log_probs, axis=-1)
loss = tf.reduce_mean(per_example_loss)
return (loss, per_example_loss, logits, probabilities)
def get_available_gpus():
local_device_protos = device_lib.list_local_devices()
return [x.name for x in local_device_protos if x.device_type == 'GPU']
def fast_transformer_model_trans(input_tensor,
attention_mask=None,
hidden_size=768,
num_hidden_layers=12,
num_attention_heads=12,
intermediate_size=3072,
intermediate_act_fn=gelu,
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
initializer_range=0.02,
do_return_all_layers=False):
""" Re-implementation of transformer_model function from modeling.py from Google's BERT repository https://github.com/google-research/bert
using FasterTransformer Tensorflow op.
Multi-headed, multi-layer Transformer from "Attention is All You Need".
This is almost an exact implementation of the original Transformer encoder.
See the original paper:
https://arxiv.org/abs/1706.03762
Also see:
https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/models/transformer.py
Args:
input_tensor: float Tensor of shape [batch_size, seq_length, hidden_size].
attention_mask: (optional) int32 Tensor of shape [batch_size, seq_length,
seq_length], with 1 for positions that can be attended to and 0 in
positions that should not be.
hidden_size: int. Hidden size of the Transformer.
num_hidden_layers: int. Number of layers (blocks) in the Transformer.
num_attention_heads: int. Number of attention heads in the Transformer.
intermediate_size: int. The size of the "intermediate" (a.k.a., feed
forward) layer.
intermediate_act_fn: function. The non-linear activation function to apply
to the output of the intermediate/feed-forward layer.
hidden_dropout_prob: float. Dropout probability for the hidden layers.
attention_probs_dropout_prob: float. Dropout probability of the attention
probabilities.
initializer_range: float. Range of the initializer (stddev of truncated
normal).
do_return_all_layers: Whether to also return all layers or just the final
layer.
Returns:
float Tensor of shape [batch_size, seq_length, hidden_size], the final
hidden layer of the Transformer.
Raises:
ValueError: A Tensor shape or parameter is invalid.
"""
if hidden_size % num_attention_heads != 0:
raise ValueError(
"The hidden size (%d) is not a multiple of the number of attention "
"heads (%d)" % (hidden_size, num_attention_heads))
attention_head_size = int(hidden_size / num_attention_heads)
input_shape = get_shape_list(input_tensor, expected_rank=3)
batch_size = input_shape[0]
seq_length = input_shape[1]
input_width = input_shape[2]
# The Transformer performs sum residuals on all layers so the input needs
# to be the same as the hidden size.
if input_width != hidden_size:
raise ValueError("The width of the input tensor (%d) != hidden size (%d)" %
(input_width, hidden_size))
# We keep the representation as a 2D tensor to avoid re-shaping it back and
# forth from a 3D tensor to a 2D tensor. Re-shapes are normally free on
# the GPU/CPU but may not be free on the TPU, so we want to minimize them to
# help the optimizer.
prev_output = reshape_to_matrix(input_tensor)
all_layer_outputs = []
for layer_idx in range(num_hidden_layers):
with tf.variable_scope("layer_%d" % layer_idx):
layer_input = prev_output
with tf.variable_scope("attention"):
attention_heads = []
with tf.variable_scope("self"):
attention_head = attention_layer(
from_tensor=layer_input,
to_tensor=layer_input,
attention_mask=attention_mask,
num_attention_heads=num_attention_heads,
size_per_head=attention_head_size,
attention_probs_dropout_prob=attention_probs_dropout_prob,
initializer_range=initializer_range,
do_return_2d_tensor=True,
batch_size=batch_size,
from_seq_length=seq_length,
to_seq_length=seq_length)
attention_heads.append(attention_head)
attention_output = None
if len(attention_heads) == 1:
attention_output = attention_heads[0]
else:
# In the case where we have other sequences, we just concatenate
# them to the self-attention head before the projection.
attention_output = tf.concat(attention_heads, axis=-1)
# Run a linear projection of `hidden_size` then add a residual
# with `layer_input`.
with tf.variable_scope("output"):
attention_output = tf.layers.dense(
attention_output,
hidden_size,
kernel_initializer=create_initializer(initializer_range))
attention_output = dropout(
attention_output, hidden_dropout_prob)
attention_output = layer_norm(
attention_output + layer_input)
# The activation is only applied to the "intermediate" hidden layer.
with tf.variable_scope("intermediate"):
intermediate_output = tf.layers.dense(
attention_output,
intermediate_size,
activation=intermediate_act_fn,
kernel_initializer=create_initializer(initializer_range))
# Down-project back to `hidden_size` then add the residual.
with tf.variable_scope("output"):
layer_output = tf.layers.dense(
intermediate_output,
hidden_size,
kernel_initializer=create_initializer(initializer_range))
layer_output = dropout(layer_output, hidden_dropout_prob)
layer_output = layer_norm(layer_output + attention_output)
# FASTINFER: fast transformer encoder inference
trainable_vars = tf.get_collection(
tf.GraphKeys.TRAINABLE_VARIABLES, scope=tf.get_variable_scope().name)
layer_output = transformer_op_module.bert_transformer(
layer_input,
layer_input,
trainable_vars[0], trainable_vars[2], trainable_vars[4], trainable_vars[1], trainable_vars[3], trainable_vars[5],
attention_mask,
trainable_vars[6], trainable_vars[7], trainable_vars[8], trainable_vars[9], trainable_vars[10], trainable_vars[11],
trainable_vars[12], trainable_vars[13], trainable_vars[14], trainable_vars[15],
batch_size=batch_size, from_seq_len=seq_length, to_seq_len=seq_length, head_num=num_attention_heads, size_per_head=attention_head_size)
prev_output = layer_output
all_layer_outputs.append(layer_output)
if do_return_all_layers:
final_outputs = []
for layer_output in all_layer_outputs:
final_output = reshape_from_matrix(layer_output, input_shape)
final_outputs.append(final_output)
return final_outputs
else:
final_output = reshape_from_matrix(prev_output, input_shape)
return final_output

View file

@ -0,0 +1,991 @@
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is mostly the same as bert/modeling.py from Google's BERT repository https://github.com/google-research/bert
# with configurable float types by setting tf.flags.FLAGS.floatx
"""The main BERT model and related functions."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import copy
import json
import math
import re
import numpy as np
import six
import tensorflow as tf
class BertConfig(object):
"""Configuration for `BertModel`."""
def __init__(self,
vocab_size,
hidden_size=768,
num_hidden_layers=12,
num_attention_heads=12,
intermediate_size=3072,
hidden_act="gelu",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
max_position_embeddings=512,
type_vocab_size=16,
initializer_range=0.02):
"""Constructs BertConfig.
Args:
vocab_size: Vocabulary size of `inputs_ids` in `BertModel`.
hidden_size: Size of the encoder layers and the pooler layer.
num_hidden_layers: Number of hidden layers in the Transformer encoder.
num_attention_heads: Number of attention heads for each attention layer in
the Transformer encoder.
intermediate_size: The size of the "intermediate" (i.e., feed-forward)
layer in the Transformer encoder.
hidden_act: The non-linear activation function (function or string) in the
encoder and pooler.
hidden_dropout_prob: The dropout probability for all fully connected
layers in the embeddings, encoder, and pooler.
attention_probs_dropout_prob: The dropout ratio for the attention
probabilities.
max_position_embeddings: The maximum sequence length that this model might
ever be used with. Typically set this to something large just in case
(e.g., 512 or 1024 or 2048).
type_vocab_size: The vocabulary size of the `token_type_ids` passed into
`BertModel`.
initializer_range: The stdev of the truncated_normal_initializer for
initializing all weight matrices.
"""
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.hidden_act = hidden_act
self.intermediate_size = intermediate_size
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.max_position_embeddings = max_position_embeddings
self.type_vocab_size = type_vocab_size
self.initializer_range = initializer_range
@classmethod
def from_dict(cls, json_object):
"""Constructs a `BertConfig` from a Python dictionary of parameters."""
config = BertConfig(vocab_size=None)
for (key, value) in six.iteritems(json_object):
config.__dict__[key] = value
return config
@classmethod
def from_json_file(cls, json_file):
"""Constructs a `BertConfig` from a json file of parameters."""
with tf.gfile.GFile(json_file, "r") as reader:
text = reader.read()
return cls.from_dict(json.loads(text))
def to_dict(self):
"""Serializes this instance to a Python dictionary."""
output = copy.deepcopy(self.__dict__)
return output
def to_json_string(self):
"""Serializes this instance to a JSON string."""
return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
class BertModel(object):
"""BERT model ("Bidirectional Encoder Representations from Transformers").
Example usage:
```python
# Already been converted into WordPiece token ids
input_ids = tf.constant([[31, 51, 99], [15, 5, 0]])
input_mask = tf.constant([[1, 1, 1], [1, 1, 0]])
token_type_ids = tf.constant([[0, 0, 1], [0, 2, 0]])
config = modeling.BertConfig(vocab_size=32000, hidden_size=512,
num_hidden_layers=8, num_attention_heads=6, intermediate_size=1024)
model = modeling.BertModel(config=config, is_training=True,
input_ids=input_ids, input_mask=input_mask, token_type_ids=token_type_ids)
label_embeddings = tf.get_variable(...)
pooled_output = model.get_pooled_output()
logits = tf.matmul(pooled_output, label_embeddings)
...
```
"""
def __init__(self,
config,
is_training,
input_ids,
input_mask=None,
token_type_ids=None,
use_one_hot_embeddings=False,
scope=None):
"""Constructor for BertModel.
Args:
config: `BertConfig` instance.
is_training: bool. true for training model, false for eval model. Controls
whether dropout will be applied.
input_ids: int32 Tensor of shape [batch_size, seq_length].
input_mask: (optional) int32 Tensor of shape [batch_size, seq_length].
token_type_ids: (optional) int32 Tensor of shape [batch_size, seq_length].
use_one_hot_embeddings: (optional) bool. Whether to use one-hot word
embeddings or tf.embedding_lookup() for the word embeddings.
scope: (optional) variable scope. Defaults to "bert".
Raises:
ValueError: The config is invalid or one of the input tensor shapes
is invalid.
"""
config = copy.deepcopy(config)
if not is_training:
config.hidden_dropout_prob = 0.0
config.attention_probs_dropout_prob = 0.0
input_shape = get_shape_list(input_ids, expected_rank=2)
batch_size = input_shape[0]
seq_length = input_shape[1]
if input_mask is None:
input_mask = tf.ones(shape=[batch_size, seq_length], dtype=tf.int32)
if token_type_ids is None:
token_type_ids = tf.zeros(shape=[batch_size, seq_length], dtype=tf.int32)
with tf.variable_scope(scope, default_name="bert"):
with tf.variable_scope("embeddings"):
# Perform embedding lookup on the word ids.
(self.embedding_output, self.embedding_table) = embedding_lookup(
input_ids=input_ids,
vocab_size=config.vocab_size,
embedding_size=config.hidden_size,
initializer_range=config.initializer_range,
word_embedding_name="word_embeddings",
use_one_hot_embeddings=use_one_hot_embeddings)
# Add positional embeddings and token type embeddings, then layer
# normalize and perform dropout.
self.embedding_output = embedding_postprocessor(
input_tensor=self.embedding_output,
use_token_type=True,
token_type_ids=token_type_ids,
token_type_vocab_size=config.type_vocab_size,
token_type_embedding_name="token_type_embeddings",
use_position_embeddings=True,
position_embedding_name="position_embeddings",
initializer_range=config.initializer_range,
max_position_embeddings=config.max_position_embeddings,
dropout_prob=config.hidden_dropout_prob)
with tf.variable_scope("encoder"):
# This converts a 2D mask of shape [batch_size, seq_length] to a 3D
# mask of shape [batch_size, seq_length, seq_length] which is used
# for the attention scores.
attention_mask = create_attention_mask_from_input_mask(
input_ids, input_mask)
# Run the stacked transformer.
# `sequence_output` shape = [batch_size, seq_length, hidden_size].
self.all_encoder_layers = transformer_model(
input_tensor=self.embedding_output,
attention_mask=attention_mask,
hidden_size=config.hidden_size,
num_hidden_layers=config.num_hidden_layers,
num_attention_heads=config.num_attention_heads,
intermediate_size=config.intermediate_size,
intermediate_act_fn=get_activation(config.hidden_act),
hidden_dropout_prob=config.hidden_dropout_prob,
attention_probs_dropout_prob=config.attention_probs_dropout_prob,
initializer_range=config.initializer_range,
do_return_all_layers=True)
self.sequence_output = self.all_encoder_layers[-1]
# The "pooler" converts the encoded sequence tensor of shape
# [batch_size, seq_length, hidden_size] to a tensor of shape
# [batch_size, hidden_size]. This is necessary for segment-level
# (or segment-pair-level) classification tasks where we need a fixed
# dimensional representation of the segment.
with tf.variable_scope("pooler"):
# We "pool" the model by simply taking the hidden state corresponding
# to the first token. We assume that this has been pre-trained
first_token_tensor = tf.squeeze(self.sequence_output[:, 0:1, :], axis=1)
self.pooled_output = tf.layers.dense(
first_token_tensor,
config.hidden_size,
activation=tf.tanh,
kernel_initializer=create_initializer(config.initializer_range))
def get_pooled_output(self):
return self.pooled_output
def get_sequence_output(self):
"""Gets final hidden layer of encoder.
Returns:
float Tensor of shape [batch_size, seq_length, hidden_size] corresponding
to the final hidden of the transformer encoder.
"""
return self.sequence_output
def get_all_encoder_layers(self):
return self.all_encoder_layers
def get_embedding_output(self):
"""Gets output of the embedding lookup (i.e., input to the transformer).
Returns:
float Tensor of shape [batch_size, seq_length, hidden_size] corresponding
to the output of the embedding layer, after summing the word
embeddings with the positional embeddings and the token type embeddings,
then performing layer normalization. This is the input to the transformer.
"""
return self.embedding_output
def get_embedding_table(self):
return self.embedding_table
def gelu(x):
"""Gaussian Error Linear Unit.
This is a smoother version of the RELU.
Original paper: https://arxiv.org/abs/1606.08415
Args:
x: float Tensor to perform activation.
Returns:
`x` with the GELU activation applied.
"""
cdf = 0.5 * (1.0 + tf.tanh(
(np.sqrt(2 / np.pi) * (x + 0.044715 * tf.pow(x, 3)))))
return x * cdf
def get_activation(activation_string):
"""Maps a string to a Python function, e.g., "relu" => `tf.nn.relu`.
Args:
activation_string: String name of the activation function.
Returns:
A Python function corresponding to the activation function. If
`activation_string` is None, empty, or "linear", this will return None.
If `activation_string` is not a string, it will return `activation_string`.
Raises:
ValueError: The `activation_string` does not correspond to a known
activation.
"""
# We assume that anything that"s not a string is already an activation
# function, so we just return it.
if not isinstance(activation_string, six.string_types):
return activation_string
if not activation_string:
return None
act = activation_string.lower()
if act == "linear":
return None
elif act == "relu":
return tf.nn.relu
elif act == "gelu":
return gelu
elif act == "tanh":
return tf.tanh
else:
raise ValueError("Unsupported activation: %s" % act)
def get_assignment_map_from_checkpoint(tvars, init_checkpoint):
"""Compute the union of the current variables and checkpoint variables."""
assignment_map = {}
initialized_variable_names = {}
name_to_variable = collections.OrderedDict()
for var in tvars:
name = var.name
m = re.match("^(.*):\\d+$", name)
if m is not None:
name = m.group(1)
name_to_variable[name] = var
init_vars = tf.train.list_variables(init_checkpoint)
assignment_map = collections.OrderedDict()
for x in init_vars:
(name, var) = (x[0], x[1])
if name not in name_to_variable:
continue
assignment_map[name] = name
initialized_variable_names[name] = 1
initialized_variable_names[name + ":0"] = 1
return (assignment_map, initialized_variable_names)
def dropout(input_tensor, dropout_prob):
"""Perform dropout.
Args:
input_tensor: float Tensor.
dropout_prob: Python float. The probability of dropping out a value (NOT of
*keeping* a dimension as in `tf.nn.dropout`).
Returns:
A version of `input_tensor` with dropout applied.
"""
if dropout_prob is None or dropout_prob == 0.0:
return input_tensor
output = tf.nn.dropout(input_tensor, 1.0 - dropout_prob)
return output
def layer_norm(input_tensor, name=None):
"""Run layer normalization on the last dimension of the tensor."""
return tf.contrib.layers.layer_norm(
inputs=input_tensor, begin_norm_axis=-1, begin_params_axis=-1, scope=name)
def layer_norm_and_dropout(input_tensor, dropout_prob, name=None):
"""Runs layer normalization followed by dropout."""
output_tensor = layer_norm(input_tensor, name)
output_tensor = dropout(output_tensor, dropout_prob)
return output_tensor
def create_initializer(initializer_range=0.02):
"""Creates a `truncated_normal_initializer` with the given range."""
return tf.truncated_normal_initializer(stddev=initializer_range, dtype=tf.flags.FLAGS.floatx)
def embedding_lookup(input_ids,
vocab_size,
embedding_size=128,
initializer_range=0.02,
word_embedding_name="word_embeddings",
use_one_hot_embeddings=False):
"""Looks up words embeddings for id tensor.
Args:
input_ids: int32 Tensor of shape [batch_size, seq_length] containing word
ids.
vocab_size: int. Size of the embedding vocabulary.
embedding_size: int. Width of the word embeddings.
initializer_range: float. Embedding initialization range.
word_embedding_name: string. Name of the embedding table.
use_one_hot_embeddings: bool. If True, use one-hot method for word
embeddings. If False, use `tf.gather()`.
Returns:
float Tensor of shape [batch_size, seq_length, embedding_size].
"""
# This function assumes that the input is of shape [batch_size, seq_length,
# num_inputs].
#
# If the input is a 2D tensor of shape [batch_size, seq_length], we
# reshape to [batch_size, seq_length, 1].
if input_ids.shape.ndims == 2:
input_ids = tf.expand_dims(input_ids, axis=[-1])
embedding_table = tf.get_variable(
name=word_embedding_name,
shape=[vocab_size, embedding_size],
dtype=tf.flags.FLAGS.floatx,
initializer=create_initializer(initializer_range))
flat_input_ids = tf.reshape(input_ids, [-1])
if use_one_hot_embeddings:
one_hot_input_ids = tf.one_hot(flat_input_ids, depth=vocab_size)
output = tf.matmul(one_hot_input_ids, embedding_table)
else:
output = tf.gather(embedding_table, flat_input_ids)
input_shape = get_shape_list(input_ids)
output = tf.reshape(output,
input_shape[0:-1] + [input_shape[-1] * embedding_size])
return (output, embedding_table)
def embedding_postprocessor(input_tensor,
use_token_type=False,
token_type_ids=None,
token_type_vocab_size=16,
token_type_embedding_name="token_type_embeddings",
use_position_embeddings=True,
position_embedding_name="position_embeddings",
initializer_range=0.02,
max_position_embeddings=512,
dropout_prob=0.1):
"""Performs various post-processing on a word embedding tensor.
Args:
input_tensor: float Tensor of shape [batch_size, seq_length,
embedding_size].
use_token_type: bool. Whether to add embeddings for `token_type_ids`.
token_type_ids: (optional) int32 Tensor of shape [batch_size, seq_length].
Must be specified if `use_token_type` is True.
token_type_vocab_size: int. The vocabulary size of `token_type_ids`.
token_type_embedding_name: string. The name of the embedding table variable
for token type ids.
use_position_embeddings: bool. Whether to add position embeddings for the
position of each token in the sequence.
position_embedding_name: string. The name of the embedding table variable
for positional embeddings.
initializer_range: float. Range of the weight initialization.
max_position_embeddings: int. Maximum sequence length that might ever be
used with this model. This can be longer than the sequence length of
input_tensor, but cannot be shorter.
dropout_prob: float. Dropout probability applied to the final output tensor.
Returns:
float tensor with same shape as `input_tensor`.
Raises:
ValueError: One of the tensor shapes or input values is invalid.
"""
input_shape = get_shape_list(input_tensor, expected_rank=3)
batch_size = input_shape[0]
seq_length = input_shape[1]
width = input_shape[2]
output = input_tensor
if use_token_type:
if token_type_ids is None:
raise ValueError("`token_type_ids` must be specified if"
"`use_token_type` is True.")
token_type_table = tf.get_variable(
name=token_type_embedding_name,
shape=[token_type_vocab_size, width],
dtype=tf.flags.FLAGS.floatx,
initializer=create_initializer(initializer_range))
# This vocab will be small so we always do one-hot here, since it is always
# faster for a small vocabulary.
flat_token_type_ids = tf.reshape(token_type_ids, [-1])
one_hot_ids = tf.one_hot(flat_token_type_ids, depth=token_type_vocab_size, dtype=tf.flags.FLAGS.floatx)
token_type_embeddings = tf.matmul(one_hot_ids, token_type_table)
token_type_embeddings = tf.reshape(token_type_embeddings,
[batch_size, seq_length, width])
output += token_type_embeddings
if use_position_embeddings:
assert_op = tf.assert_less_equal(seq_length, max_position_embeddings)
with tf.control_dependencies([assert_op]):
full_position_embeddings = tf.get_variable(
name=position_embedding_name,
shape=[max_position_embeddings, width],
dtype=tf.flags.FLAGS.floatx,
initializer=create_initializer(initializer_range))
# Since the position embedding table is a learned variable, we create it
# using a (long) sequence length `max_position_embeddings`. The actual
# sequence length might be shorter than this, for faster training of
# tasks that do not have long sequences.
#
# So `full_position_embeddings` is effectively an embedding table
# for position [0, 1, 2, ..., max_position_embeddings-1], and the current
# sequence has positions [0, 1, 2, ... seq_length-1], so we can just
# perform a slice.
position_embeddings = tf.slice(full_position_embeddings, [0, 0],
[seq_length, -1])
num_dims = len(output.shape.as_list())
# Only the last two dimensions are relevant (`seq_length` and `width`), so
# we broadcast among the first dimensions, which is typically just
# the batch size.
position_broadcast_shape = []
for _ in range(num_dims - 2):
position_broadcast_shape.append(1)
position_broadcast_shape.extend([seq_length, width])
position_embeddings = tf.reshape(position_embeddings,
position_broadcast_shape)
output += position_embeddings
output = layer_norm_and_dropout(output, dropout_prob)
return output
def create_attention_mask_from_input_mask(from_tensor, to_mask):
"""Create 3D attention mask from a 2D tensor mask.
Args:
from_tensor: 2D or 3D Tensor of shape [batch_size, from_seq_length, ...].
to_mask: int32 Tensor of shape [batch_size, to_seq_length].
Returns:
float Tensor of shape [batch_size, from_seq_length, to_seq_length].
"""
from_shape = get_shape_list(from_tensor, expected_rank=[2, 3])
batch_size = from_shape[0]
from_seq_length = from_shape[1]
to_shape = get_shape_list(to_mask, expected_rank=2)
to_seq_length = to_shape[1]
to_mask = tf.cast(
tf.reshape(to_mask, [batch_size, 1, to_seq_length]), tf.flags.FLAGS.floatx)
# We don't assume that `from_tensor` is a mask (although it could be). We
# don't actually care if we attend *from* padding tokens (only *to* padding)
# tokens so we create a tensor of all ones.
#
# `broadcast_ones` = [batch_size, from_seq_length, 1]
broadcast_ones = tf.ones(
shape=[batch_size, from_seq_length, 1], dtype=tf.flags.FLAGS.floatx)
# Here we broadcast along two dimensions to create the mask.
mask = broadcast_ones * to_mask
return mask
def attention_layer(from_tensor,
to_tensor,
attention_mask=None,
num_attention_heads=1,
size_per_head=512,
query_act=None,
key_act=None,
value_act=None,
attention_probs_dropout_prob=0.0,
initializer_range=0.02,
do_return_2d_tensor=False,
batch_size=None,
from_seq_length=None,
to_seq_length=None):
"""Performs multi-headed attention from `from_tensor` to `to_tensor`.
This is an implementation of multi-headed attention based on "Attention
is all you Need". If `from_tensor` and `to_tensor` are the same, then
this is self-attention. Each timestep in `from_tensor` attends to the
corresponding sequence in `to_tensor`, and returns a fixed-with vector.
This function first projects `from_tensor` into a "query" tensor and
`to_tensor` into "key" and "value" tensors. These are (effectively) a list
of tensors of length `num_attention_heads`, where each tensor is of shape
[batch_size, seq_length, size_per_head].
Then, the query and key tensors are dot-producted and scaled. These are
softmaxed to obtain attention probabilities. The value tensors are then
interpolated by these probabilities, then concatenated back to a single
tensor and returned.
In practice, the multi-headed attention are done with transposes and
reshapes rather than actual separate tensors.
Args:
from_tensor: float Tensor of shape [batch_size, from_seq_length,
from_width].
to_tensor: float Tensor of shape [batch_size, to_seq_length, to_width].
attention_mask: (optional) int32 Tensor of shape [batch_size,
from_seq_length, to_seq_length]. The values should be 1 or 0. The
attention scores will effectively be set to -infinity for any positions in
the mask that are 0, and will be unchanged for positions that are 1.
num_attention_heads: int. Number of attention heads.
size_per_head: int. Size of each attention head.
query_act: (optional) Activation function for the query transform.
key_act: (optional) Activation function for the key transform.
value_act: (optional) Activation function for the value transform.
attention_probs_dropout_prob: (optional) float. Dropout probability of the
attention probabilities.
initializer_range: float. Range of the weight initializer.
do_return_2d_tensor: bool. If True, the output will be of shape [batch_size
* from_seq_length, num_attention_heads * size_per_head]. If False, the
output will be of shape [batch_size, from_seq_length, num_attention_heads
* size_per_head].
batch_size: (Optional) int. If the input is 2D, this might be the batch size
of the 3D version of the `from_tensor` and `to_tensor`.
from_seq_length: (Optional) If the input is 2D, this might be the seq length
of the 3D version of the `from_tensor`.
to_seq_length: (Optional) If the input is 2D, this might be the seq length
of the 3D version of the `to_tensor`.
Returns:
float Tensor of shape [batch_size, from_seq_length,
num_attention_heads * size_per_head]. (If `do_return_2d_tensor` is
true, this will be of shape [batch_size * from_seq_length,
num_attention_heads * size_per_head]).
Raises:
ValueError: Any of the arguments or tensor shapes are invalid.
"""
def transpose_for_scores(input_tensor, batch_size, num_attention_heads,
seq_length, width):
output_tensor = tf.reshape(
input_tensor, [batch_size, seq_length, num_attention_heads, width])
output_tensor = tf.transpose(output_tensor, [0, 2, 1, 3])
return output_tensor
from_shape = get_shape_list(from_tensor, expected_rank=[2, 3])
to_shape = get_shape_list(to_tensor, expected_rank=[2, 3])
if len(from_shape) != len(to_shape):
raise ValueError(
"The rank of `from_tensor` must match the rank of `to_tensor`.")
if len(from_shape) == 3:
batch_size = from_shape[0]
from_seq_length = from_shape[1]
to_seq_length = to_shape[1]
elif len(from_shape) == 2:
if (batch_size is None or from_seq_length is None or to_seq_length is None):
raise ValueError(
"When passing in rank 2 tensors to attention_layer, the values "
"for `batch_size`, `from_seq_length`, and `to_seq_length` "
"must all be specified.")
# Scalar dimensions referenced here:
# B = batch size (number of sequences)
# F = `from_tensor` sequence length
# T = `to_tensor` sequence length
# N = `num_attention_heads`
# H = `size_per_head`
from_tensor_2d = reshape_to_matrix(from_tensor)
to_tensor_2d = reshape_to_matrix(to_tensor)
# `query_layer` = [B*F, N*H]
query_layer = tf.layers.dense(
from_tensor_2d,
num_attention_heads * size_per_head,
activation=query_act,
name="query",
kernel_initializer=create_initializer(initializer_range))
# `key_layer` = [B*T, N*H]
key_layer = tf.layers.dense(
to_tensor_2d,
num_attention_heads * size_per_head,
activation=key_act,
name="key",
kernel_initializer=create_initializer(initializer_range))
# `value_layer` = [B*T, N*H]
value_layer = tf.layers.dense(
to_tensor_2d,
num_attention_heads * size_per_head,
activation=value_act,
name="value",
kernel_initializer=create_initializer(initializer_range))
# `query_layer` = [B, N, F, H]
query_layer = transpose_for_scores(query_layer, batch_size,
num_attention_heads, from_seq_length,
size_per_head)
# `key_layer` = [B, N, T, H]
key_layer = transpose_for_scores(key_layer, batch_size, num_attention_heads,
to_seq_length, size_per_head)
# Take the dot product between "query" and "key" to get the raw
# attention scores.
# `attention_scores` = [B, N, F, T]
attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True)
attention_scores = tf.multiply(attention_scores,
1.0 / math.sqrt(size_per_head))
if attention_mask is not None:
# `attention_mask` = [B, 1, F, T]
attention_mask = tf.expand_dims(attention_mask, axis=[1])
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for
# positions we want to attend and -10000.0 for masked positions.
adder = (1.0 - tf.cast(attention_mask, tf.flags.FLAGS.floatx)) * -10000.0
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
attention_scores += adder
# Normalize the attention scores to probabilities.
# `attention_probs` = [B, N, F, T]
attention_probs = tf.nn.softmax(attention_scores)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attention_probs = dropout(attention_probs, attention_probs_dropout_prob)
# `value_layer` = [B, T, N, H]
value_layer = tf.reshape(
value_layer,
[batch_size, to_seq_length, num_attention_heads, size_per_head])
# `value_layer` = [B, N, T, H]
value_layer = tf.transpose(value_layer, [0, 2, 1, 3])
# `context_layer` = [B, N, F, H]
context_layer = tf.matmul(attention_probs, value_layer)
# `context_layer` = [B, F, N, H]
context_layer = tf.transpose(context_layer, [0, 2, 1, 3])
if do_return_2d_tensor:
# `context_layer` = [B*F, N*H]
context_layer = tf.reshape(
context_layer,
[batch_size * from_seq_length, num_attention_heads * size_per_head])
else:
# `context_layer` = [B, F, N*H]
context_layer = tf.reshape(
context_layer,
[batch_size, from_seq_length, num_attention_heads * size_per_head])
return context_layer
def transformer_model(input_tensor,
attention_mask=None,
hidden_size=768,
num_hidden_layers=12,
num_attention_heads=12,
intermediate_size=3072,
intermediate_act_fn=gelu,
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
initializer_range=0.02,
do_return_all_layers=False):
"""Multi-headed, multi-layer Transformer from "Attention is All You Need".
This is almost an exact implementation of the original Transformer encoder.
See the original paper:
https://arxiv.org/abs/1706.03762
Also see:
https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/models/transformer.py
Args:
input_tensor: float Tensor of shape [batch_size, seq_length, hidden_size].
attention_mask: (optional) int32 Tensor of shape [batch_size, seq_length,
seq_length], with 1 for positions that can be attended to and 0 in
positions that should not be.
hidden_size: int. Hidden size of the Transformer.
num_hidden_layers: int. Number of layers (blocks) in the Transformer.
num_attention_heads: int. Number of attention heads in the Transformer.
intermediate_size: int. The size of the "intermediate" (a.k.a., feed
forward) layer.
intermediate_act_fn: function. The non-linear activation function to apply
to the output of the intermediate/feed-forward layer.
hidden_dropout_prob: float. Dropout probability for the hidden layers.
attention_probs_dropout_prob: float. Dropout probability of the attention
probabilities.
initializer_range: float. Range of the initializer (stddev of truncated
normal).
do_return_all_layers: Whether to also return all layers or just the final
layer.
Returns:
float Tensor of shape [batch_size, seq_length, hidden_size], the final
hidden layer of the Transformer.
Raises:
ValueError: A Tensor shape or parameter is invalid.
"""
if hidden_size % num_attention_heads != 0:
raise ValueError(
"The hidden size (%d) is not a multiple of the number of attention "
"heads (%d)" % (hidden_size, num_attention_heads))
attention_head_size = int(hidden_size / num_attention_heads)
input_shape = get_shape_list(input_tensor, expected_rank=3)
batch_size = input_shape[0]
seq_length = input_shape[1]
input_width = input_shape[2]
# The Transformer performs sum residuals on all layers so the input needs
# to be the same as the hidden size.
if input_width != hidden_size:
raise ValueError("The width of the input tensor (%d) != hidden size (%d)" %
(input_width, hidden_size))
# We keep the representation as a 2D tensor to avoid re-shaping it back and
# forth from a 3D tensor to a 2D tensor. Re-shapes are normally free on
# the GPU/CPU but may not be free on the TPU, so we want to minimize them to
# help the optimizer.
prev_output = reshape_to_matrix(input_tensor)
all_layer_outputs = []
for layer_idx in range(num_hidden_layers):
with tf.variable_scope("layer_%d" % layer_idx):
layer_input = prev_output
with tf.variable_scope("attention"):
attention_heads = []
with tf.variable_scope("self"):
attention_head = attention_layer(
from_tensor=layer_input,
to_tensor=layer_input,
attention_mask=attention_mask,
num_attention_heads=num_attention_heads,
size_per_head=attention_head_size,
attention_probs_dropout_prob=attention_probs_dropout_prob,
initializer_range=initializer_range,
do_return_2d_tensor=True,
batch_size=batch_size,
from_seq_length=seq_length,
to_seq_length=seq_length)
attention_heads.append(attention_head)
attention_output = None
if len(attention_heads) == 1:
attention_output = attention_heads[0]
else:
# In the case where we have other sequences, we just concatenate
# them to the self-attention head before the projection.
attention_output = tf.concat(attention_heads, axis=-1)
# Run a linear projection of `hidden_size` then add a residual
# with `layer_input`.
with tf.variable_scope("output"):
attention_output = tf.layers.dense(
attention_output,
hidden_size,
kernel_initializer=create_initializer(initializer_range))
attention_output = dropout(attention_output, hidden_dropout_prob)
attention_output = layer_norm(attention_output + layer_input)
# The activation is only applied to the "intermediate" hidden layer.
with tf.variable_scope("intermediate"):
intermediate_output = tf.layers.dense(
attention_output,
intermediate_size,
activation=intermediate_act_fn,
kernel_initializer=create_initializer(initializer_range))
# Down-project back to `hidden_size` then add the residual.
with tf.variable_scope("output"):
layer_output = tf.layers.dense(
intermediate_output,
hidden_size,
kernel_initializer=create_initializer(initializer_range))
layer_output = dropout(layer_output, hidden_dropout_prob)
layer_output = layer_norm(layer_output + attention_output)
prev_output = layer_output
all_layer_outputs.append(layer_output)
if do_return_all_layers:
final_outputs = []
for layer_output in all_layer_outputs:
final_output = reshape_from_matrix(layer_output, input_shape)
final_outputs.append(final_output)
return final_outputs
else:
final_output = reshape_from_matrix(prev_output, input_shape)
return final_output
def get_shape_list(tensor, expected_rank=None, name=None):
"""Returns a list of the shape of tensor, preferring static dimensions.
Args:
tensor: A tf.Tensor object to find the shape of.
expected_rank: (optional) int. The expected rank of `tensor`. If this is
specified and the `tensor` has a different rank, and exception will be
thrown.
name: Optional name of the tensor for the error message.
Returns:
A list of dimensions of the shape of tensor. All static dimensions will
be returned as python integers, and dynamic dimensions will be returned
as tf.Tensor scalars.
"""
if name is None:
name = tensor.name
if expected_rank is not None:
assert_rank(tensor, expected_rank, name)
shape = tensor.shape.as_list()
non_static_indexes = []
for (index, dim) in enumerate(shape):
if dim is None:
non_static_indexes.append(index)
if not non_static_indexes:
return shape
dyn_shape = tf.shape(tensor)
for index in non_static_indexes:
shape[index] = dyn_shape[index]
return shape
def reshape_to_matrix(input_tensor):
"""Reshapes a >= rank 2 tensor to a rank 2 tensor (i.e., a matrix)."""
ndims = input_tensor.shape.ndims
if ndims < 2:
raise ValueError("Input tensor must have at least rank 2. Shape = %s" %
(input_tensor.shape))
if ndims == 2:
return input_tensor
width = input_tensor.shape[-1]
output_tensor = tf.reshape(input_tensor, [-1, width])
return output_tensor
def reshape_from_matrix(output_tensor, orig_shape_list):
"""Reshapes a rank 2 tensor back to its original rank >= 2 tensor."""
if len(orig_shape_list) == 2:
return output_tensor
output_shape = get_shape_list(output_tensor)
orig_dims = orig_shape_list[0:-1]
width = output_shape[-1]
return tf.reshape(output_tensor, orig_dims + [width])
def assert_rank(tensor, expected_rank, name=None):
"""Raises an exception if the tensor rank is not of the expected rank.
Args:
tensor: A tf.Tensor to check the rank of.
expected_rank: Python integer or list of integers, expected rank.
name: Optional name of the tensor for the error message.
Raises:
ValueError: If the expected shape doesn't match the actual shape.
"""
if name is None:
name = tensor.name
expected_rank_dict = {}
if isinstance(expected_rank, six.integer_types):
expected_rank_dict[expected_rank] = True
else:
for x in expected_rank:
expected_rank_dict[x] = True
actual_rank = tensor.shape.ndims
if actual_rank not in expected_rank_dict:
scope_name = tf.get_variable_scope().name
raise ValueError(
"For the tensor `%s` in scope `%s`, the actual rank "
"`%d` (shape = %s) is not equal to the expected rank `%s`" %
(name, scope_name, actual_rank, str(tensor.shape), str(expected_rank)))

View file

@ -0,0 +1,196 @@
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# usage example
# export GLUE_DIR=/home/dongluw/Desktop/data/glue_data
# export BERT_BASE_DIR=/home/dongluw/Desktop/data/uncased_L-12_H-768_A-12
# python profile_bert_inference.py --task_name=MRPC --data_dir=$GLUE_DIR/MRPC --vocab_file=$BERT_BASE_DIR/vocab.txt --bert_config_file=$BERT_BASE_DIR/bert_config.json --predict_batch_size=8 --max_seq_length=128 --output_dir=mrpc_output --init_checkpoint=$BERT_BASE_DIR/bert_model.ckpt --tf_profile=true --profiling_output_file=time_elapsed --xla=false --floatx=float32
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import sys
import numpy as np
import fast_infer_util as fiu
import profile_util
import tensorflow as tf
import os
from tensorflow.python.client import timeline
import contextlib
import time
from tensorflow.python.client import device_lib
import my_modeling
bert_submodule = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'bert')
sys.path.insert(0, bert_submodule)
import tokenization
import run_classifier as rc
flags = tf.flags
FLAGS = flags.FLAGS
def create_model(bert_config, is_training, input_ids, input_mask, segment_ids):
"""Creates a classification model."""
model = my_modeling.BertModel(
config=bert_config,
is_training=is_training,
input_ids=input_ids,
input_mask=input_mask,
token_type_ids=segment_ids,
use_one_hot_embeddings=False)
seq_output = model.get_sequence_output()
return seq_output
def model_fn_builder(bert_config):
def model_fn(features):
# print features
tf.logging.info("*** Features ***")
for name in sorted(features.keys()):
tf.logging.info(" name = %s, shape = %s" %
(name, features[name].shape))
input_ids = features["input_ids"]
input_mask = features["input_mask"]
segment_ids = features["segment_ids"]
label_ids = features["label_ids"]
fetches = create_model(
bert_config, False, input_ids, input_mask, segment_ids)
# # fetch mrpc logits for prediction
# num_labels = 2 # for mrpc
# _, _, fetches, _ = fiu.create_model(bert_config, False, input_ids, input_mask, segment_ids, label_ids,
# num_labels, False)
return fetches
return model_fn
def main(_):
tf.logging.set_verbosity(tf.logging.INFO)
num_iter = 20
jit_xla = tf.OptimizerOptions.ON_1 if FLAGS.xla else 0
processors = {
"cola": rc.ColaProcessor,
"mnli": rc.MnliProcessor,
"mrpc": rc.MrpcProcessor,
"xnli": rc.XnliProcessor,
}
# sanity check
tokenization.validate_case_matches_checkpoint(FLAGS.do_lower_case,
FLAGS.init_checkpoint)
bert_config = my_modeling.BertConfig.from_json_file(FLAGS.bert_config_file)
if FLAGS.max_seq_length > bert_config.max_position_embeddings:
raise ValueError(
"Cannot use sequence length %d because the BERT model "
"was only trained up to sequence length %d" %
(FLAGS.max_seq_length, bert_config.max_position_embeddings))
tf.gfile.MakeDirs(FLAGS.output_dir)
task_name = FLAGS.task_name.lower()
if task_name not in processors:
raise ValueError("Task not found: %s" % (task_name))
# prepare data
processor = processors[task_name]()
label_list = processor.get_labels()
predict_examples = processor.get_test_examples(FLAGS.data_dir)
tokenizer = tokenization.FullTokenizer(
vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)
predict_file = os.path.join(FLAGS.output_dir, "predict.tf_record")
rc.file_based_convert_examples_to_features(predict_examples, label_list,
FLAGS.max_seq_length, tokenizer,
predict_file)
# get model function and input function
# drop_remainder option should be turned on for fast transformer inference
drop_remainder = True
predict_input_fn = rc.file_based_input_fn_builder(
input_file=predict_file,
seq_length=FLAGS.max_seq_length,
is_training=False,
drop_remainder=drop_remainder)
def graph_fn():
model_fn = model_fn_builder(bert_config=bert_config)
dataset = predict_input_fn({'batch_size': FLAGS.predict_batch_size})
next_item = dataset.make_one_shot_iterator().get_next()
output_var = model_fn(next_item)
return output_var
if FLAGS.tf_profile:
tf.logging.info("***** Running tensorflow transformer*****")
p1 = profile_util.Profiler(os.path.join(
FLAGS.output_dir, 'prof/bert_origin'))
t1, r1 = profile_util.run_profile(
graph_fn, jit_xla, num_iter, p1, init_checkpoint=FLAGS.init_checkpoint)
tf.reset_default_graph()
my_modeling.transformer_model = fiu.fast_transformer_model_trans
tf.logging.info("***** Running fast transformer*****")
p2 = profile_util.Profiler(os.path.join(
FLAGS.output_dir, 'prof/bert_fastinfer'))
t2, r2 = profile_util.run_profile(
graph_fn, jit_xla, num_iter, p2, init_checkpoint=FLAGS.init_checkpoint)
else:
tf.logging.info("***** Running tensorflow transformer*****")
t1, r1 = profile_util.run_profile(
graph_fn, jit_xla, num_iter, check_result=False, init_checkpoint=FLAGS.init_checkpoint)
tf.reset_default_graph()
my_modeling.transformer_model = fiu.fast_transformer_model_trans
tf.logging.info("***** Running fast transformer*****")
t2, r2 = profile_util.run_profile(
graph_fn, jit_xla, num_iter, check_result=False, init_checkpoint=FLAGS.init_checkpoint)
print('average time (seconds) elasped original tensorflow:', t1)
print('average time (seconds) elasped fast transformer:', t2)
if len(r1) + len(r2) > 0:
check_res = np.asarray([np.allclose(
r1[i], r2[i], atol=1e-4, rtol=0) for i in range(num_iter)])
if check_res.all():
print('Pass')
print(np.mean(r1))
print(np.mean(r2))
else:
for i in np.where(np.logical_not(check_res))[0]:
diff = np.fabs(r1[i] - r2[i])
idx = np.unravel_index(diff.argmax(), diff.shape)
print('Failed iter:', i, "max diff:",
diff[idx], idx, r1[i][idx], r2[i][idx])
if __name__ == "__main__":
flags.mark_flag_as_required("data_dir")
flags.mark_flag_as_required("task_name")
flags.mark_flag_as_required("vocab_file")
flags.mark_flag_as_required("bert_config_file")
flags.mark_flag_as_required("output_dir")
flags.DEFINE_string("profiling_output_file", None,
"The output file for profiling results.")
flags.mark_flag_as_required("profiling_output_file")
flags.DEFINE_string("floatx", "float32", "float32 or float16")
flags.mark_flag_as_required("floatx")
flags.DEFINE_bool("xla", False, "whether to turn on XLA")
flags.mark_flag_as_required("xla")
flags.DEFINE_bool("tf_profile", False,
"whether to use tensorflow profiling")
tf.app.run()

View file

@ -0,0 +1,223 @@
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# usage example
# export BERT_BASE_DIR=/home/dongluw/Desktop/data/uncased_L-12_H-768_A-12
# python profile_transformer_inference.py --init_checkpoint=$BERT_BASE_DIR/bert_model.ckpt --tf_profile=false --output_dir=mrpc_output --profiling_output_file=time_elapsed --xla=false --floatx=float32
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.client import device_lib
import time
import contextlib
from tensorflow.python.client import timeline
import os
import tensorflow as tf
import fast_infer_util as fiu
import numpy as np
import profile_util
import sys
import my_modeling
bert_submodule = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'bert')
sys.path.insert(0, bert_submodule)
import run_classifier
import optimization
flags = tf.flags
FLAGS = flags.FLAGS
# stacked transformer encoders
class TransformerModel(object):
def __init__(self,
config,
is_training,
input_tensor,
attention_mask,
transformer_model_fn,
scope=None):
config = my_modeling.copy.deepcopy(config)
if not is_training:
config.hidden_dropout_prob = 0.0
config.attention_probs_dropout_prob = 0.0
input_shape = my_modeling.get_shape_list(input_tensor, expected_rank=3)
batch_size = input_shape[0]
seq_length = input_shape[1]
with tf.variable_scope(scope, default_name="bert"):
with tf.variable_scope("encoder"):
# Run the stacked transformer.
# `sequence_output` shape = [batch_size, seq_length, hidden_size].
self.all_encoder_layers = transformer_model_fn(
input_tensor=input_tensor,
attention_mask=attention_mask,
hidden_size=config.hidden_size,
num_hidden_layers=config.num_hidden_layers,
num_attention_heads=config.num_attention_heads,
intermediate_size=config.intermediate_size,
intermediate_act_fn=my_modeling.get_activation(
config.hidden_act),
hidden_dropout_prob=config.hidden_dropout_prob,
attention_probs_dropout_prob=config.attention_probs_dropout_prob,
initializer_range=config.initializer_range,
do_return_all_layers=True)
self.sequence_output = self.all_encoder_layers[-1]
with tf.variable_scope("pooler"):
first_token_tensor = tf.squeeze(self.sequence_output[:, 0:1, :], axis=1)
self.pooled_output = tf.layers.dense(
first_token_tensor,
config.hidden_size,
activation=tf.tanh,
kernel_initializer=my_modeling.create_initializer(config.initializer_range))
def get_pooled_output(self):
return self.pooled_output
def get_sequence_output(self):
return self.sequence_output
def model_fn_builder(bert_config, transformer_model_fn):
def model_fn(input_tensor, attention_mask): # pylint: disable=unused-argument
model = TransformerModel(
config=bert_config,
is_training=False,
input_tensor=input_tensor,
attention_mask=attention_mask,
transformer_model_fn=transformer_model_fn)
seq_output = model.get_sequence_output()
return seq_output
return model_fn
def profile_model(config, jit_xla, num_iter):
# initialize data
input_data = np.random.randn(
FLAGS.predict_batch_size, FLAGS.max_seq_length, config.hidden_size)
attention_mask = np.random.randint(2, size=(
FLAGS.predict_batch_size, FLAGS.max_seq_length))
attention_mask = np.repeat(
attention_mask[:, np.newaxis, :], FLAGS.max_seq_length, axis=1)
model_fn_tf = model_fn_builder(config, my_modeling.transformer_model)
model_fn_ft = model_fn_builder(config, fiu.fast_transformer_model_trans)
def graph_fn_builder(model_fn):
def graph_fn():
input_tensor = tf.constant(input_data, dtype=FLAGS.floatx)
mask_tensor = tf.constant(attention_mask, dtype=FLAGS.floatx)
output_var = model_fn(input_tensor, mask_tensor)
# for saving memcopy time
return tf.reduce_mean(output_var)
return graph_fn
if FLAGS.tf_profile:
tf.logging.info("***** Running tensorflow transformer*****")
p1 = profile_util.Profiler(os.path.join(
FLAGS.output_dir, 'prof/bert_origin'))
t1, r1 = profile_util.run_profile(graph_fn_builder(
model_fn_tf), jit_xla, num_iter, p1, init_checkpoint=FLAGS.init_checkpoint)
tf.reset_default_graph()
tf.logging.info("***** Running fast transformer*****")
p2 = profile_util.Profiler(os.path.join(
FLAGS.output_dir, 'prof/bert_fastinfer'))
t2, r2 = profile_util.run_profile(graph_fn_builder(
model_fn_ft), jit_xla, num_iter, p2, init_checkpoint=FLAGS.init_checkpoint)
else:
tf.logging.info("***** Running tensorflow transformer*****")
t1, r1 = profile_util.run_profile(graph_fn_builder(
model_fn_tf), jit_xla, num_iter, check_result=False, init_checkpoint=FLAGS.init_checkpoint)
tf.reset_default_graph()
tf.logging.info("***** Running fast transformer*****")
t2, r2 = profile_util.run_profile(graph_fn_builder(
model_fn_ft), jit_xla, num_iter, check_result=False, init_checkpoint=FLAGS.init_checkpoint)
# check errors
print('average time (seconds) elasped original tensorflow:', t1)
print('average time (seconds) elasped fast transformer:', t2)
if len(r1) + len(r2) > 0:
check_res = np.asarray([np.allclose(
r1[i], r2[i], atol=1e-4, rtol=0) for i in range(num_iter)])
if check_res.all():
print('Pass')
print(np.mean(r1))
print(np.mean(r2))
else:
for i in np.where(np.logical_not(check_res))[0]:
diff = np.fabs(r1[i] - r2[i])
idx = np.unravel_index(diff.argmax(), diff.shape)
print('Failed iter:', i, "max diff:",
diff[idx], idx, r1[i][idx], r2[i][idx])
return t1, t2
def main(_):
tf.logging.set_verbosity(tf.logging.INFO)
batch_size = [8]
seq_length = [128]
num_hidden_layers = [12]
attention_heads_num_size = [(12, 64)]
num_iter = 20
interval = 0
# collect results of both original bert and fast transformer
jit_xla = tf.OptimizerOptions.ON_1 if FLAGS.xla else 0
config = my_modeling.BertConfig(vocab_size=0)
tf.gfile.MakeDirs(FLAGS.output_dir)
local_device_protos = device_lib.list_local_devices()
with open(os.path.join(FLAGS.output_dir, FLAGS.profiling_output_file), 'w') as f:
for x in local_device_protos:
if x.device_type == 'GPU':
f.write(x.physical_device_desc + '\n')
f.write(str(FLAGS.floatx) + '\t' + 'XLA: ' + str(FLAGS.xla) + '\n')
f.write('batch_size\tseq_length\thidden_layers\tattention_heads\tattention_head_size\tTensorflow\tFasterTransformer\n')
for bs in batch_size:
FLAGS.predict_batch_size = bs
for sl in seq_length:
FLAGS.max_seq_length = sl
for hidden_layers in num_hidden_layers:
config.num_hidden_layers = hidden_layers
for head_num, head_size in attention_heads_num_size:
config.num_attention_heads = head_num
config.hidden_size = head_num * head_size
time.sleep(interval)
t1, t2 = profile_model(config, jit_xla, num_iter)
tmp = [FLAGS.predict_batch_size, FLAGS.max_seq_length, hidden_layers, head_num, head_size,
'{:.6}'.format(t1), '{:.6}'.format(t2)]
f.write('\t'.join([str(x) for x in tmp]) + '\n')
if __name__ == "__main__":
flags.mark_flag_as_required("output_dir")
flags.DEFINE_string("profiling_output_file", None,
"The output file for profiling results.")
flags.mark_flag_as_required("profiling_output_file")
flags.DEFINE_string("floatx", "float32", "float32 or float16")
flags.mark_flag_as_required("floatx")
flags.DEFINE_bool("xla", False, "whether to turn on XLA")
flags.mark_flag_as_required("xla")
flags.DEFINE_bool("tf_profile", False,
"whether to use tensorflow profiling")
tf.app.run()

View file

@ -0,0 +1,88 @@
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from tensorflow.python.client import device_lib
import time
import contextlib
from tensorflow.python.client import timeline
import os
import tensorflow as tf
class Profiler():
def __init__(self, profile_name_pref):
self.profile_name_pref = profile_name_pref
self.run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
self.run_metadata = tf.RunMetadata()
self.ctr = 0
self.time_avg = 0
@contextlib.contextmanager
def prof_run(self):
start = time.time()
yield
end = time.time()
self.time_avg = (self.time_avg * self.ctr + end - start)/(self.ctr + 1)
fetched_timeline = timeline.Timeline(self.run_metadata.step_stats)
chrome_trace = fetched_timeline.generate_chrome_trace_format()
file_name = self.profile_name_pref + '_' + str(self.ctr) + '.json'
os.makedirs(os.path.dirname(file_name), exist_ok=True)
with open(file_name, 'w') as f:
f.write(chrome_trace)
self.ctr += 1
def run_profile(graph_fn, jit_xla, num_iter, profiler=None, init_checkpoint=None, check_result=True, dryrun_iter=1):
config = tf.ConfigProto()
config.graph_options.optimizer_options.global_jit_level = jit_xla
fetches = graph_fn()
with tf.Session(config=config) as sess:
# init
if init_checkpoint is None:
sess.run(tf.global_variables_initializer())
else:
saver = tf.train.Saver()
saver.restore(sess, init_checkpoint)
# dry run
for _ in range(dryrun_iter):
sess.run(fetches)
res = []
if profiler is None:
start_time = time.time()
if check_result:
for _ in range(num_iter):
res.append(sess.run(fetches))
else:
for _ in range(num_iter):
sess.run(fetches)
end_time = time.time()
time_avg = (end_time - start_time)/num_iter
else:
if check_result:
for _ in range(num_iter):
with profiler.prof_run():
res.append(sess.run(fetches, options=profiler.run_options, run_metadata=profiler.run_metadata))
else:
for _ in range(num_iter):
with profiler.prof_run():
sess.run(fetches, options=profiler.run_options, run_metadata=profiler.run_metadata)
time_avg = profiler.time_avg
return time_avg, res

View file

@ -0,0 +1,76 @@
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# usage example
# export GLUE_DIR=/home/dongluw/Desktop/data/glue_data
# export BERT_BASE_DIR=/home/dongluw/Desktop/data/uncased_L-12_H-768_A-12
# python run_classifier_wrap.py --floatx=float16 --task_name=MRPC --do_eval=true --data_dir=$GLUE_DIR/MRPC --vocab_file=$BERT_BASE_DIR/vocab.txt --bert_config_file=$BERT_BASE_DIR/bert_config.json --init_checkpoint=mrpc_output/fp16_model.ckpt --max_seq_length=128 --eval_batch_size=8 --output_dir=mrpc_output
# FP32 Tensorflow Transformer MRPC result
# INFO:tensorflow: eval_accuracy = 0.877451
# INFO:tensorflow: eval_loss = 0.44744828
# INFO:tensorflow: global_step = 0
# INFO:tensorflow: loss = 0.44744828
# FP32 Faster Transformer MRPC result
# INFO:tensorflow: eval_accuracy = 0.877451
# INFO:tensorflow: eval_loss = 0.4474482
# INFO:tensorflow: global_step = 0
# INFO:tensorflow: loss = 0.4474482
# FP16 Tensorflow Transformer MRPC result
# INFO:tensorflow: eval_accuracy = 0.875
# INFO:tensorflow: eval_loss = 0.44760832
# INFO:tensorflow: global_step = 0
# INFO:tensorflow: loss = 0.44760215
# FP16 Faster Transformer MRPC result
# INFO:tensorflow: eval_accuracy = 0.875
# INFO:tensorflow: eval_loss = 0.44731623
# INFO:tensorflow: global_step = 0
# INFO:tensorflow: loss = 0.44728807
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import sys
import os
bert_submodule = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'bert')
sys.path.insert(0, bert_submodule)
import tensorflow as tf
import run_classifier as rc
import fast_infer_util as fiu
import my_modeling
flags = tf.flags
FLAGS = flags.FLAGS
# replace transformer implementation
my_modeling.transformer_model = fiu.fast_transformer_model_trans
# replace the model to support fp16 data type
rc.create_model = fiu.create_model
# replace the input function to drop remainder
rc.file_based_input_fn_builder = fiu.file_based_input_fn_builder_drop
main = rc.main
if __name__ == "__main__":
flags.mark_flag_as_required("data_dir")
flags.mark_flag_as_required("task_name")
flags.mark_flag_as_required("vocab_file")
flags.mark_flag_as_required("bert_config_file")
flags.mark_flag_as_required("output_dir")
flags.DEFINE_string("floatx", None, "float32 or float16")
flags.mark_flag_as_required("floatx")
tf.app.run()

View file

@ -0,0 +1,148 @@
Tensorflow BERT Samples
---
**Using bert_transformer Tensorflow op in a transformer encoder**
The trunk network of BERT model consists of a multi-layer transformer encoder,
which is implemented as the `transformer_model()` function in the file `modeling.py` in their official [Github repository](https://github.com/google-research/bert).
Samples provided in file `fast_infer_util.py` show how to re-implement this function with our ops in order to get an inference time speedup.
The function `fast_transformer_model_trans()` implements the transformer encoder using the `bert_transformer` op.
In order to do that, we only need to first import the op at the beginning of the file, then call `bert_transformer` op at the end of each encoder layer. This turns out can be done by adding several lines of code to the original `transformer_model()` function as the following.
```python
# import op
transformer_op_module = tf.load_op_library(os.path.join('../../build/lib/libtf_fastertransformer.so'))
...
def fast_transformer_model_trans(...)
...
# original code
with tf.variable_scope("output"):
layer_output = tf.layers.dense(
intermediate_output,
hidden_size,
kernel_initializer=create_initializer(initializer_range))
layer_output = dropout(layer_output, hidden_dropout_prob)
layer_output = layer_norm(layer_output + attention_output)
# calling bert_transformer
trainable_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=tf.get_variable_scope().name)
layer_output = transformer_op_module.bert_transformer(
layer_input,
layer_input,
trainable_vars[0], trainable_vars[2], trainable_vars[4], trainable_vars[1], trainable_vars[3], trainable_vars[5],
attention_mask,
trainable_vars[6], trainable_vars[7], trainable_vars[8], trainable_vars[9], trainable_vars[10], trainable_vars[11],
trainable_vars[12], trainable_vars[13], trainable_vars[14], trainable_vars[15],
batch_size=batch_size, from_seq_len=seq_length, to_seq_len=seq_length, head_num=num_attention_heads, size_per_head=attention_head_size)
# original code
prev_output = layer_output
all_layer_outputs.append(layer_output)
...
```
**Running GLEU tasks with fast transformer inference**
The above shows how to implement a transformer encoder using our ops, to integrate it into the BERT pipeline
we can simply replace the `transformer_model` function in `modeling.py` with `fast_transformer_model_trans`.
Our implementation supports FP16 data type to further exploit the potential of inference acceleration.
FP16 inference was not supported by the original BERT code, here we made necessary modifications to build a FP16 compatible model,
which was implemented in `my_modeling.py` and the `create_model` function in `fast_infer_util.py`.
FP32 Tensorflow checkpoint files cannot be used directly for FP16 inference, we can convert its data type to FP16 in advance.
The `ckpt_type_convert.py` script is provided for checkpoint data type conversion.
It is also important to note that our implementation requires a fixed batch size, this can be done by setting `drop_remainder` option to `True` for Tensorflow `Dataset` instances. We have re-implemented this as well in the `file_based_input_fn_builder_drop` function.
On top of the above modifications, it's easy to run any of the GLEU tasks supported by the open source BERT sample with our ops for better inference performance. We only need to replace several functions in original `run_classifier.py` script with the implementations we provide.
```python
import run_classifier as rc
import fast_infer_util as fiu
import my_modeling
...
# replace transformer implementation
my_modeling.transformer_model = fiu.fast_transformer_model_trans
# replace the model to support fp16 data type
rc.create_model = fiu.create_model
# replace the input function to drop remainder
rc.file_based_input_fn_builder = fiu.file_based_input_fn_builder_drop
...
```
The sample `run_classifier_wrap.py` is a wrapper of the original `run_classifier.py` script for BERT, it supports the same options as described in [BERT readme](https://github.com/google-research/bert) with additional `floatx` options to specify floating point type.
For example, to compare the performance of original BERT and our implementation on MRPC task we can run the following command.
```bash
export BERT_BASE_DIR=/path/to/bert/uncased_L-12_H-768_A-12
export GLUE_DIR=/path/to/glue
python run_classifier.py --task_name=MRPC --do_eval=true --data_dir=$GLUE_DIR/MRPC --vocab_file=$BERT_BASE_DIR/vocab.txt --bert_config_file=$BERT_BASE_DIR/bert_config.json --init_checkpoint=ckpt_dir/fp32_model.ckpt --max_seq_length=128 --eval_batch_size=8 --output_dir=mrpc_output
python run_classifier_wrap.py --task_name=MRPC --do_eval=true --data_dir=$GLUE_DIR/MRPC --vocab_file=$BERT_BASE_DIR/vocab.txt --bert_config_file=$BERT_BASE_DIR/bert_config.json --init_checkpoint=ckpt_dir/fp16_model.ckpt --max_seq_length=128 --eval_batch_size=8 --output_dir=mrpc_output --floatx=float16
```
The evaluation result should be like
```
# original Tensorflow op
...
INFO:tensorflow:***** Eval results *****
INFO:tensorflow: eval_accuracy = 0.877451
INFO:tensorflow: eval_loss = 0.44744828
INFO:tensorflow: global_step = 0
INFO:tensorflow: loss = 0.44744828
# faster_transformer op with fp16 data type
INFO:tensorflow:***** Eval results *****
INFO:tensorflow: eval_accuracy = 0.875
INFO:tensorflow: eval_loss = 0.44731623
INFO:tensorflow: global_step = 0
INFO:tensorflow: loss = 0.44728807
...
```
We see the evaluation accuracy and loss drop slightly with FP16 inference for the MRPC sentence pair classification task.
The following section will show such minor sacrifice in accuracy will bring considerable performance gain.
**Tensorflow profiling**
The sample script `profile_transformer_inference.py` shows how to run and profile a BERT inference model from scratch. Results show we received a 6.36x speedup compared to FP32 Tensorflow (1.48x speedup compared to FP16 Tensorflow XLA) for an end-to-end classification model in our experiment settings.
GPU: Tesla T4
CUDA: 10.0.0
Model: BERT-Base: 12-layer, 768-hidden, 12-heads , 110M parameters
Max sequence length: 128
Batch size: 32
Average time elapsed:
| settings | seconds |
| ------------- | ------------- |
| FP32 Tensorflow | 0.2495 |
| FP32 Tensorflow XLA | 0.1998 |
| FP16 Tensorflow | 0.0978 |
| FP16 Tensorflow XLA | 0.0582 |
| FP16 FasterTransformer | 0.0392 |
**Content summary**
| file name | summary |
| ------------- | ------------- |
| `ckpt_type_convert.py` | script for checkpoint data type conversion |
| `fast_infer_util.py` | example functions to use faster transformer ops in Tensorflow |
| `my_modeling.py` | basically the same as `modeling.py` in the original BERT repository, modifications are made to support FP16 data types |
| `run_classifier_wrap.py` | a wrapper script of `run_classifier.py` in the original BERT repository, shows how to run classification tasks using faster transformer ops |
| `profile_bert_inference.py` | for profiling BERT model pipelines |
| `profile_transformer_inference.py` | for profiling transformer encoder layers |
| `profile_util.py` | helper functions for profiling |

View file

@ -0,0 +1,21 @@
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
cmake_minimum_required(VERSION 3.8)
set(trt_files
transformer_trt.cc
)
add_executable(transformer_trt ${trt_files})
target_link_libraries(transformer_trt PRIVATE -lcublas -lcudart -lnvinfer fastertransformer ${CMAKE_THREAD_LIBS_INIT})

View file

@ -0,0 +1,164 @@
/*
* Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "fastertransformer/trt_plugin/trt_model.h"
#include <cstdio>
#include <cstdlib>
#include <cuda_profiler_api.h>
#include <iostream>
#include <sys/time.h>
using namespace fastertransformer;
double diffTime(timeval start, timeval end)
{
return (end.tv_sec - start.tv_sec) * 1000 + (end.tv_usec - start.tv_usec) * 0.001;
}
template <typename T>
void host_malloc(T** ptr, int size)
{
(*ptr) = (T*)malloc(sizeof(T) * size);
}
template <typename T>
void run_bert_transformer(int batch_size, int seq_len, int layers, int head_num, int size_per_head){
int hidden_dim = head_num * size_per_head;
std::vector<std::vector<T *> > params;
T *h_from_tensor = NULL, *h_transformer_out = NULL;
T *h_attr_mask = NULL;
host_malloc(&h_from_tensor, batch_size * seq_len * hidden_dim);
host_malloc(&h_transformer_out, batch_size * seq_len * hidden_dim);
host_malloc(&h_attr_mask, batch_size * seq_len * seq_len);
for(int i = 0; i < batch_size * seq_len * hidden_dim; ++i)
h_from_tensor[i] = 0.001f;
for(int i = 0; i < batch_size * seq_len * seq_len; ++i)
h_attr_mask[i] = 1.0f;
for(int i = 0; i < layers; ++i)
{
T *h_attr_kernel_Q = NULL, *h_attr_kernel_K = NULL, *h_attr_kernel_V = NULL;
T *h_attr_bias_Q = NULL, *h_attr_bias_K = NULL, *h_attr_bias_V = NULL;
T *h_attr_output_kernel = NULL, *h_attr_output_bias = NULL;
T *h_attr_output_layernorm_beta = NULL;
T *h_attr_output_layernorm_gamma = NULL;
T *h_inter_kernel = NULL, *h_inter_bias = NULL;
T *h_output_kernel = NULL, *h_output_bias = NULL, *h_output_layernorm_beta = NULL, *h_output_layernorm_gamma = NULL;
host_malloc(&h_attr_kernel_Q, hidden_dim * hidden_dim);
host_malloc(&h_attr_kernel_K, hidden_dim * hidden_dim);
host_malloc(&h_attr_kernel_V, hidden_dim * hidden_dim);
host_malloc(&h_attr_bias_Q, hidden_dim);
host_malloc(&h_attr_bias_K, hidden_dim);
host_malloc(&h_attr_bias_V, hidden_dim);
host_malloc(&h_attr_output_kernel, hidden_dim * hidden_dim);
host_malloc(&h_attr_output_bias, hidden_dim);
host_malloc(&h_attr_output_layernorm_beta, hidden_dim);
host_malloc(&h_attr_output_layernorm_gamma, hidden_dim);
host_malloc(&h_inter_kernel, hidden_dim * hidden_dim * 4);
host_malloc(&h_inter_bias, hidden_dim * 4);
host_malloc(&h_output_kernel, hidden_dim * hidden_dim * 4);
host_malloc(&h_output_bias, hidden_dim);
host_malloc(&h_output_layernorm_beta, hidden_dim);
host_malloc(&h_output_layernorm_gamma, hidden_dim);
for(int i = 0; i < hidden_dim * hidden_dim; ++i)
{
h_attr_kernel_Q[i] = 0.001f;
h_attr_kernel_K[i] = 0.001f;
h_attr_kernel_V[i] = 0.001f;
h_attr_output_kernel[i] = 0.0001f * i;
if(i < hidden_dim)
{
h_attr_bias_Q[i] = 0.001f;
h_attr_bias_K[i] = 0.001f;
h_attr_bias_V[i] = 0.001f;
h_attr_output_bias[i] = 0.001f;
h_attr_output_layernorm_beta[i] = 0.0001f * i;
h_attr_output_layernorm_gamma[i] = 0.001f * i;
h_output_bias[i] = 0.001f;
h_output_layernorm_beta[i] = 0.001f;
h_output_layernorm_gamma[i] = 0.001f;
}
if(i < hidden_dim * 4)
h_inter_bias[i] = 0.001f;
}
for(int i = 0; i < 4 * hidden_dim * hidden_dim; ++i)
{
h_inter_kernel[i] = 0.001f;
h_output_kernel[i] = 0.001f;
}
std::vector<T* > layer_param;
layer_param.push_back(h_attr_kernel_Q);
layer_param.push_back(h_attr_kernel_K);
layer_param.push_back(h_attr_kernel_V);
layer_param.push_back(h_attr_bias_Q);
layer_param.push_back(h_attr_bias_K);
layer_param.push_back(h_attr_bias_V);
layer_param.push_back(h_attr_output_kernel);
layer_param.push_back(h_attr_output_bias);
layer_param.push_back(h_attr_output_layernorm_beta);
layer_param.push_back(h_attr_output_layernorm_gamma);
layer_param.push_back(h_inter_kernel);
layer_param.push_back(h_inter_bias);
layer_param.push_back(h_output_kernel);
layer_param.push_back(h_output_bias);
layer_param.push_back(h_output_layernorm_beta);
layer_param.push_back(h_output_layernorm_gamma);
params.push_back(layer_param);
}
cudaStream_t stream;
cudaStreamCreate(&stream);
TRT_Transformer<T>* trt_transformer = new TRT_Transformer<T>(batch_size, seq_len, head_num, hidden_dim, layers);
trt_transformer->build_engine(params);
trt_transformer->do_inference(batch_size, h_from_tensor, h_attr_mask, h_transformer_out, stream);
delete trt_transformer;
printf("finished!\n");
}
int main(int argc, char* argv[])
{
if(argc != 7)
{
printf("./transformer_trt batch_size num_layers seq_len head_num size_per_head fp32/fp16\n");
printf("e.g., ./transformer_trt 1 12 32 12 64 fp32\n");
printf("e.g., ./transformer_trt 1 12 32 12 64 fp16\n");
return 0;
}
int batch_size = atoi(argv[1]);
int num_layers = atoi(argv[2]);
int seq_len = atoi(argv[3]);
int head_num = atoi(argv[4]);
int size_per_head = atoi(argv[5]);
if(strcmp(argv[6], "fp16") == 0)
run_bert_transformer<__half>(batch_size, seq_len, num_layers, head_num, size_per_head);
else if(strcmp(argv[6], "fp32") == 0)
run_bert_transformer<float>(batch_size, seq_len, num_layers, head_num, size_per_head);
else
{
printf("the last argument is invalid, it should be fp16 or fp32\n");
return 0;
}
}

View file

@ -0,0 +1,359 @@
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import tensorflow as tf
import numpy as np
import os
import math
import six
from datetime import datetime
import sys
transformer_op_module = tf.load_op_library(os.path.join('./lib/libtf_transformer.so'))
argumentList = sys.argv
batch_size = int(sys.argv[1])
num_layers = int(sys.argv[2])
seq_len = int(sys.argv[3])
print("Argumentlist: batch_size " + str(batch_size) + " num_layers " + str(num_layers) + " seq_len " + str(seq_len))
head_num = 12
size_per_head = 64
hidden_dim = head_num * size_per_head
initializer_range = 0.02
from_data = np.random.randn(batch_size, seq_len, hidden_dim)
from_tensor = tf.convert_to_tensor(from_data, dtype=float)
mask = np.random.randint(2, size=(batch_size, seq_len, seq_len))
attention_mask = tf.convert_to_tensor(mask, dtype=float)
def gelu(x):
cdf = 0.5 * (1.0 + tf.tanh(
(np.sqrt(2 / np.pi) * (x + 0.044715 * tf.pow(x, 3)))))
return x * cdf
def layer_norm(input_tensor, name=None):
return tf.contrib.layers.layer_norm(
inputs=input_tensor, begin_norm_axis=-1, begin_params_axis=-1, scope=name)
def create_initializer(initializer_range=0.02):
return tf.truncated_normal_initializer(stddev=initializer_range)
def attention_layer(from_tensor,
to_tensor,
attention_mask=None,
num_attention_heads=1,
size_per_head=512,
query_act=None,
key_act=None,
value_act=None,
attention_probs_dropout_prob=0.0,
initializer_range=0.02,
do_return_2d_tensor=False,
batch_size=None,
from_seq_length=None,
to_seq_length=None):
def transpose_for_scores(input_tensor, batch_size, num_attention_heads,
seq_length, width):
output_tensor = tf.reshape(
input_tensor, [batch_size, seq_length, num_attention_heads, width])
output_tensor = tf.transpose(output_tensor, [0, 2, 1, 3])
return output_tensor
from_shape = get_shape_list(from_tensor, expected_rank=[2, 3])
to_shape = get_shape_list(to_tensor, expected_rank=[2, 3])
if len(from_shape) != len(to_shape):
raise ValueError(
"The rank of `from_tensor` must match the rank of `to_tensor`.")
if len(from_shape) == 3:
batch_size = from_shape[0]
from_seq_length = from_shape[1]
to_seq_length = to_shape[1]
elif len(from_shape) == 2:
if (batch_size is None or from_seq_length is None or to_seq_length is None):
raise ValueError(
"When passing in rank 2 tensors to attention_layer, the values "
"for `batch_size`, `from_seq_length`, and `to_seq_length` "
"must all be specified.")
from_tensor_2d = reshape_to_matrix(from_tensor)
to_tensor_2d = reshape_to_matrix(to_tensor)
# `query_layer` = [B*F, N*H]
query_layer = tf.layers.dense(
from_tensor_2d,
num_attention_heads * size_per_head,
activation=query_act,
name="query",
use_bias=True,
bias_initializer = create_initializer(initializer_range),
kernel_initializer=create_initializer(initializer_range))
# `key_layer` = [B*T, N*H]
key_layer = tf.layers.dense(
to_tensor_2d,
num_attention_heads * size_per_head,
activation=key_act,
name="key",
use_bias=True,
bias_initializer = create_initializer(initializer_range),
kernel_initializer=create_initializer(initializer_range))
# `value_layer` = [B*T, N*H]
value_layer = tf.layers.dense(
to_tensor_2d,
num_attention_heads * size_per_head,
activation=value_act,
name="value",
use_bias=True,
bias_initializer = create_initializer(initializer_range),
kernel_initializer=create_initializer(initializer_range))
# `query_layer` = [B, N, F, H]
query_layer = transpose_for_scores(query_layer, batch_size,
num_attention_heads, from_seq_length,
size_per_head)
# `key_layer` = [B, N, T, H]
key_layer = transpose_for_scores(key_layer, batch_size, num_attention_heads,
to_seq_length, size_per_head)
attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True)
attention_scores = tf.multiply(attention_scores,
1.0 / math.sqrt(float(size_per_head)))
if attention_mask is not None:
# `attention_mask` = [B, 1, F, T]
attention_mask = tf.expand_dims(attention_mask, axis=[1])
adder = (1.0 - tf.cast(attention_mask, tf.float32)) * -10000.0
attention_scores += adder
attention_probs = tf.nn.softmax(attention_scores)
value_layer = tf.reshape(
value_layer,
[batch_size, to_seq_length, num_attention_heads, size_per_head])
value_layer = tf.transpose(value_layer, [0, 2, 1, 3])
context_layer = tf.matmul(attention_probs, value_layer)
context_layer = tf.transpose(context_layer, [0, 2, 1, 3])
if do_return_2d_tensor:
context_layer = tf.reshape(
context_layer,
[batch_size * from_seq_length, num_attention_heads * size_per_head])
else:
context_layer = tf.reshape(
context_layer,
[batch_size, from_seq_length, num_attention_heads * size_per_head])
return context_layer
def transformer_model(input_tensor,
attention_mask=None,
hidden_size=768,
num_hidden_layers=12,
num_attention_heads=12,
intermediate_size=3072,
intermediate_act_fn=gelu,
initializer_range=0.02,
do_return_all_layers=False):
if hidden_size % num_attention_heads != 0:
raise ValueError(
"The hidden size (%d) is not a multiple of the number of attention "
"heads (%d)" % (hidden_size, num_attention_heads))
attention_head_size = int(hidden_size / num_attention_heads)
input_shape = get_shape_list(input_tensor, expected_rank=3)
batch_size = input_shape[0]
seq_length = input_shape[1]
input_width = input_shape[2]
prev_output = reshape_to_matrix(input_tensor)
for layer_idx in range(num_hidden_layers):
with tf.variable_scope("layer_%d" % layer_idx, reuse=tf.AUTO_REUSE):
layer_input = prev_output
with tf.variable_scope("attention"):
with tf.variable_scope("self"):
attention_head = attention_layer(
from_tensor=layer_input,
to_tensor=layer_input,
attention_mask=attention_mask,
num_attention_heads=num_attention_heads,
size_per_head=attention_head_size,
initializer_range=initializer_range,
do_return_2d_tensor=True,
batch_size=batch_size,
from_seq_length=seq_length,
to_seq_length=seq_length)
attention_output = attention_head
with tf.variable_scope("output"):
attention_output = tf.layers.dense(
attention_output,
hidden_size,
use_bias=True,
bias_initializer = create_initializer(initializer_range),
kernel_initializer=create_initializer(initializer_range))
attention_output = layer_norm(attention_output + layer_input)
# The activation is only applied to the "intermediate" hidden layer.
with tf.variable_scope("intermediate"):
intermediate_output = tf.layers.dense(
attention_output,
intermediate_size,
activation=intermediate_act_fn,
use_bias=True,
bias_initializer = create_initializer(initializer_range),
kernel_initializer=create_initializer(initializer_range))
# Down-project back to `hidden_size` then add the residual.
with tf.variable_scope("output"):
layer_output = tf.layers.dense(
intermediate_output,
hidden_size,
use_bias=True,
bias_initializer = create_initializer(initializer_range),
kernel_initializer=create_initializer(initializer_range))
layer_output = layer_norm(layer_output + attention_output)
prev_output = layer_output
return prev_output
def get_shape_list(tensor, expected_rank=None, name=None):
if name is None:
name = tensor.name
if expected_rank is not None:
assert_rank(tensor, expected_rank, name)
shape = tensor.shape.as_list()
non_static_indexes = []
for (index, dim) in enumerate(shape):
if dim is None:
non_static_indexes.append(index)
if not non_static_indexes:
return shape
dyn_shape = tf.shape(tensor)
for index in non_static_indexes:
shape[index] = dyn_shape[index]
return shape
def reshape_to_matrix(input_tensor):
"""Reshapes a >= rank 2 tensor to a rank 2 tensor (i.e., a matrix)."""
ndims = input_tensor.shape.ndims
if ndims < 2:
raise ValueError("Input tensor must have at least rank 2. Shape = %s" %
(input_tensor.shape))
if ndims == 2:
return input_tensor
width = input_tensor.shape[-1]
output_tensor = tf.reshape(input_tensor, [-1, width])
return output_tensor
def reshape_from_matrix(output_tensor, orig_shape_list):
if len(orig_shape_list) == 2:
return output_tensor
output_shape = get_shape_list(output_tensor)
orig_dims = orig_shape_list[0:-1]
width = output_shape[-1]
return tf.reshape(output_tensor, orig_dims + [width])
def assert_rank(tensor, expected_rank, name=None):
if name is None:
name = tensor.name
expected_rank_dict = {}
if isinstance(expected_rank, six.integer_types):
expected_rank_dict[expected_rank] = True
else:
for x in expected_rank:
expected_rank_dict[x] = True
actual_rank = tensor.shape.ndims
if actual_rank not in expected_rank_dict:
scope_name = tf.get_variable_scope().name
raise ValueError(
"For the tensor `%s` in scope `%s`, the actual rank "
"`%d` (shape = %s) is not equal to the expected rank `%s`" %
(name, scope_name, actual_rank, str(tensor.shape), str(expected_rank)))
def transformer_single(input_tensor, params, layer_idx):
val_off = layer_idx * 16
output = transformer_op_module.bert_transformer(
input_tensor,
input_tensor,
params[val_off + 0], params[val_off + 2], params[val_off + 4], params[val_off + 1], params[val_off + 3], params[val_off + 5], attention_mask,
params[val_off + 6], params[val_off + 7], params[val_off + 8], params[val_off + 9], params[val_off + 10],
params[val_off + 11], params[val_off + 12], params[val_off + 13], params[val_off + 14], params[val_off + 15],
batch_size = batch_size, from_seq_len = seq_len, to_seq_len = seq_len, head_num = head_num, size_per_head = size_per_head)
return output
def transformer_own(input_tensor, params):
in_tensor = input_tensor
for layer_idx in range(num_layers):
out_tensor = transformer_single(in_tensor, params, layer_idx)
in_tensor = out_tensor
return in_tensor
output = transformer_model(input_tensor=from_tensor, attention_mask = attention_mask, num_hidden_layers = num_layers, do_return_all_layers=True)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
run_metadata = tf.RunMetadata()
sess.run(output)
Model_variables = tf.GraphKeys.GLOBAL_VARIABLES
idx = 0
all_vars = tf.get_collection(Model_variables)
for var in all_vars:
print (str(idx) + " " + str(var.name) + " " + str(var.shape))
idx = idx + 1
params = all_vars
output_own = transformer_own(from_tensor, params)
for ite in range(20):
print("ite " + str(ite))
try:
sess.run(output_own)
except tf.errors.InvalidArgumentError as e:
print(e)
except tf.errors.InternalError as e:
print(e)
except:
print("Runtime error")

View file

@ -0,0 +1,407 @@
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import tensorflow as tf
import numpy as np
import os
import math
import six
from datetime import datetime
import sys
transformer_op_module = tf.load_op_library(os.path.join('./lib/libtf_fastertransformer.so'))
if len(sys.argv) != 6:
print "python transformer_fp16.py batch_size num_layers seq_len head_num size_per_head"
print "e.g., python transformer_fp16.py 1 12 32 12 64"
sys.exit(0)
argumentList = sys.argv
batch_size = int(sys.argv[1])
num_layers = int(sys.argv[2])
seq_len = int(sys.argv[3])
head_num = int(sys.argv[4])
size_per_head = int(sys.argv[5])
#batch_size = 192
#num_layers = 6
#seq_len = 32
#head_num = 8
#size_per_head = 96
print("Argumentlist: batch_size " + str(batch_size) + " num_layers " + str(num_layers) + " seq_len " + str(seq_len))
hidden_dim = head_num * size_per_head
initializer_range = 0.02
from_data = np.random.randn(batch_size, seq_len, hidden_dim)
from_tensor = tf.convert_to_tensor(from_data, dtype=tf.float16)
mask = np.random.randint(2, size=(batch_size, seq_len, seq_len))
attention_mask = tf.convert_to_tensor(mask, dtype=tf.float16)
def gelu(x):
cdf = 0.5 * (1.0 + tf.tanh(
(np.sqrt(2 / np.pi) * (x + 0.044715 * tf.pow(x, 3)))))
return x * cdf
def layer_norm(input_tensor, name=None):
return tf.contrib.layers.layer_norm(
inputs=input_tensor, begin_norm_axis=-1, begin_params_axis=-1, scope=name)
def create_initializer(initializer_range=0.02):
return tf.truncated_normal_initializer(stddev=initializer_range)
def attention_layer(from_tensor,
to_tensor,
attention_mask=None,
num_attention_heads=1,
size_per_head=512,
query_act=None,
key_act=None,
value_act=None,
attention_probs_dropout_prob=0.0,
initializer_range=0.02,
do_return_2d_tensor=False,
batch_size=None,
from_seq_length=None,
to_seq_length=None):
def transpose_for_scores(input_tensor, batch_size, num_attention_heads,
seq_length, width):
output_tensor = tf.reshape(
input_tensor, [batch_size, seq_length, num_attention_heads, width])
output_tensor = tf.transpose(output_tensor, [0, 2, 1, 3])
return output_tensor
from_shape = get_shape_list(from_tensor, expected_rank=[2, 3])
to_shape = get_shape_list(to_tensor, expected_rank=[2, 3])
if len(from_shape) != len(to_shape):
raise ValueError(
"The rank of `from_tensor` must match the rank of `to_tensor`.")
if len(from_shape) == 3:
batch_size = from_shape[0]
from_seq_length = from_shape[1]
to_seq_length = to_shape[1]
elif len(from_shape) == 2:
if (batch_size is None or from_seq_length is None or to_seq_length is None):
raise ValueError(
"When passing in rank 2 tensors to attention_layer, the values "
"for `batch_size`, `from_seq_length`, and `to_seq_length` "
"must all be specified.")
from_tensor_2d = reshape_to_matrix(from_tensor)
to_tensor_2d = reshape_to_matrix(to_tensor)
# `query_layer` = [B*F, N*H]
query_layer = tf.layers.dense(
from_tensor_2d,
num_attention_heads * size_per_head,
activation=query_act,
name="query",
use_bias=True,
bias_initializer = create_initializer(initializer_range),
kernel_initializer=create_initializer(initializer_range))
# `key_layer` = [B*T, N*H]
key_layer = tf.layers.dense(
to_tensor_2d,
num_attention_heads * size_per_head,
activation=key_act,
name="key",
use_bias=True,
bias_initializer = create_initializer(initializer_range),
kernel_initializer=create_initializer(initializer_range))
# `value_layer` = [B*T, N*H]
value_layer = tf.layers.dense(
to_tensor_2d,
num_attention_heads * size_per_head,
activation=value_act,
name="value",
use_bias=True,
bias_initializer = create_initializer(initializer_range),
kernel_initializer=create_initializer(initializer_range))
# `query_layer` = [B, N, F, H]
query_layer = transpose_for_scores(query_layer, batch_size,
num_attention_heads, from_seq_length,
size_per_head)
# `key_layer` = [B, N, T, H]
key_layer = transpose_for_scores(key_layer, batch_size, num_attention_heads,
to_seq_length, size_per_head)
attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True)
attention_scores = tf.multiply(attention_scores,
1.0 / math.sqrt(float(size_per_head)))
if attention_mask is not None:
# `attention_mask` = [B, 1, F, T]
attention_mask = tf.expand_dims(attention_mask, axis=[1])
adder = (1.0 - tf.cast(attention_mask, tf.float16)) * -10000.0
attention_scores += adder
attention_probs = tf.nn.softmax(attention_scores)
value_layer = tf.reshape(
value_layer,
[batch_size, to_seq_length, num_attention_heads, size_per_head])
value_layer = tf.transpose(value_layer, [0, 2, 1, 3])
context_layer = tf.matmul(attention_probs, value_layer)
context_layer = tf.transpose(context_layer, [0, 2, 1, 3])
if do_return_2d_tensor:
context_layer = tf.reshape(
context_layer,
[batch_size * from_seq_length, num_attention_heads * size_per_head])
else:
context_layer = tf.reshape(
context_layer,
[batch_size, from_seq_length, num_attention_heads * size_per_head])
return context_layer
def transformer_model(input_tensor,
attention_mask=None,
hidden_size=768,
num_hidden_layers=12,
num_attention_heads=12,
intermediate_size=3072,
intermediate_act_fn=gelu,
initializer_range=0.02,
do_return_all_layers=False):
intermediate_size=hidden_size * 4
if hidden_size % num_attention_heads != 0:
raise ValueError(
"The hidden size (%d) is not a multiple of the number of attention "
"heads (%d)" % (hidden_size, num_attention_heads))
attention_head_size = int(hidden_size / num_attention_heads)
input_shape = get_shape_list(input_tensor, expected_rank=3)
batch_size = input_shape[0]
seq_length = input_shape[1]
input_width = input_shape[2]
prev_output = reshape_to_matrix(input_tensor)
for layer_idx in range(num_hidden_layers):
with tf.variable_scope("layer_%d" % layer_idx, reuse=tf.AUTO_REUSE):
layer_input = prev_output
with tf.variable_scope("attention"):
with tf.variable_scope("self"):
attention_head = attention_layer(
from_tensor=layer_input,
to_tensor=layer_input,
attention_mask=attention_mask,
num_attention_heads=num_attention_heads,
size_per_head=attention_head_size,
initializer_range=initializer_range,
do_return_2d_tensor=True,
batch_size=batch_size,
from_seq_length=seq_length,
to_seq_length=seq_length)
attention_output = attention_head
with tf.variable_scope("output"):
attention_output = tf.layers.dense(
attention_output,
hidden_size,
use_bias=True,
bias_initializer = create_initializer(initializer_range),
kernel_initializer=create_initializer(initializer_range))
attention_output = layer_norm(attention_output + layer_input)
# The activation is only applied to the "intermediate" hidden layer.
with tf.variable_scope("intermediate"):
intermediate_output = tf.layers.dense(
attention_output,
intermediate_size,
activation=intermediate_act_fn,
use_bias=True,
bias_initializer = create_initializer(initializer_range),
kernel_initializer=create_initializer(initializer_range))
# Down-project back to `hidden_size` then add the residual.
with tf.variable_scope("output"):
layer_output = tf.layers.dense(
intermediate_output,
hidden_size,
use_bias=True,
bias_initializer = create_initializer(initializer_range),
kernel_initializer=create_initializer(initializer_range))
layer_output = layer_norm(layer_output + attention_output)
prev_output = layer_output
return prev_output
def get_shape_list(tensor, expected_rank=None, name=None):
if name is None:
name = tensor.name
if expected_rank is not None:
assert_rank(tensor, expected_rank, name)
shape = tensor.shape.as_list()
non_static_indexes = []
for (index, dim) in enumerate(shape):
if dim is None:
non_static_indexes.append(index)
if not non_static_indexes:
return shape
dyn_shape = tf.shape(tensor)
for index in non_static_indexes:
shape[index] = dyn_shape[index]
return shape
def reshape_to_matrix(input_tensor):
"""Reshapes a >= rank 2 tensor to a rank 2 tensor (i.e., a matrix)."""
ndims = input_tensor.shape.ndims
if ndims < 2:
raise ValueError("Input tensor must have at least rank 2. Shape = %s" %
(input_tensor.shape))
if ndims == 2:
return input_tensor
width = input_tensor.shape[-1]
output_tensor = tf.reshape(input_tensor, [-1, width])
return output_tensor
def reshape_from_matrix(output_tensor, orig_shape_list):
if len(orig_shape_list) == 2:
return output_tensor
output_shape = get_shape_list(output_tensor)
orig_dims = orig_shape_list[0:-1]
width = output_shape[-1]
return tf.reshape(output_tensor, orig_dims + [width])
def assert_rank(tensor, expected_rank, name=None):
if name is None:
name = tensor.name
expected_rank_dict = {}
if isinstance(expected_rank, six.integer_types):
expected_rank_dict[expected_rank] = True
else:
for x in expected_rank:
expected_rank_dict[x] = True
actual_rank = tensor.shape.ndims
if actual_rank not in expected_rank_dict:
scope_name = tf.get_variable_scope().name
raise ValueError(
"For the tensor `%s` in scope `%s`, the actual rank "
"`%d` (shape = %s) is not equal to the expected rank `%s`" %
(name, scope_name, actual_rank, str(tensor.shape), str(expected_rank)))
def transformer_single(input_tensor, params, layer_idx):
val_off = layer_idx * 16
output = transformer_op_module.bert_transformer(
input_tensor,
input_tensor,
params[val_off + 0], params[val_off + 2], params[val_off + 4], params[val_off + 1], params[val_off + 3], params[val_off + 5], attention_mask,
params[val_off + 6], params[val_off + 7], params[val_off + 8], params[val_off + 9], params[val_off + 10],
params[val_off + 11], params[val_off + 12], params[val_off + 13], params[val_off + 14], params[val_off + 15],
batch_size = batch_size, from_seq_len = seq_len, to_seq_len = seq_len, head_num = head_num, size_per_head = size_per_head)
return output
def transformer_own(input_tensor, params):
in_tensor = input_tensor
for layer_idx in range(num_layers):
out_tensor = transformer_single(in_tensor, params, layer_idx)
in_tensor = out_tensor
return in_tensor
output = transformer_model(input_tensor=from_tensor,
hidden_size = hidden_dim, num_attention_heads = head_num, attention_mask = attention_mask, num_hidden_layers = num_layers, do_return_all_layers=True)
config = tf.ConfigProto()
config.graph_options.optimizer_options.global_jit_level = tf.OptimizerOptions.ON_1
with tf.Session(config=config) as sess:
sess.run(tf.global_variables_initializer())
run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
run_metadata = tf.RunMetadata()
sess.run(output)
Model_variables = tf.GraphKeys.GLOBAL_VARIABLES
idx = 0
all_vars = tf.get_collection(Model_variables)
for var in all_vars:
print (str(idx) + " " + str(var.name) + " " + str(var.shape)) + " " + str(var.dtype)
idx = idx + 1
params = all_vars
output_own = transformer_own(from_tensor, params)
print("#################################")
np_val1 = sess.run(output)
np_val2 = sess.run(output_own)
print("cross_check " + str(np.allclose(np_val1, np_val2, atol = 1e-5)))
print("max diff " + str(np.fabs(np_val1 - np_val2).max()))
print("min diff " + str(np.fabs(np_val1 - np_val2).min()))
print np_val1
print " "
print np_val2
ite = 500
time_sum = 0
a = datetime.now()
for i in range(ite):
sess.run(output)
b = datetime.now()
time_sum = (b - a).total_seconds()
print("original costs " + str(time_sum * 1000 / ite) + " ms")
ite = 500
time_sum = 0
a = datetime.now()
for i in range(ite):
sess.run(output_own)
b = datetime.now()
time_sum = (b - a).total_seconds()
print("optimized costs " + str(time_sum * 1000 / ite) + " ms")
for _ in range(50):
sess.run(output_own)
ret = sess.run(output_own, options=run_options, run_metadata=run_metadata)
from tensorflow.python.client import timeline
fetched_timeline = timeline.Timeline(run_metadata.step_stats)
chrome_trace = fetched_timeline.generate_chrome_trace_format()
with open("transformer_own_fp16.json", 'w') as f:
f.write(chrome_trace)
ret = sess.run(output, options=run_options, run_metadata=run_metadata)
from tensorflow.python.client import timeline
fetched_timeline = timeline.Timeline(run_metadata.step_stats)
chrome_trace = fetched_timeline.generate_chrome_trace_format()
with open("transformer_fp16.json", 'w') as f:
f.write(chrome_trace)

View file

@ -0,0 +1,412 @@
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import tensorflow as tf
import numpy as np
import os
import math
import six
from datetime import datetime
import sys
import time
transformer_op_module = tf.load_op_library(os.path.join('./lib/libtf_fastertransformer.so'))
if len(sys.argv) != 6:
print "python transformer_fp32.py batch_size num_layers seq_len head_num size_per_head"
print "e.g., python transformer_fp32.py 1 12 32 12 64"
sys.exit(0)
argumentList = sys.argv
batch_size = int(sys.argv[1])
num_layers = int(sys.argv[2])
seq_len = int(sys.argv[3])
head_num = int(sys.argv[4])
size_per_head = int(sys.argv[5])
print("Argumentlist: batch_size " + str(batch_size) + " num_layers " + str(num_layers) + " seq_len " + str(seq_len))
hidden_dim = head_num * size_per_head
initializer_range = 0.02
from_data = np.random.randn(batch_size, seq_len, hidden_dim)
from_tensor = tf.convert_to_tensor(from_data, dtype=float)
mask = np.random.randint(2, size=(batch_size, seq_len, seq_len))
attention_mask = tf.convert_to_tensor(mask, dtype=float)
def gelu(x):
cdf = 0.5 * (1.0 + tf.tanh(
(np.sqrt(2 / np.pi) * (x + 0.044715 * tf.pow(x, 3)))))
return x * cdf
def layer_norm(input_tensor, name=None):
return tf.contrib.layers.layer_norm(
inputs=input_tensor, begin_norm_axis=-1, begin_params_axis=-1, scope=name)
def create_initializer(initializer_range=0.02):
return tf.truncated_normal_initializer(stddev=initializer_range)
def attention_layer(from_tensor,
to_tensor,
attention_mask=None,
num_attention_heads=1,
size_per_head=512,
query_act=None,
key_act=None,
value_act=None,
attention_probs_dropout_prob=0.0,
initializer_range=0.02,
do_return_2d_tensor=False,
batch_size=None,
from_seq_length=None,
to_seq_length=None):
def transpose_for_scores(input_tensor, batch_size, num_attention_heads,
seq_length, width):
output_tensor = tf.reshape(
input_tensor, [batch_size, seq_length, num_attention_heads, width])
output_tensor = tf.transpose(output_tensor, [0, 2, 1, 3])
return output_tensor
from_shape = get_shape_list(from_tensor, expected_rank=[2, 3])
to_shape = get_shape_list(to_tensor, expected_rank=[2, 3])
if len(from_shape) != len(to_shape):
raise ValueError(
"The rank of `from_tensor` must match the rank of `to_tensor`.")
if len(from_shape) == 3:
batch_size = from_shape[0]
from_seq_length = from_shape[1]
to_seq_length = to_shape[1]
elif len(from_shape) == 2:
if (batch_size is None or from_seq_length is None or to_seq_length is None):
raise ValueError(
"When passing in rank 2 tensors to attention_layer, the values "
"for `batch_size`, `from_seq_length`, and `to_seq_length` "
"must all be specified.")
from_tensor_2d = reshape_to_matrix(from_tensor)
to_tensor_2d = reshape_to_matrix(to_tensor)
# `query_layer` = [B*F, N*H]
query_layer = tf.layers.dense(
from_tensor_2d,
num_attention_heads * size_per_head,
activation=query_act,
name="query",
use_bias=True,
bias_initializer = create_initializer(initializer_range),
kernel_initializer=create_initializer(initializer_range))
# `key_layer` = [B*T, N*H]
key_layer = tf.layers.dense(
to_tensor_2d,
num_attention_heads * size_per_head,
activation=key_act,
name="key",
use_bias=True,
bias_initializer = create_initializer(initializer_range),
kernel_initializer=create_initializer(initializer_range))
# `value_layer` = [B*T, N*H]
value_layer = tf.layers.dense(
to_tensor_2d,
num_attention_heads * size_per_head,
activation=value_act,
name="value",
use_bias=True,
bias_initializer = create_initializer(initializer_range),
kernel_initializer=create_initializer(initializer_range))
# `query_layer` = [B, N, F, H]
query_layer = transpose_for_scores(query_layer, batch_size,
num_attention_heads, from_seq_length,
size_per_head)
# `key_layer` = [B, N, T, H]
key_layer = transpose_for_scores(key_layer, batch_size, num_attention_heads,
to_seq_length, size_per_head)
attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True)
attention_scores = tf.multiply(attention_scores,
1.0 / math.sqrt(float(size_per_head)))
if attention_mask is not None:
# `attention_mask` = [B, 1, F, T]
attention_mask = tf.expand_dims(attention_mask, axis=[1])
adder = (1.0 - tf.cast(attention_mask, tf.float32)) * -10000.0
attention_scores += adder
attention_probs = tf.nn.softmax(attention_scores)
value_layer = tf.reshape(
value_layer,
[batch_size, to_seq_length, num_attention_heads, size_per_head])
value_layer = tf.transpose(value_layer, [0, 2, 1, 3])
context_layer = tf.matmul(attention_probs, value_layer)
context_layer = tf.transpose(context_layer, [0, 2, 1, 3])
if do_return_2d_tensor:
context_layer = tf.reshape(
context_layer,
[batch_size * from_seq_length, num_attention_heads * size_per_head])
else:
context_layer = tf.reshape(
context_layer,
[batch_size, from_seq_length, num_attention_heads * size_per_head])
return context_layer
def transformer_model(input_tensor,
attention_mask=None,
hidden_size=768,
num_hidden_layers=12,
num_attention_heads=12,
intermediate_size=3072,
intermediate_act_fn=gelu,
initializer_range=0.02,
do_return_all_layers=False):
intermediate_size=hidden_size * 4
if hidden_size % num_attention_heads != 0:
raise ValueError(
"The hidden size (%d) is not a multiple of the number of attention "
"heads (%d)" % (hidden_size, num_attention_heads))
attention_head_size = int(hidden_size / num_attention_heads)
input_shape = get_shape_list(input_tensor, expected_rank=3)
batch_size = input_shape[0]
seq_length = input_shape[1]
input_width = input_shape[2]
prev_output = reshape_to_matrix(input_tensor)
for layer_idx in range(num_hidden_layers):
with tf.variable_scope("layer_%d" % layer_idx, reuse=tf.AUTO_REUSE):
layer_input = prev_output
with tf.variable_scope("attention"):
with tf.variable_scope("self"):
attention_head = attention_layer(
from_tensor=layer_input,
to_tensor=layer_input,
attention_mask=attention_mask,
num_attention_heads=num_attention_heads,
size_per_head=attention_head_size,
initializer_range=initializer_range,
do_return_2d_tensor=True,
batch_size=batch_size,
from_seq_length=seq_length,
to_seq_length=seq_length)
attention_output = attention_head
with tf.variable_scope("output"):
attention_output = tf.layers.dense(
attention_output,
hidden_size,
use_bias=True,
bias_initializer = create_initializer(initializer_range),
kernel_initializer=create_initializer(initializer_range))
attention_output = layer_norm(attention_output + layer_input)
# The activation is only applied to the "intermediate" hidden layer.
with tf.variable_scope("intermediate"):
intermediate_output = tf.layers.dense(
attention_output,
intermediate_size,
activation=intermediate_act_fn,
use_bias=True,
bias_initializer = create_initializer(initializer_range),
kernel_initializer=create_initializer(initializer_range))
# Down-project back to `hidden_size` then add the residual.
with tf.variable_scope("output"):
layer_output = tf.layers.dense(
intermediate_output,
hidden_size,
use_bias=True,
bias_initializer = create_initializer(initializer_range),
kernel_initializer=create_initializer(initializer_range))
layer_output = layer_norm(layer_output + attention_output)
prev_output = layer_output
return prev_output
def get_shape_list(tensor, expected_rank=None, name=None):
if name is None:
name = tensor.name
if expected_rank is not None:
assert_rank(tensor, expected_rank, name)
shape = tensor.shape.as_list()
non_static_indexes = []
for (index, dim) in enumerate(shape):
if dim is None:
non_static_indexes.append(index)
if not non_static_indexes:
return shape
dyn_shape = tf.shape(tensor)
for index in non_static_indexes:
shape[index] = dyn_shape[index]
return shape
def reshape_to_matrix(input_tensor):
"""Reshapes a >= rank 2 tensor to a rank 2 tensor (i.e., a matrix)."""
ndims = input_tensor.shape.ndims
if ndims < 2:
raise ValueError("Input tensor must have at least rank 2. Shape = %s" %
(input_tensor.shape))
if ndims == 2:
return input_tensor
width = input_tensor.shape[-1]
output_tensor = tf.reshape(input_tensor, [-1, width])
return output_tensor
def reshape_from_matrix(output_tensor, orig_shape_list):
if len(orig_shape_list) == 2:
return output_tensor
output_shape = get_shape_list(output_tensor)
orig_dims = orig_shape_list[0:-1]
width = output_shape[-1]
return tf.reshape(output_tensor, orig_dims + [width])
def assert_rank(tensor, expected_rank, name=None):
if name is None:
name = tensor.name
expected_rank_dict = {}
if isinstance(expected_rank, six.integer_types):
expected_rank_dict[expected_rank] = True
else:
for x in expected_rank:
expected_rank_dict[x] = True
actual_rank = tensor.shape.ndims
if actual_rank not in expected_rank_dict:
scope_name = tf.get_variable_scope().name
raise ValueError(
"For the tensor `%s` in scope `%s`, the actual rank "
"`%d` (shape = %s) is not equal to the expected rank `%s`" %
(name, scope_name, actual_rank, str(tensor.shape), str(expected_rank)))
def transformer_single(input_tensor, params, layer_idx):
val_off = layer_idx * 16
output = transformer_op_module.bert_transformer(
input_tensor,
input_tensor,
params[val_off + 0], params[val_off + 2], params[val_off + 4], params[val_off + 1], params[val_off + 3], params[val_off + 5], attention_mask,
params[val_off + 6], params[val_off + 7], params[val_off + 8], params[val_off + 9], params[val_off + 10],
params[val_off + 11], params[val_off + 12], params[val_off + 13], params[val_off + 14], params[val_off + 15],
batch_size = batch_size, from_seq_len = seq_len, to_seq_len = seq_len, head_num = head_num, size_per_head = size_per_head)
return output
def transformer_own(input_tensor, params):
in_tensor = input_tensor
for layer_idx in range(num_layers):
out_tensor = transformer_single(in_tensor, params, layer_idx)
in_tensor = out_tensor
return in_tensor
output = transformer_model(input_tensor=from_tensor,
hidden_size = hidden_dim, num_attention_heads = head_num, attention_mask = attention_mask, num_hidden_layers = num_layers, do_return_all_layers=True)
config = tf.ConfigProto()
config.graph_options.optimizer_options.global_jit_level = tf.OptimizerOptions.ON_1
with tf.Session(config=config) as sess:
sess.run(tf.global_variables_initializer())
run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
run_metadata = tf.RunMetadata()
sess.run(output)
Model_variables = tf.GraphKeys.GLOBAL_VARIABLES
idx = 0
all_vars = tf.get_collection(Model_variables)
for var in all_vars:
print (str(idx) + " " + str(var.name) + " " + str(var.shape))
idx = idx + 1
params = all_vars
output_own = transformer_own(from_tensor, params)
print("#################################")
np_val1 = sess.run(output)
np_val2 = sess.run(output_own)
print("cross_check " + str(np.allclose(np_val1, np_val2, atol = 1e-5)))
print("max diff " + str(np.fabs(np_val1 - np_val2).max()))
print("min diff " + str(np.fabs(np_val1 - np_val2).min()))
run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
run_metadata = tf.RunMetadata()
for _ in range(50):
sess.run(output_own)
ret = sess.run(output_own, options=run_options, run_metadata=run_metadata)
from tensorflow.python.client import timeline
fetched_timeline = timeline.Timeline(run_metadata.step_stats)
chrome_trace = fetched_timeline.generate_chrome_trace_format()
with open("transformer_own.json", 'w') as f:
f.write(chrome_trace)
for _ in range(50):
sess.run(output)
ret = sess.run(output, options=run_options, run_metadata=run_metadata)
from tensorflow.python.client import timeline
fetched_timeline = timeline.Timeline(run_metadata.step_stats)
chrome_trace = fetched_timeline.generate_chrome_trace_format()
with open("transformer.json", 'w') as f:
f.write(chrome_trace)
for _ in range(500):
sess.run(output)
ite = 500
time_sum = 0
for i in range(ite):
a = datetime.now()
sess.run(output)
b = datetime.now()
time_sum += (b - a).total_seconds()
print("original costs " + str(time_sum * 1000 / ite) + " ms")
for _ in range(500):
sess.run(output_own)
time_sum = 0
for i in range(ite):
a = datetime.now()
sess.run(output_own)
b = datetime.now()
time_sum += (b - a).total_seconds()
print("optimized costs " + str(time_sum * 1000 / ite) + " ms")

View file

@ -0,0 +1,15 @@
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
add_subdirectory(gemm_test)

View file

@ -0,0 +1,28 @@
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
cmake_minimum_required(VERSION 3.8)
set(gemm_fp16_files
gemm_fp16.cu
)
set(gemm_fp32_files
gemm_fp32.cu
)
add_executable(gemm_fp32 ${gemm_fp32_files})
target_link_libraries(gemm_fp32 PUBLIC -lcublas -lcudart ${CMAKE_THREAD_LIBS_INIT})
add_executable(gemm_fp16 ${gemm_fp16_files})
target_link_libraries(gemm_fp16 PUBLIC -lcublas -lcudart ${CMAKE_THREAD_LIBS_INIT})

View file

@ -0,0 +1,73 @@
/*
* Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <iostream>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cublas_v2.h>
#include <cstring>
#include <sstream>
static const char *_cudaGetErrorEnum(cudaError_t error) {
return cudaGetErrorString(error);
}
static const char *_cudaGetErrorEnum(cublasStatus_t error) {
switch (error) {
case CUBLAS_STATUS_SUCCESS:
return "CUBLAS_STATUS_SUCCESS";
case CUBLAS_STATUS_NOT_INITIALIZED:
return "CUBLAS_STATUS_NOT_INITIALIZED";
case CUBLAS_STATUS_ALLOC_FAILED:
return "CUBLAS_STATUS_ALLOC_FAILED";
case CUBLAS_STATUS_INVALID_VALUE:
return "CUBLAS_STATUS_INVALID_VALUE";
case CUBLAS_STATUS_ARCH_MISMATCH:
return "CUBLAS_STATUS_ARCH_MISMATCH";
case CUBLAS_STATUS_MAPPING_ERROR:
return "CUBLAS_STATUS_MAPPING_ERROR";
case CUBLAS_STATUS_EXECUTION_FAILED:
return "CUBLAS_STATUS_EXECUTION_FAILED";
case CUBLAS_STATUS_INTERNAL_ERROR:
return "CUBLAS_STATUS_INTERNAL_ERROR";
case CUBLAS_STATUS_NOT_SUPPORTED:
return "CUBLAS_STATUS_NOT_SUPPORTED";
case CUBLAS_STATUS_LICENSE_ERROR:
return "CUBLAS_STATUS_LICENSE_ERROR";
}
return "<unknown>";
}
template <typename T>
void check(T result, char const *const func, const char *const file, int const line) {
if (result) {
std::cout << (std::string("[FT][ERROR] CUDA runtime error: ") + \
(_cudaGetErrorEnum(result)) + " " + file + \
":" + std::to_string(line) + " \n");\
}
}
#define check_cuda_error(val) check((val), #val, __FILE__, __LINE__)

View file

@ -0,0 +1,175 @@
/*
* Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <cstdio>
#include <cstdlib>
#include <cuda_fp16.h>
#include <cuda_profiler_api.h>
#include <ctime>
#include <unistd.h>
#include <sys/time.h>
#include "common.h"
using namespace std;
double diffTime(timeval start, timeval end)
{
return (end.tv_sec - start.tv_sec) * 1000 + (end.tv_usec - start.tv_usec) * 0.001;
}
int main(int argc, char* argv[])
{
FILE* fd = fopen("gemm_config.in", "w");
if(fd == NULL)
{
printf("Cannot write to file gemm_config.in\n");
return 0;
}
struct cudaDeviceProp prop;
cudaGetDeviceProperties(&prop, 0);
printf("Device %s\n", prop.name);
const int batch_size = atoi(argv[1]);
const int seq_len = atoi(argv[2]);
const int head_num = atoi(argv[3]);
const int size_per_head = atoi(argv[4]);
const int gemm_num = 5;
int M[gemm_num];
int N[gemm_num];
int K[gemm_num];
int batchCount[gemm_num] = {1,1,1,1,1};
char mess[gemm_num][256];
//gemm1
M[0] = batch_size * seq_len;
K[0] = head_num * size_per_head;
N[0] = K[0];
strcpy(mess[0], "from_tensor * weightQ/K/V, attr * output_kernel");
//gemm2
M[1] = M[0];
K[1] = K[0];
N[1] = 4 * N[0];
strcpy(mess[1], "attr_output * inter_kernel");
//gemm3
M[2] = M[0];
K[2] = 4 * K[0];
N[2] = N[0];
strcpy(mess[2], "inter_matmul * output_kernel");
M[3] = seq_len;
N[3] = seq_len;
K[3] = size_per_head;
batchCount[3] = batch_size * head_num;
strcpy(mess[3], "attention batched Gemm1");
M[4] = seq_len;
N[4] = size_per_head;
K[4] = seq_len;
batchCount[4] = batch_size * head_num;
strcpy(mess[4], "attention batched Gemm2");
cublasHandle_t cublas_handle;
cublasCreate(&cublas_handle);
typedef __half T;
cudaDataType_t AType = CUDA_R_16F;
cudaDataType_t BType = CUDA_R_16F;
cudaDataType_t CType = CUDA_R_16F;
cudaDataType_t computeType = CUDA_R_16F;
const int ites = 100;
struct timeval start, end;
int startAlgo = (int)CUBLAS_GEMM_DEFAULT_TENSOR_OP;
int endAlgo = (int)CUBLAS_GEMM_ALGO15_TENSOR_OP;
T alpha = (T)1.0f;
T beta = (T)0.0f;
printf("***FP16 Gemm Testing***\n");
for(int i = 0; i < gemm_num; ++i)
{
int m = M[i], n = N[i], k = K[i];
printf("\n-----------------------------\n");
printf("GEMM test %d: [M: %d, K: %d, N: %d] %s\n", i, m, k, n, mess[i]);
T* d_A;
T* d_B;
T* d_C;
check_cuda_error(cudaMalloc((void**)&d_A, sizeof(T) * m * k * batchCount[i]));
check_cuda_error(cudaMalloc((void**)&d_B, sizeof(T) * k * n * batchCount[i]));
check_cuda_error(cudaMalloc((void**)&d_C, sizeof(T) * m * n * batchCount[i]));
float exec_time = 99999.0f;
int fast_algo = 0;
for(int algo = startAlgo; algo <= endAlgo; algo++)
{
cudaDeviceSynchronize();
gettimeofday(&start, NULL);
for(int ite = 0; ite < ites; ++ite)
{
if(i < 3)
{
check_cuda_error(cublasGemmEx(cublas_handle,
CUBLAS_OP_N, CUBLAS_OP_N,
n, m, k,
&alpha,
d_B, BType, n,
d_A, AType, k,
&beta,
d_C, CType, n,
computeType,
static_cast<cublasGemmAlgo_t>(algo)));
}
else if(i == 3)
{
check_cuda_error(cublasGemmStridedBatchedEx(cublas_handle,
CUBLAS_OP_T, CUBLAS_OP_N,
seq_len, seq_len, size_per_head,
&alpha,
d_B, BType, size_per_head, seq_len * size_per_head,
d_A, AType, size_per_head, seq_len * size_per_head,
&beta,
d_C, CType, seq_len, seq_len * seq_len,
batch_size * head_num,
computeType,
static_cast<cublasGemmAlgo_t>(algo)));
}
else
{
check_cuda_error(cublasGemmStridedBatchedEx(cublas_handle,
CUBLAS_OP_N, CUBLAS_OP_N,
size_per_head, seq_len, seq_len,
&alpha,
d_B, BType, size_per_head, seq_len * size_per_head,
d_A, AType, seq_len, seq_len * seq_len,
&beta,
d_C, CType, size_per_head, seq_len * size_per_head,
batch_size * head_num,
computeType,
static_cast<cublasGemmAlgo_t>(algo)));
}
}
cudaDeviceSynchronize();
gettimeofday(&end, NULL);
printf("algo_%d costs %.3fms \n", algo, diffTime(start, end) / ites);
if(diffTime(start, end) / ites < exec_time)
{
exec_time = diffTime(start, end) / ites;
fast_algo = algo;
}
}
printf("fast_algo %d costs %.3f ms\n", fast_algo, exec_time);
fprintf(fd, "%d\n", fast_algo);
}
}

View file

@ -0,0 +1,178 @@
/*
* Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <cstdio>
#include <cstdlib>
#include <cuda_fp16.h>
#include <cuda_profiler_api.h>
#include <ctime>
#include <sys/time.h>
#include "common.h"
using namespace std;
double diffTime(timeval start, timeval end)
{
return (end.tv_sec - start.tv_sec) * 1000 + (end.tv_usec - start.tv_usec) * 0.001;
}
int main(int argc, char* argv[])
{
struct cudaDeviceProp prop;
cudaGetDeviceProperties(&prop, 0);
printf("Device %s\n", prop.name);
FILE* fd = fopen("gemm_config.in", "w");
if(fd == NULL)
{
printf("Cannot write to file gemm_config.in\n");
return 0;
}
const int batch_size = atoi(argv[1]);
const int seq_len = atoi(argv[2]);
const int head_num = atoi(argv[3]);
const int size_per_head = atoi(argv[4]);
const int gemm_num = 5;
int M[gemm_num];
int N[gemm_num];
int K[gemm_num];
int batchCount[gemm_num] = {1,1,1,1,1};
char mess[gemm_num][256];
//gemm1
M[0] = batch_size * seq_len;
K[0] = head_num * size_per_head;
N[0] = K[0];
strcpy(mess[0], "from_tensor * weightQ/K/V, attr * output_kernel");
//gemm2
M[1] = M[0];
K[1] = K[0];
N[1] = 4 * N[0];
strcpy(mess[1], "attr_output * inter_kernel");
//gemm3
M[2] = M[0];
K[2] = 4 * K[0];
N[2] = N[0];
strcpy(mess[2], "inter_matmul * output_kernel");
M[3] = seq_len;
N[3] = seq_len;
K[3] = size_per_head;
batchCount[3] = batch_size * head_num;
strcpy(mess[3], "attention batched Gemm1");
M[4] = seq_len;
N[4] = size_per_head;
K[4] = seq_len;
batchCount[4] = batch_size * head_num;
strcpy(mess[4], "attention batched Gemm2");
cublasHandle_t cublas_handle;
cublasCreate(&cublas_handle);
typedef float T;
cudaDataType_t AType = CUDA_R_32F;
cudaDataType_t BType = CUDA_R_32F;
cudaDataType_t CType = CUDA_R_32F;
cudaDataType_t computeType = CUDA_R_32F;
const int ites = 100;
struct timeval start, end;
int startAlgo = (int)CUBLAS_GEMM_DEFAULT;
int endAlgo = (int)CUBLAS_GEMM_ALGO23;
T alpha = (T)1.0f;
T beta = (T)0.0f;
printf("***FP32 Gemm Testing***\n");
for(int i = 0; i < gemm_num; ++i)
{
int m = M[i], n = N[i], k = K[i];
printf("\n-----------------------------\n");
printf("GEMM test %d: [M: %d, K: %d, N: %d] %s\n", i, m, k, n, mess[i]);
T* d_A;
T* d_B;
T* d_C;
check_cuda_error(cudaMalloc((void**)&d_A, sizeof(T) * m * k * batchCount[i]));
check_cuda_error(cudaMalloc((void**)&d_B, sizeof(T) * k * n * batchCount[i]));
check_cuda_error(cudaMalloc((void**)&d_C, sizeof(T) * m * n * batchCount[i]));
float exec_time = 99999.0f;
int fast_algo = 0;
for(int algo = startAlgo; algo <= endAlgo; algo++)
{
cublasStatus_t status;
cudaDeviceSynchronize();
gettimeofday(&start, NULL);
for(int ite = 0; ite < ites; ++ite)
{
if(i < 3)
{
status = cublasGemmEx(cublas_handle,
CUBLAS_OP_N, CUBLAS_OP_N,
n, m, k,
&alpha,
d_B, BType, n,
d_A, AType, k,
&beta,
d_C, CType, n,
computeType,
static_cast<cublasGemmAlgo_t>(algo));
}
else if(i == 3)
{
status = cublasGemmStridedBatchedEx(cublas_handle,
CUBLAS_OP_T, CUBLAS_OP_N,
seq_len, seq_len, size_per_head,
&alpha,
d_B, BType, size_per_head, seq_len * size_per_head,
d_A, AType, size_per_head, seq_len * size_per_head,
&beta,
d_C, CType, seq_len, seq_len * seq_len,
batch_size * head_num,
computeType,
static_cast<cublasGemmAlgo_t>(algo));
}
else
{
status = cublasGemmStridedBatchedEx(cublas_handle,
CUBLAS_OP_N, CUBLAS_OP_N,
size_per_head, seq_len, seq_len,
&alpha,
d_B, BType, size_per_head, seq_len * size_per_head,
d_A, AType, seq_len, seq_len * seq_len,
&beta,
d_C, CType, size_per_head, seq_len * size_per_head,
batch_size * head_num,
computeType,
static_cast<cublasGemmAlgo_t>(algo));
}
}
cudaDeviceSynchronize();
gettimeofday(&end, NULL);
if(status == CUBLAS_STATUS_SUCCESS)
{
printf("algo_%d costs %.3fms \n", algo, diffTime(start, end) / ites);
if(diffTime(start, end) / ites < exec_time)
{
exec_time = diffTime(start, end) / ites;
fast_algo = algo;
}
}
}
printf("fast_algo %d costs %.3f ms\n", fast_algo, exec_time);
fprintf(fd, "%d\n", fast_algo);
}
}