fix softmax max_value

This commit is contained in:
xjia 2019-09-11 07:09:53 +00:00
parent 488cc11967
commit fd852b56a0

View file

@ -204,7 +204,7 @@ void softmax_kernel(T* qk_buf_, const T* attr_mask, const int batch_size, const
mask_val = (1.0f - mask_val) * -10000.0f;
float tmp = threadIdx.x < seq_len ? (float)(qk * (float)scaler + mask_val): -1e-20f;
float tmp = threadIdx.x < seq_len ? (float)(qk * (float)scaler + mask_val): -1e20f;
float max_val = blockReduceMax<float>(tmp);
@ -248,7 +248,7 @@ void softmax_kernel_v2(T* qk_buf_, const T* attr_mask, const int batch_size, con
mask_val = (1.0f - mask_val) * -10000.0f;
float tmp = threadIdx.x < seq_len ? (float)(qk * (float)scaler + mask_val) : -1e-20f;
float tmp = threadIdx.x < seq_len ? (float)(qk * (float)scaler + mask_val) : -1e20f;
float max_val = blockReduceMax<float>(tmp);
if(threadIdx.x == 0)
s_max = max_val;