refine softmax

This commit is contained in:
xjia 2019-08-26 09:08:59 +00:00
parent b6fb9aa463
commit 488cc11967

View file

@ -61,6 +61,39 @@ T blockReduceSum(T val)
return val;
}
template <typename T>
__inline__ __device__
T warpReduceMax(T val)
{
for(int mask = 16; mask > 0; mask >>= 1)
val = max(val, __shfl_xor_sync(FINAL_MASK, val, mask, 32));
return val;
}
/* Calculate the maximum of all elements in a block */
template <typename T>
__inline__ __device__
T blockReduceMax(T val)
{
static __shared__ T shared[32];
int lane = threadIdx.x & 0x1f; // in-warp idx
int wid = threadIdx.x >> 5; // warp idx
val = warpReduceMax(val); // get maxx in each warp
if(lane == 0) // record in-warp maxx by warp Idx
shared[wid] = val;
__syncthreads();
val = (threadIdx.x < (blockDim.x >> 5 )) ? shared[lane] : 0;
val = warpReduceMax(val);
return val;
}
__inline__ __device__
int target_index(int id1, int id2, int id3, int id4, int dim_1, int dim_2, int dim_3, int dim_4)
{
@ -162,16 +195,25 @@ void softmax_kernel(T* qk_buf_, const T* attr_mask, const int batch_size, const
int qk_offset = blockIdx.x * seq_len * seq_len;
int mask_offset = batch_id * seq_len * seq_len;
__shared__ float s_sum;
__shared__ float s_sum, s_max;
for(int i = 0; i < seq_len; ++i)
{
T qk = threadIdx.x < seq_len ? qk_buf_[threadIdx.x + qk_offset] : (T)(0.0f);
T mask_val = threadIdx.x < seq_len ? attr_mask[threadIdx.x + mask_offset] : (T)(0.0f);
float qk = threadIdx.x < seq_len ? (float)qk_buf_[threadIdx.x + qk_offset] : 0.0f;
float mask_val = threadIdx.x < seq_len ? (float)attr_mask[threadIdx.x + mask_offset] : 0.0f;
mask_val = ((T)1.0f - mask_val) * (T)(-10000.0f);
mask_val = (1.0f - mask_val) * -10000.0f;
float tmp = threadIdx.x < seq_len ? (float)(qk * (float)scaler + mask_val): -1e-20f;
float max_val = blockReduceMax<float>(tmp);
if(threadIdx.x == 0)
s_max = max_val;
__syncthreads();
qk = threadIdx.x < seq_len ? __expf(tmp - s_max) : 0.0f;
qk = __expf((float)(qk * scaler + mask_val));
float sum_val = blockReduceSum<float>(qk);
if(threadIdx.x == 0)
@ -181,7 +223,7 @@ void softmax_kernel(T* qk_buf_, const T* attr_mask, const int batch_size, const
__syncthreads();
if(threadIdx.x < seq_len)
qk_buf_[threadIdx.x + qk_offset] = qk / (T)s_sum;
qk_buf_[threadIdx.x + qk_offset] = (T)(qk / s_sum);
qk_offset += seq_len;
mask_offset += seq_len;
@ -192,21 +234,27 @@ void softmax_kernel(T* qk_buf_, const T* attr_mask, const int batch_size, const
template <typename T>
__global__
void softmax_kernel_v2(T* qk_buf_, const T* attr_mask, const int batch_size, const int head_num,
const int seq_len, const T scaler)
const int seq_len, const float scaler)
{
int batch_id = blockIdx.x / head_num / seq_len;
int seq_id = blockIdx.x % seq_len;
int qk_offset = blockIdx.x * seq_len;
int mask_offset = batch_id * seq_len * seq_len + seq_id * seq_len;
__shared__ float s_sum;
__shared__ float s_sum, s_max;
T qk = threadIdx.x < seq_len ? qk_buf_[threadIdx.x + qk_offset] : (T)(0.0f);
T mask_val = threadIdx.x < seq_len ? attr_mask[threadIdx.x + mask_offset] : (T)(0.0f);
float qk = threadIdx.x < seq_len ? (float)qk_buf_[threadIdx.x + qk_offset] : 0.0f;
float mask_val = threadIdx.x < seq_len ? (float)attr_mask[threadIdx.x + mask_offset] : 0.0f;
mask_val = ((T)1.0f - mask_val) * (T)(-10000.0f);
mask_val = (1.0f - mask_val) * -10000.0f;
float qk_tmp = threadIdx.x < seq_len ? __expf((float)(qk * scaler + mask_val)) : 0.0f;
float tmp = threadIdx.x < seq_len ? (float)(qk * (float)scaler + mask_val) : -1e-20f;
float max_val = blockReduceMax<float>(tmp);
if(threadIdx.x == 0)
s_max = max_val;
__syncthreads();
float qk_tmp = threadIdx.x < seq_len ? __expf((float)(tmp - s_max)) : 0.0f;
float sum_val = blockReduceSum<float>(qk_tmp);
if(threadIdx.x == 0)