refine softmax
This commit is contained in:
parent
b6fb9aa463
commit
488cc11967
|
@ -61,6 +61,39 @@ T blockReduceSum(T val)
|
|||
|
||||
return val;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__inline__ __device__
|
||||
T warpReduceMax(T val)
|
||||
{
|
||||
for(int mask = 16; mask > 0; mask >>= 1)
|
||||
val = max(val, __shfl_xor_sync(FINAL_MASK, val, mask, 32));
|
||||
return val;
|
||||
}
|
||||
|
||||
/* Calculate the maximum of all elements in a block */
|
||||
template <typename T>
|
||||
__inline__ __device__
|
||||
T blockReduceMax(T val)
|
||||
{
|
||||
static __shared__ T shared[32];
|
||||
int lane = threadIdx.x & 0x1f; // in-warp idx
|
||||
int wid = threadIdx.x >> 5; // warp idx
|
||||
|
||||
val = warpReduceMax(val); // get maxx in each warp
|
||||
|
||||
if(lane == 0) // record in-warp maxx by warp Idx
|
||||
shared[wid] = val;
|
||||
|
||||
__syncthreads();
|
||||
|
||||
|
||||
val = (threadIdx.x < (blockDim.x >> 5 )) ? shared[lane] : 0;
|
||||
val = warpReduceMax(val);
|
||||
|
||||
return val;
|
||||
}
|
||||
|
||||
__inline__ __device__
|
||||
int target_index(int id1, int id2, int id3, int id4, int dim_1, int dim_2, int dim_3, int dim_4)
|
||||
{
|
||||
|
@ -162,16 +195,25 @@ void softmax_kernel(T* qk_buf_, const T* attr_mask, const int batch_size, const
|
|||
int qk_offset = blockIdx.x * seq_len * seq_len;
|
||||
int mask_offset = batch_id * seq_len * seq_len;
|
||||
|
||||
__shared__ float s_sum;
|
||||
__shared__ float s_sum, s_max;
|
||||
|
||||
for(int i = 0; i < seq_len; ++i)
|
||||
{
|
||||
T qk = threadIdx.x < seq_len ? qk_buf_[threadIdx.x + qk_offset] : (T)(0.0f);
|
||||
T mask_val = threadIdx.x < seq_len ? attr_mask[threadIdx.x + mask_offset] : (T)(0.0f);
|
||||
float qk = threadIdx.x < seq_len ? (float)qk_buf_[threadIdx.x + qk_offset] : 0.0f;
|
||||
float mask_val = threadIdx.x < seq_len ? (float)attr_mask[threadIdx.x + mask_offset] : 0.0f;
|
||||
|
||||
mask_val = ((T)1.0f - mask_val) * (T)(-10000.0f);
|
||||
mask_val = (1.0f - mask_val) * -10000.0f;
|
||||
|
||||
float tmp = threadIdx.x < seq_len ? (float)(qk * (float)scaler + mask_val): -1e-20f;
|
||||
|
||||
float max_val = blockReduceMax<float>(tmp);
|
||||
|
||||
if(threadIdx.x == 0)
|
||||
s_max = max_val;
|
||||
__syncthreads();
|
||||
|
||||
qk = threadIdx.x < seq_len ? __expf(tmp - s_max) : 0.0f;
|
||||
|
||||
qk = __expf((float)(qk * scaler + mask_val));
|
||||
float sum_val = blockReduceSum<float>(qk);
|
||||
|
||||
if(threadIdx.x == 0)
|
||||
|
@ -181,7 +223,7 @@ void softmax_kernel(T* qk_buf_, const T* attr_mask, const int batch_size, const
|
|||
__syncthreads();
|
||||
|
||||
if(threadIdx.x < seq_len)
|
||||
qk_buf_[threadIdx.x + qk_offset] = qk / (T)s_sum;
|
||||
qk_buf_[threadIdx.x + qk_offset] = (T)(qk / s_sum);
|
||||
|
||||
qk_offset += seq_len;
|
||||
mask_offset += seq_len;
|
||||
|
@ -192,21 +234,27 @@ void softmax_kernel(T* qk_buf_, const T* attr_mask, const int batch_size, const
|
|||
template <typename T>
|
||||
__global__
|
||||
void softmax_kernel_v2(T* qk_buf_, const T* attr_mask, const int batch_size, const int head_num,
|
||||
const int seq_len, const T scaler)
|
||||
const int seq_len, const float scaler)
|
||||
{
|
||||
int batch_id = blockIdx.x / head_num / seq_len;
|
||||
int seq_id = blockIdx.x % seq_len;
|
||||
int qk_offset = blockIdx.x * seq_len;
|
||||
int mask_offset = batch_id * seq_len * seq_len + seq_id * seq_len;
|
||||
|
||||
__shared__ float s_sum;
|
||||
__shared__ float s_sum, s_max;
|
||||
|
||||
T qk = threadIdx.x < seq_len ? qk_buf_[threadIdx.x + qk_offset] : (T)(0.0f);
|
||||
T mask_val = threadIdx.x < seq_len ? attr_mask[threadIdx.x + mask_offset] : (T)(0.0f);
|
||||
float qk = threadIdx.x < seq_len ? (float)qk_buf_[threadIdx.x + qk_offset] : 0.0f;
|
||||
float mask_val = threadIdx.x < seq_len ? (float)attr_mask[threadIdx.x + mask_offset] : 0.0f;
|
||||
|
||||
mask_val = ((T)1.0f - mask_val) * (T)(-10000.0f);
|
||||
mask_val = (1.0f - mask_val) * -10000.0f;
|
||||
|
||||
float qk_tmp = threadIdx.x < seq_len ? __expf((float)(qk * scaler + mask_val)) : 0.0f;
|
||||
float tmp = threadIdx.x < seq_len ? (float)(qk * (float)scaler + mask_val) : -1e-20f;
|
||||
float max_val = blockReduceMax<float>(tmp);
|
||||
if(threadIdx.x == 0)
|
||||
s_max = max_val;
|
||||
__syncthreads();
|
||||
|
||||
float qk_tmp = threadIdx.x < seq_len ? __expf((float)(tmp - s_max)) : 0.0f;
|
||||
float sum_val = blockReduceSum<float>(qk_tmp);
|
||||
|
||||
if(threadIdx.x == 0)
|
||||
|
|
Loading…
Reference in a new issue