mirror of
https://github.com/matrix-construct/construct
synced 2024-12-02 03:32:52 +01:00
ircd::gpt::pipe: Consolidate loop; improve access pattern; various reorg.
This commit is contained in:
parent
1f49f71530
commit
536594b487
1 changed files with 59 additions and 42 deletions
101
ircd/gpt_cl.cl
101
ircd/gpt_cl.cl
|
@ -10,6 +10,7 @@
|
||||||
|
|
||||||
|
|
||||||
inline void
|
inline void
|
||||||
|
__attribute__((always_inline))
|
||||||
ircd_gpt_norm_fmad(__local float4 *const out,
|
ircd_gpt_norm_fmad(__local float4 *const out,
|
||||||
__local const float4 *const in,
|
__local const float4 *const in,
|
||||||
__global const float4 *const restrict bias,
|
__global const float4 *const restrict bias,
|
||||||
|
@ -19,40 +20,6 @@ ircd_gpt_norm_fmad(__local float4 *const out,
|
||||||
out[i] = in[i] * weight[i] + bias[i];
|
out[i] = in[i] * weight[i] + bias[i];
|
||||||
}
|
}
|
||||||
|
|
||||||
// Matrix * Vector Multiply/Accumulate
|
|
||||||
inline void
|
|
||||||
ircd_gpt_sgemv(__local float4 *const restrict out,
|
|
||||||
__local const float4 *const restrict in,
|
|
||||||
__global const float4 *const restrict bias,
|
|
||||||
__global const float4 *const restrict weight,
|
|
||||||
const uint width,
|
|
||||||
const uint height,
|
|
||||||
const uint i)
|
|
||||||
{
|
|
||||||
const uint
|
|
||||||
lanes = 4;
|
|
||||||
|
|
||||||
float4
|
|
||||||
acc = bias[i];
|
|
||||||
|
|
||||||
for(uint j = 0; j < height; ++j)
|
|
||||||
{
|
|
||||||
const uint
|
|
||||||
tile = j * lanes;
|
|
||||||
|
|
||||||
for(uint k = 0; k < lanes; ++k)
|
|
||||||
{
|
|
||||||
const uint
|
|
||||||
row = tile + k,
|
|
||||||
cell = row * width + i;
|
|
||||||
|
|
||||||
acc += in[j][k] * weight[cell];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
out[i] = acc;
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Gaussian Error Linear Unit
|
/// Gaussian Error Linear Unit
|
||||||
inline void
|
inline void
|
||||||
ircd_gpt_ffnn_gelu(__local float4 *const out,
|
ircd_gpt_ffnn_gelu(__local float4 *const out,
|
||||||
|
@ -78,6 +45,50 @@ ircd_gpt_ffnn_gelu(__local float4 *const out,
|
||||||
out[i] = a;
|
out[i] = a;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Matrix * Vector Multiply/Accumulate
|
||||||
|
inline void
|
||||||
|
ircd_gpt_sgemv(__local float4 *const restrict out,
|
||||||
|
__local const float4 *const restrict in,
|
||||||
|
__global const float4 *const restrict bias,
|
||||||
|
__global const float4 *const restrict weight,
|
||||||
|
const uint width,
|
||||||
|
const uint height,
|
||||||
|
const uint segs)
|
||||||
|
{
|
||||||
|
const uint
|
||||||
|
li = get_local_id(0),
|
||||||
|
ln = get_local_size(0),
|
||||||
|
lanes = 4;
|
||||||
|
|
||||||
|
__attribute__((opencl_unroll_hint))
|
||||||
|
for(uint i = 0; i < segs; ++i)
|
||||||
|
{
|
||||||
|
const uint
|
||||||
|
col = i * ln + li;
|
||||||
|
|
||||||
|
out[col] = bias[col];
|
||||||
|
}
|
||||||
|
|
||||||
|
for(uint j = 0; j < height; ++j)
|
||||||
|
for(uint i = 0; i < segs; ++i)
|
||||||
|
{
|
||||||
|
const uint
|
||||||
|
col = i * ln + li;
|
||||||
|
|
||||||
|
float4 acc = 0.0f;
|
||||||
|
for(uint k = 0; k < lanes; ++k)
|
||||||
|
{
|
||||||
|
const uint
|
||||||
|
row = j * lanes + k,
|
||||||
|
cell = row * width + col;
|
||||||
|
|
||||||
|
acc += in[j][k] * weight[cell];
|
||||||
|
}
|
||||||
|
|
||||||
|
out[col] += acc;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
inline void
|
inline void
|
||||||
__attribute__((always_inline))
|
__attribute__((always_inline))
|
||||||
ircd_gpt_ffnn_fcon(__global const struct ircd_gpt_task *const ctrl,
|
ircd_gpt_ffnn_fcon(__global const struct ircd_gpt_task *const ctrl,
|
||||||
|
@ -94,8 +105,7 @@ ircd_gpt_ffnn_fcon(__global const struct ircd_gpt_task *const ctrl,
|
||||||
height = opts->ffnn_height,
|
height = opts->ffnn_height,
|
||||||
tiles = opts->ffnn_mult;
|
tiles = opts->ffnn_mult;
|
||||||
|
|
||||||
for(uint i = 0; i < tiles; ++i)
|
ircd_gpt_sgemv(out->fcon, in->word, bias, weight, width, height, tiles);
|
||||||
ircd_gpt_sgemv(out->fcon, in->word, bias, weight, width, height, i * ln + li);
|
|
||||||
|
|
||||||
for(uint i = 0; i < tiles; ++i)
|
for(uint i = 0; i < tiles; ++i)
|
||||||
ircd_gpt_ffnn_gelu(out->fcon, out->fcon, i * ln + li);
|
ircd_gpt_ffnn_gelu(out->fcon, out->fcon, i * ln + li);
|
||||||
|
@ -107,6 +117,8 @@ ircd_gpt_ffnn(__global const struct ircd_gpt_task *const ctrl,
|
||||||
__constant const struct ircd_gpt_opts *const opts,
|
__constant const struct ircd_gpt_opts *const opts,
|
||||||
__local union ircd_gpt_tokenv *const restrict token,
|
__local union ircd_gpt_tokenv *const restrict token,
|
||||||
__local union ircd_gpt_ffnn_aperaturev *const restrict buf,
|
__local union ircd_gpt_ffnn_aperaturev *const restrict buf,
|
||||||
|
__local union ircd_gpt_ffnn_aperaturev *const restrict tmp0,
|
||||||
|
__local union ircd_gpt_tokenv *const restrict tmp1,
|
||||||
__global const float4 *const restrict norm_bias,
|
__global const float4 *const restrict norm_bias,
|
||||||
__global const float4 *const restrict norm_weight,
|
__global const float4 *const restrict norm_weight,
|
||||||
__global const float4 *const restrict fcon_bias,
|
__global const float4 *const restrict fcon_bias,
|
||||||
|
@ -138,7 +150,7 @@ ircd_gpt_ffnn(__global const struct ircd_gpt_task *const ctrl,
|
||||||
barrier(CLK_LOCAL_MEM_FENCE);
|
barrier(CLK_LOCAL_MEM_FENCE);
|
||||||
|
|
||||||
// Projection
|
// Projection
|
||||||
ircd_gpt_sgemv(token->word, buf->fcon, proj_bias, proj_weight, height, width, li);
|
ircd_gpt_sgemv(token->word, buf->fcon, proj_bias, proj_weight, height, width, 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
inline void
|
inline void
|
||||||
|
@ -197,12 +209,14 @@ ircd_gpt_attn_self(__global const struct ircd_gpt_task *const ctrl,
|
||||||
self[i][li] = exp(self[i][li] - mu);
|
self[i][li] = exp(self[i][li] - mu);
|
||||||
|
|
||||||
float sum = 0.0f;
|
float sum = 0.0f;
|
||||||
|
__attribute__((opencl_unroll_hint))
|
||||||
for(uint i = 0; i < wn; ++i)
|
for(uint i = 0; i < wn; ++i)
|
||||||
sum += self[i][li];
|
sum += self[i][li];
|
||||||
|
|
||||||
const float
|
const float
|
||||||
lambda = 1.0f / sum;
|
lambda = 1.0f / sum;
|
||||||
|
|
||||||
|
__attribute__((opencl_unroll_hint))
|
||||||
for(uint i = 0; i < wn; ++i)
|
for(uint i = 0; i < wn; ++i)
|
||||||
self[i][li] *= lambda;
|
self[i][li] *= lambda;
|
||||||
}
|
}
|
||||||
|
@ -211,6 +225,7 @@ ircd_gpt_attn_self(__global const struct ircd_gpt_task *const ctrl,
|
||||||
barrier(CLK_LOCAL_MEM_FENCE);
|
barrier(CLK_LOCAL_MEM_FENCE);
|
||||||
|
|
||||||
float4 acc = 0.0f;
|
float4 acc = 0.0f;
|
||||||
|
__attribute__((opencl_unroll_hint))
|
||||||
for(uint i = 0; i < wn; ++i)
|
for(uint i = 0; i < wn; ++i)
|
||||||
{
|
{
|
||||||
const float4
|
const float4
|
||||||
|
@ -243,7 +258,7 @@ ircd_gpt_attn_proj(__global const struct ircd_gpt_task *const ctrl,
|
||||||
height = opts->attn_height;
|
height = opts->attn_height;
|
||||||
|
|
||||||
// Projection
|
// Projection
|
||||||
ircd_gpt_sgemv(out->word, xattn->word, bias, weight, width, height, li);
|
ircd_gpt_sgemv(out->word, xattn->word, bias, weight, width, height, 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
__kernel void
|
__kernel void
|
||||||
|
@ -263,6 +278,7 @@ ircd_gpt_coil(__global const struct ircd_gpt_task *const ctrl,
|
||||||
{
|
{
|
||||||
const uint
|
const uint
|
||||||
li = get_local_id(0),
|
li = get_local_id(0),
|
||||||
|
ln = get_local_size(0),
|
||||||
wi = get_group_id(0);
|
wi = get_group_id(0);
|
||||||
|
|
||||||
__local union ircd_gpt_tokenv
|
__local union ircd_gpt_tokenv
|
||||||
|
@ -271,7 +287,7 @@ ircd_gpt_coil(__global const struct ircd_gpt_task *const ctrl,
|
||||||
__local union
|
__local union
|
||||||
{
|
{
|
||||||
union ircd_gpt_ffnn_aperaturev
|
union ircd_gpt_ffnn_aperaturev
|
||||||
ffnn_fcon;
|
ffnn_fcon[2];
|
||||||
|
|
||||||
float
|
float
|
||||||
attn_self[512][12];
|
attn_self[512][12];
|
||||||
|
@ -319,7 +335,9 @@ ircd_gpt_coil(__global const struct ircd_gpt_task *const ctrl,
|
||||||
ctrl,
|
ctrl,
|
||||||
opts,
|
opts,
|
||||||
&buf0,
|
&buf0,
|
||||||
&buf.ffnn_fcon,
|
buf.ffnn_fcon + 0,
|
||||||
|
buf.ffnn_fcon + 1,
|
||||||
|
&buf1,
|
||||||
ffnn_norm_bias,
|
ffnn_norm_bias,
|
||||||
ffnn_norm_weight,
|
ffnn_norm_weight,
|
||||||
ffnn_fcon_bias,
|
ffnn_fcon_bias,
|
||||||
|
@ -376,8 +394,7 @@ ircd_gpt_attn_fcon(__global const struct ircd_gpt_task *const ctrl,
|
||||||
barrier(CLK_LOCAL_MEM_FENCE);
|
barrier(CLK_LOCAL_MEM_FENCE);
|
||||||
|
|
||||||
// Fully connected
|
// Fully connected
|
||||||
for(uint i = 0; i < tiles; ++i)
|
ircd_gpt_sgemv(token.fcon, tmp, fcon_bias, fcon_weight, width, height, tiles);
|
||||||
ircd_gpt_sgemv(token.fcon, tmp, fcon_bias, fcon_weight, width, height, i * ln + li);
|
|
||||||
|
|
||||||
// Export queries, keys, and values.
|
// Export queries, keys, and values.
|
||||||
for(uint i = 0; i < tiles; ++i)
|
for(uint i = 0; i < tiles; ++i)
|
||||||
|
|
Loading…
Reference in a new issue