From bc98835b3cc7b1ddd5b14f133ba74df5e5549014 Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Thu, 15 Apr 2021 10:07:17 -0700 Subject: [PATCH] ircd::simt: Simplify interface; internalize workitem functions. --- include/ircd/simt/broadcast.h | 8 +++++--- include/ircd/simt/mean.h | 24 ++++++++++++++---------- include/ircd/simt/norm.h | 20 +++++++++++--------- include/ircd/simt/reduce_add.h | 8 +++++--- include/ircd/simt/reduce_max.h | 8 +++++--- include/ircd/simt/sort.h | 8 +++++--- 6 files changed, 45 insertions(+), 31 deletions(-) diff --git a/include/ircd/simt/broadcast.h b/include/ircd/simt/broadcast.h index b5e6a79cb..fb7aab718 100644 --- a/include/ircd/simt/broadcast.h +++ b/include/ircd/simt/broadcast.h @@ -11,10 +11,12 @@ /// Broadcast originating from the local leader (index [0]). All threads in the /// group participate. inline void -ircd_simt_broadcast_f4lldr(__local float4 *const buf, - const uint ln, - const uint li) +ircd_simt_broadcast_f4lldr(__local float4 *const buf) { + const uint + li = get_local_id(0), + ln = get_local_size(0); + for(uint stride = 1; stride < ln; stride <<= 1) { if(li < stride) diff --git a/include/ircd/simt/mean.h b/include/ircd/simt/mean.h index 4a27cf822..2ec178cbf 100644 --- a/include/ircd/simt/mean.h +++ b/include/ircd/simt/mean.h @@ -10,20 +10,24 @@ /// Compute average of all elements in the input. The result is broadcast /// to all elements of the output. +/// +/// provide: +/// li = local thread id +/// ln = local group size +/// inline void ircd_simt_math_mean_f4lldr(__local float4 *const restrict out, - __local const float4 *const restrict in, - const uint num, - const uint i) + __local const float4 *const restrict in) { - out[i] = in[i]; - ircd_simt_reduce_add_f4lldr(out, num, i); + const uint + li = get_local_id(0), + ln = get_local_size(0); - if(i == 0) - out[i][0] = ircd_simt_reduce_add_f4(out[i]); + out[li] = in[li]; + ircd_simt_reduce_add_f4lldr(out); - if(i == 0) - out[i] = out[i][0] / (num * 4); + if(li == 0) + out[li] = ircd_simt_reduce_add_f4(out[li]) / (ln * 4); - ircd_simt_broadcast_f4lldr(out, num, i); + ircd_simt_broadcast_f4lldr(out); } diff --git a/include/ircd/simt/norm.h b/include/ircd/simt/norm.h index ce303dad7..256ec150e 100644 --- a/include/ircd/simt/norm.h +++ b/include/ircd/simt/norm.h @@ -13,21 +13,23 @@ inline void ircd_simt_math_norm_f4lldr(__local float4 *const out, __local const float4 *const in, - __local float4 *const restrict tmp, - const uint num, - const uint i) + __local float4 *const restrict tmp) { - ircd_simt_math_mean_f4lldr(tmp, in, num, i); + const uint + li = get_local_id(0), + ln = get_local_size(0); + + ircd_simt_math_mean_f4lldr(tmp, in); const float4 - sub_mean = in[i] - tmp[i]; + sub_mean = in[li] - tmp[li]; - tmp[i] = pow(sub_mean, 2); - ircd_simt_math_mean_f4lldr(out, tmp, num, i); + tmp[li] = pow(sub_mean, 2); + ircd_simt_math_mean_f4lldr(out, tmp); const float4 epsilon = 0.00001f, - s = sqrt(out[i] + epsilon); + s = sqrt(out[li] + epsilon); - out[i] = sub_mean / s; + out[li] = sub_mean / s; } diff --git a/include/ircd/simt/reduce_add.h b/include/ircd/simt/reduce_add.h index 2a13c5166..f210c1416 100644 --- a/include/ircd/simt/reduce_add.h +++ b/include/ircd/simt/reduce_add.h @@ -11,10 +11,12 @@ /// Sum all elements in the buffer. All threads in the group participate; /// result is placed in index [0], the rest of the buffer is trashed. inline void -ircd_simt_reduce_add_f4lldr(__local float4 *const buf, - const uint ln, - const uint li) +ircd_simt_reduce_add_f4lldr(__local float4 *const buf) { + const uint + li = get_local_id(0), + ln = get_local_size(0); + for(uint stride = ln >> 1; stride > 0; stride >>= 1) { barrier(CLK_LOCAL_MEM_FENCE); diff --git a/include/ircd/simt/reduce_max.h b/include/ircd/simt/reduce_max.h index f8ff577de..985bc0812 100644 --- a/include/ircd/simt/reduce_max.h +++ b/include/ircd/simt/reduce_max.h @@ -12,10 +12,12 @@ /// the greatest value is placed in index [0], the rest of the buffer is /// trashed. inline void -ircd_simt_reduce_max_flldr(__local float *const buf, - const uint ln, - const uint li) +ircd_simt_reduce_max_flldr(__local float *const buf) { + const uint + li = get_local_id(0), + ln = get_local_size(0); + for(uint stride = ln >> 1; stride > 0; stride >>= 1) { barrier(CLK_LOCAL_MEM_FENCE); diff --git a/include/ircd/simt/sort.h b/include/ircd/simt/sort.h index d00176ecb..7264378a8 100644 --- a/include/ircd/simt/sort.h +++ b/include/ircd/simt/sort.h @@ -11,10 +11,12 @@ /// Sort indices in `idx` which point to values contained in `val`. inline void ircd_simt_sort_idx16_flldr(__local ushort *const idx, - __global const float *const val, - const uint ln, - const uint li) + __global const float *const val) { + const uint + li = get_local_id(0), + ln = get_local_size(0); + for(uint stride = ln >> 1; stride > 0; stride >>= 1) { barrier(CLK_LOCAL_MEM_FENCE);