0
0
Fork 0
mirror of https://github.com/matrix-construct/construct synced 2024-11-29 18:22:50 +01:00

ircd::gpt::pipe: Consolidate loop; improve access pattern; various reorg.

This commit is contained in:
Jason Volk 2021-04-26 20:43:21 -07:00
parent 1f49f71530
commit 536594b487

View file

@ -10,6 +10,7 @@
inline void
__attribute__((always_inline))
ircd_gpt_norm_fmad(__local float4 *const out,
__local const float4 *const in,
__global const float4 *const restrict bias,
@ -19,40 +20,6 @@ ircd_gpt_norm_fmad(__local float4 *const out,
out[i] = in[i] * weight[i] + bias[i];
}
// Matrix * Vector Multiply/Accumulate
inline void
ircd_gpt_sgemv(__local float4 *const restrict out,
__local const float4 *const restrict in,
__global const float4 *const restrict bias,
__global const float4 *const restrict weight,
const uint width,
const uint height,
const uint i)
{
const uint
lanes = 4;
float4
acc = bias[i];
for(uint j = 0; j < height; ++j)
{
const uint
tile = j * lanes;
for(uint k = 0; k < lanes; ++k)
{
const uint
row = tile + k,
cell = row * width + i;
acc += in[j][k] * weight[cell];
}
}
out[i] = acc;
}
/// Gaussian Error Linear Unit
inline void
ircd_gpt_ffnn_gelu(__local float4 *const out,
@ -78,6 +45,50 @@ ircd_gpt_ffnn_gelu(__local float4 *const out,
out[i] = a;
}
// Matrix * Vector Multiply/Accumulate
inline void
ircd_gpt_sgemv(__local float4 *const restrict out,
__local const float4 *const restrict in,
__global const float4 *const restrict bias,
__global const float4 *const restrict weight,
const uint width,
const uint height,
const uint segs)
{
const uint
li = get_local_id(0),
ln = get_local_size(0),
lanes = 4;
__attribute__((opencl_unroll_hint))
for(uint i = 0; i < segs; ++i)
{
const uint
col = i * ln + li;
out[col] = bias[col];
}
for(uint j = 0; j < height; ++j)
for(uint i = 0; i < segs; ++i)
{
const uint
col = i * ln + li;
float4 acc = 0.0f;
for(uint k = 0; k < lanes; ++k)
{
const uint
row = j * lanes + k,
cell = row * width + col;
acc += in[j][k] * weight[cell];
}
out[col] += acc;
}
}
inline void
__attribute__((always_inline))
ircd_gpt_ffnn_fcon(__global const struct ircd_gpt_task *const ctrl,
@ -94,8 +105,7 @@ ircd_gpt_ffnn_fcon(__global const struct ircd_gpt_task *const ctrl,
height = opts->ffnn_height,
tiles = opts->ffnn_mult;
for(uint i = 0; i < tiles; ++i)
ircd_gpt_sgemv(out->fcon, in->word, bias, weight, width, height, i * ln + li);
ircd_gpt_sgemv(out->fcon, in->word, bias, weight, width, height, tiles);
for(uint i = 0; i < tiles; ++i)
ircd_gpt_ffnn_gelu(out->fcon, out->fcon, i * ln + li);
@ -107,6 +117,8 @@ ircd_gpt_ffnn(__global const struct ircd_gpt_task *const ctrl,
__constant const struct ircd_gpt_opts *const opts,
__local union ircd_gpt_tokenv *const restrict token,
__local union ircd_gpt_ffnn_aperaturev *const restrict buf,
__local union ircd_gpt_ffnn_aperaturev *const restrict tmp0,
__local union ircd_gpt_tokenv *const restrict tmp1,
__global const float4 *const restrict norm_bias,
__global const float4 *const restrict norm_weight,
__global const float4 *const restrict fcon_bias,
@ -138,7 +150,7 @@ ircd_gpt_ffnn(__global const struct ircd_gpt_task *const ctrl,
barrier(CLK_LOCAL_MEM_FENCE);
// Projection
ircd_gpt_sgemv(token->word, buf->fcon, proj_bias, proj_weight, height, width, li);
ircd_gpt_sgemv(token->word, buf->fcon, proj_bias, proj_weight, height, width, 1);
}
inline void
@ -197,12 +209,14 @@ ircd_gpt_attn_self(__global const struct ircd_gpt_task *const ctrl,
self[i][li] = exp(self[i][li] - mu);
float sum = 0.0f;
__attribute__((opencl_unroll_hint))
for(uint i = 0; i < wn; ++i)
sum += self[i][li];
const float
lambda = 1.0f / sum;
__attribute__((opencl_unroll_hint))
for(uint i = 0; i < wn; ++i)
self[i][li] *= lambda;
}
@ -211,6 +225,7 @@ ircd_gpt_attn_self(__global const struct ircd_gpt_task *const ctrl,
barrier(CLK_LOCAL_MEM_FENCE);
float4 acc = 0.0f;
__attribute__((opencl_unroll_hint))
for(uint i = 0; i < wn; ++i)
{
const float4
@ -243,7 +258,7 @@ ircd_gpt_attn_proj(__global const struct ircd_gpt_task *const ctrl,
height = opts->attn_height;
// Projection
ircd_gpt_sgemv(out->word, xattn->word, bias, weight, width, height, li);
ircd_gpt_sgemv(out->word, xattn->word, bias, weight, width, height, 1);
}
__kernel void
@ -263,6 +278,7 @@ ircd_gpt_coil(__global const struct ircd_gpt_task *const ctrl,
{
const uint
li = get_local_id(0),
ln = get_local_size(0),
wi = get_group_id(0);
__local union ircd_gpt_tokenv
@ -271,7 +287,7 @@ ircd_gpt_coil(__global const struct ircd_gpt_task *const ctrl,
__local union
{
union ircd_gpt_ffnn_aperaturev
ffnn_fcon;
ffnn_fcon[2];
float
attn_self[512][12];
@ -319,7 +335,9 @@ ircd_gpt_coil(__global const struct ircd_gpt_task *const ctrl,
ctrl,
opts,
&buf0,
&buf.ffnn_fcon,
buf.ffnn_fcon + 0,
buf.ffnn_fcon + 1,
&buf1,
ffnn_norm_bias,
ffnn_norm_weight,
ffnn_fcon_bias,
@ -376,8 +394,7 @@ ircd_gpt_attn_fcon(__global const struct ircd_gpt_task *const ctrl,
barrier(CLK_LOCAL_MEM_FENCE);
// Fully connected
for(uint i = 0; i < tiles; ++i)
ircd_gpt_sgemv(token.fcon, tmp, fcon_bias, fcon_weight, width, height, i * ln + li);
ircd_gpt_sgemv(token.fcon, tmp, fcon_bias, fcon_weight, width, height, tiles);
// Export queries, keys, and values.
for(uint i = 0; i < tiles; ++i)