fix softmax max_value
This commit is contained in:
parent
488cc11967
commit
fd852b56a0
|
@ -204,7 +204,7 @@ void softmax_kernel(T* qk_buf_, const T* attr_mask, const int batch_size, const
|
|||
|
||||
mask_val = (1.0f - mask_val) * -10000.0f;
|
||||
|
||||
float tmp = threadIdx.x < seq_len ? (float)(qk * (float)scaler + mask_val): -1e-20f;
|
||||
float tmp = threadIdx.x < seq_len ? (float)(qk * (float)scaler + mask_val): -1e20f;
|
||||
|
||||
float max_val = blockReduceMax<float>(tmp);
|
||||
|
||||
|
@ -248,7 +248,7 @@ void softmax_kernel_v2(T* qk_buf_, const T* attr_mask, const int batch_size, con
|
|||
|
||||
mask_val = (1.0f - mask_val) * -10000.0f;
|
||||
|
||||
float tmp = threadIdx.x < seq_len ? (float)(qk * (float)scaler + mask_val) : -1e-20f;
|
||||
float tmp = threadIdx.x < seq_len ? (float)(qk * (float)scaler + mask_val) : -1e20f;
|
||||
float max_val = blockReduceMax<float>(tmp);
|
||||
if(threadIdx.x == 0)
|
||||
s_max = max_val;
|
||||
|
|
Loading…
Reference in a new issue