mirror of
https://github.com/matrix-construct/construct
synced 2024-11-29 18:22:50 +01:00
ircd::gpt::pipe: Consolidate loop; improve access pattern; various reorg.
This commit is contained in:
parent
1f49f71530
commit
536594b487
1 changed files with 59 additions and 42 deletions
101
ircd/gpt_cl.cl
101
ircd/gpt_cl.cl
|
@ -10,6 +10,7 @@
|
|||
|
||||
|
||||
inline void
|
||||
__attribute__((always_inline))
|
||||
ircd_gpt_norm_fmad(__local float4 *const out,
|
||||
__local const float4 *const in,
|
||||
__global const float4 *const restrict bias,
|
||||
|
@ -19,40 +20,6 @@ ircd_gpt_norm_fmad(__local float4 *const out,
|
|||
out[i] = in[i] * weight[i] + bias[i];
|
||||
}
|
||||
|
||||
// Matrix * Vector Multiply/Accumulate
|
||||
inline void
|
||||
ircd_gpt_sgemv(__local float4 *const restrict out,
|
||||
__local const float4 *const restrict in,
|
||||
__global const float4 *const restrict bias,
|
||||
__global const float4 *const restrict weight,
|
||||
const uint width,
|
||||
const uint height,
|
||||
const uint i)
|
||||
{
|
||||
const uint
|
||||
lanes = 4;
|
||||
|
||||
float4
|
||||
acc = bias[i];
|
||||
|
||||
for(uint j = 0; j < height; ++j)
|
||||
{
|
||||
const uint
|
||||
tile = j * lanes;
|
||||
|
||||
for(uint k = 0; k < lanes; ++k)
|
||||
{
|
||||
const uint
|
||||
row = tile + k,
|
||||
cell = row * width + i;
|
||||
|
||||
acc += in[j][k] * weight[cell];
|
||||
}
|
||||
}
|
||||
|
||||
out[i] = acc;
|
||||
}
|
||||
|
||||
/// Gaussian Error Linear Unit
|
||||
inline void
|
||||
ircd_gpt_ffnn_gelu(__local float4 *const out,
|
||||
|
@ -78,6 +45,50 @@ ircd_gpt_ffnn_gelu(__local float4 *const out,
|
|||
out[i] = a;
|
||||
}
|
||||
|
||||
// Matrix * Vector Multiply/Accumulate
|
||||
inline void
|
||||
ircd_gpt_sgemv(__local float4 *const restrict out,
|
||||
__local const float4 *const restrict in,
|
||||
__global const float4 *const restrict bias,
|
||||
__global const float4 *const restrict weight,
|
||||
const uint width,
|
||||
const uint height,
|
||||
const uint segs)
|
||||
{
|
||||
const uint
|
||||
li = get_local_id(0),
|
||||
ln = get_local_size(0),
|
||||
lanes = 4;
|
||||
|
||||
__attribute__((opencl_unroll_hint))
|
||||
for(uint i = 0; i < segs; ++i)
|
||||
{
|
||||
const uint
|
||||
col = i * ln + li;
|
||||
|
||||
out[col] = bias[col];
|
||||
}
|
||||
|
||||
for(uint j = 0; j < height; ++j)
|
||||
for(uint i = 0; i < segs; ++i)
|
||||
{
|
||||
const uint
|
||||
col = i * ln + li;
|
||||
|
||||
float4 acc = 0.0f;
|
||||
for(uint k = 0; k < lanes; ++k)
|
||||
{
|
||||
const uint
|
||||
row = j * lanes + k,
|
||||
cell = row * width + col;
|
||||
|
||||
acc += in[j][k] * weight[cell];
|
||||
}
|
||||
|
||||
out[col] += acc;
|
||||
}
|
||||
}
|
||||
|
||||
inline void
|
||||
__attribute__((always_inline))
|
||||
ircd_gpt_ffnn_fcon(__global const struct ircd_gpt_task *const ctrl,
|
||||
|
@ -94,8 +105,7 @@ ircd_gpt_ffnn_fcon(__global const struct ircd_gpt_task *const ctrl,
|
|||
height = opts->ffnn_height,
|
||||
tiles = opts->ffnn_mult;
|
||||
|
||||
for(uint i = 0; i < tiles; ++i)
|
||||
ircd_gpt_sgemv(out->fcon, in->word, bias, weight, width, height, i * ln + li);
|
||||
ircd_gpt_sgemv(out->fcon, in->word, bias, weight, width, height, tiles);
|
||||
|
||||
for(uint i = 0; i < tiles; ++i)
|
||||
ircd_gpt_ffnn_gelu(out->fcon, out->fcon, i * ln + li);
|
||||
|
@ -107,6 +117,8 @@ ircd_gpt_ffnn(__global const struct ircd_gpt_task *const ctrl,
|
|||
__constant const struct ircd_gpt_opts *const opts,
|
||||
__local union ircd_gpt_tokenv *const restrict token,
|
||||
__local union ircd_gpt_ffnn_aperaturev *const restrict buf,
|
||||
__local union ircd_gpt_ffnn_aperaturev *const restrict tmp0,
|
||||
__local union ircd_gpt_tokenv *const restrict tmp1,
|
||||
__global const float4 *const restrict norm_bias,
|
||||
__global const float4 *const restrict norm_weight,
|
||||
__global const float4 *const restrict fcon_bias,
|
||||
|
@ -138,7 +150,7 @@ ircd_gpt_ffnn(__global const struct ircd_gpt_task *const ctrl,
|
|||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
|
||||
// Projection
|
||||
ircd_gpt_sgemv(token->word, buf->fcon, proj_bias, proj_weight, height, width, li);
|
||||
ircd_gpt_sgemv(token->word, buf->fcon, proj_bias, proj_weight, height, width, 1);
|
||||
}
|
||||
|
||||
inline void
|
||||
|
@ -197,12 +209,14 @@ ircd_gpt_attn_self(__global const struct ircd_gpt_task *const ctrl,
|
|||
self[i][li] = exp(self[i][li] - mu);
|
||||
|
||||
float sum = 0.0f;
|
||||
__attribute__((opencl_unroll_hint))
|
||||
for(uint i = 0; i < wn; ++i)
|
||||
sum += self[i][li];
|
||||
|
||||
const float
|
||||
lambda = 1.0f / sum;
|
||||
|
||||
__attribute__((opencl_unroll_hint))
|
||||
for(uint i = 0; i < wn; ++i)
|
||||
self[i][li] *= lambda;
|
||||
}
|
||||
|
@ -211,6 +225,7 @@ ircd_gpt_attn_self(__global const struct ircd_gpt_task *const ctrl,
|
|||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
|
||||
float4 acc = 0.0f;
|
||||
__attribute__((opencl_unroll_hint))
|
||||
for(uint i = 0; i < wn; ++i)
|
||||
{
|
||||
const float4
|
||||
|
@ -243,7 +258,7 @@ ircd_gpt_attn_proj(__global const struct ircd_gpt_task *const ctrl,
|
|||
height = opts->attn_height;
|
||||
|
||||
// Projection
|
||||
ircd_gpt_sgemv(out->word, xattn->word, bias, weight, width, height, li);
|
||||
ircd_gpt_sgemv(out->word, xattn->word, bias, weight, width, height, 1);
|
||||
}
|
||||
|
||||
__kernel void
|
||||
|
@ -263,6 +278,7 @@ ircd_gpt_coil(__global const struct ircd_gpt_task *const ctrl,
|
|||
{
|
||||
const uint
|
||||
li = get_local_id(0),
|
||||
ln = get_local_size(0),
|
||||
wi = get_group_id(0);
|
||||
|
||||
__local union ircd_gpt_tokenv
|
||||
|
@ -271,7 +287,7 @@ ircd_gpt_coil(__global const struct ircd_gpt_task *const ctrl,
|
|||
__local union
|
||||
{
|
||||
union ircd_gpt_ffnn_aperaturev
|
||||
ffnn_fcon;
|
||||
ffnn_fcon[2];
|
||||
|
||||
float
|
||||
attn_self[512][12];
|
||||
|
@ -319,7 +335,9 @@ ircd_gpt_coil(__global const struct ircd_gpt_task *const ctrl,
|
|||
ctrl,
|
||||
opts,
|
||||
&buf0,
|
||||
&buf.ffnn_fcon,
|
||||
buf.ffnn_fcon + 0,
|
||||
buf.ffnn_fcon + 1,
|
||||
&buf1,
|
||||
ffnn_norm_bias,
|
||||
ffnn_norm_weight,
|
||||
ffnn_fcon_bias,
|
||||
|
@ -376,8 +394,7 @@ ircd_gpt_attn_fcon(__global const struct ircd_gpt_task *const ctrl,
|
|||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
|
||||
// Fully connected
|
||||
for(uint i = 0; i < tiles; ++i)
|
||||
ircd_gpt_sgemv(token.fcon, tmp, fcon_bias, fcon_weight, width, height, i * ln + li);
|
||||
ircd_gpt_sgemv(token.fcon, tmp, fcon_bias, fcon_weight, width, height, tiles);
|
||||
|
||||
// Export queries, keys, and values.
|
||||
for(uint i = 0; i < tiles; ++i)
|
||||
|
|
Loading…
Reference in a new issue