mirror of
https://github.com/matrix-construct/construct
synced 2025-01-13 16:33:53 +01:00
ircd::simt: Simplify interface; internalize workitem functions.
This commit is contained in:
parent
075b40400a
commit
bc98835b3c
6 changed files with 45 additions and 31 deletions
|
@ -11,10 +11,12 @@
|
|||
/// Broadcast originating from the local leader (index [0]). All threads in the
|
||||
/// group participate.
|
||||
inline void
|
||||
ircd_simt_broadcast_f4lldr(__local float4 *const buf,
|
||||
const uint ln,
|
||||
const uint li)
|
||||
ircd_simt_broadcast_f4lldr(__local float4 *const buf)
|
||||
{
|
||||
const uint
|
||||
li = get_local_id(0),
|
||||
ln = get_local_size(0);
|
||||
|
||||
for(uint stride = 1; stride < ln; stride <<= 1)
|
||||
{
|
||||
if(li < stride)
|
||||
|
|
|
@ -10,20 +10,24 @@
|
|||
|
||||
/// Compute average of all elements in the input. The result is broadcast
|
||||
/// to all elements of the output.
|
||||
///
|
||||
/// provide:
|
||||
/// li = local thread id
|
||||
/// ln = local group size
|
||||
///
|
||||
inline void
|
||||
ircd_simt_math_mean_f4lldr(__local float4 *const restrict out,
|
||||
__local const float4 *const restrict in,
|
||||
const uint num,
|
||||
const uint i)
|
||||
__local const float4 *const restrict in)
|
||||
{
|
||||
out[i] = in[i];
|
||||
ircd_simt_reduce_add_f4lldr(out, num, i);
|
||||
const uint
|
||||
li = get_local_id(0),
|
||||
ln = get_local_size(0);
|
||||
|
||||
if(i == 0)
|
||||
out[i][0] = ircd_simt_reduce_add_f4(out[i]);
|
||||
out[li] = in[li];
|
||||
ircd_simt_reduce_add_f4lldr(out);
|
||||
|
||||
if(i == 0)
|
||||
out[i] = out[i][0] / (num * 4);
|
||||
if(li == 0)
|
||||
out[li] = ircd_simt_reduce_add_f4(out[li]) / (ln * 4);
|
||||
|
||||
ircd_simt_broadcast_f4lldr(out, num, i);
|
||||
ircd_simt_broadcast_f4lldr(out);
|
||||
}
|
||||
|
|
|
@ -13,21 +13,23 @@
|
|||
inline void
|
||||
ircd_simt_math_norm_f4lldr(__local float4 *const out,
|
||||
__local const float4 *const in,
|
||||
__local float4 *const restrict tmp,
|
||||
const uint num,
|
||||
const uint i)
|
||||
__local float4 *const restrict tmp)
|
||||
{
|
||||
ircd_simt_math_mean_f4lldr(tmp, in, num, i);
|
||||
const uint
|
||||
li = get_local_id(0),
|
||||
ln = get_local_size(0);
|
||||
|
||||
ircd_simt_math_mean_f4lldr(tmp, in);
|
||||
|
||||
const float4
|
||||
sub_mean = in[i] - tmp[i];
|
||||
sub_mean = in[li] - tmp[li];
|
||||
|
||||
tmp[i] = pow(sub_mean, 2);
|
||||
ircd_simt_math_mean_f4lldr(out, tmp, num, i);
|
||||
tmp[li] = pow(sub_mean, 2);
|
||||
ircd_simt_math_mean_f4lldr(out, tmp);
|
||||
|
||||
const float4
|
||||
epsilon = 0.00001f,
|
||||
s = sqrt(out[i] + epsilon);
|
||||
s = sqrt(out[li] + epsilon);
|
||||
|
||||
out[i] = sub_mean / s;
|
||||
out[li] = sub_mean / s;
|
||||
}
|
||||
|
|
|
@ -11,10 +11,12 @@
|
|||
/// Sum all elements in the buffer. All threads in the group participate;
|
||||
/// result is placed in index [0], the rest of the buffer is trashed.
|
||||
inline void
|
||||
ircd_simt_reduce_add_f4lldr(__local float4 *const buf,
|
||||
const uint ln,
|
||||
const uint li)
|
||||
ircd_simt_reduce_add_f4lldr(__local float4 *const buf)
|
||||
{
|
||||
const uint
|
||||
li = get_local_id(0),
|
||||
ln = get_local_size(0);
|
||||
|
||||
for(uint stride = ln >> 1; stride > 0; stride >>= 1)
|
||||
{
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
|
|
|
@ -12,10 +12,12 @@
|
|||
/// the greatest value is placed in index [0], the rest of the buffer is
|
||||
/// trashed.
|
||||
inline void
|
||||
ircd_simt_reduce_max_flldr(__local float *const buf,
|
||||
const uint ln,
|
||||
const uint li)
|
||||
ircd_simt_reduce_max_flldr(__local float *const buf)
|
||||
{
|
||||
const uint
|
||||
li = get_local_id(0),
|
||||
ln = get_local_size(0);
|
||||
|
||||
for(uint stride = ln >> 1; stride > 0; stride >>= 1)
|
||||
{
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
|
|
|
@ -11,10 +11,12 @@
|
|||
/// Sort indices in `idx` which point to values contained in `val`.
|
||||
inline void
|
||||
ircd_simt_sort_idx16_flldr(__local ushort *const idx,
|
||||
__global const float *const val,
|
||||
const uint ln,
|
||||
const uint li)
|
||||
__global const float *const val)
|
||||
{
|
||||
const uint
|
||||
li = get_local_id(0),
|
||||
ln = get_local_size(0);
|
||||
|
||||
for(uint stride = ln >> 1; stride > 0; stride >>= 1)
|
||||
{
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
|
|
Loading…
Reference in a new issue