0
0
Fork 0
mirror of https://github.com/matrix-construct/construct synced 2025-01-14 16:46:50 +01:00

ircd::gpt::pipe: Expansions to the context size.

This commit is contained in:
Jason Volk 2021-04-22 12:17:29 -07:00
parent f61239a52c
commit 5392ba0c7d
2 changed files with 40 additions and 24 deletions

View file

@ -10,8 +10,8 @@
inline void
ircd_gpt_norm_fmad(__local float4 *const restrict out,
__local const float4 *const restrict in,
ircd_gpt_norm_fmad(__local float4 *const out,
__local const float4 *const in,
__global const float4 *const restrict bias,
__global const float4 *const restrict weight,
const uint i)
@ -36,8 +36,19 @@ ircd_gpt_sgemv(__local float4 *const restrict out,
acc = bias[i];
for(uint j = 0; j < height; ++j)
{
const uint
tile = j * lanes;
for(uint k = 0; k < lanes; ++k)
acc += in[j][k] * weight[width * (j * lanes + k) + i];
{
const uint
row = tile + k,
cell = row * width + i;
acc += in[j][k] * weight[cell];
}
}
out[i] = acc;
}
@ -95,7 +106,6 @@ __attribute__((always_inline))
ircd_gpt_ffnn(__global const struct ircd_gpt_task *const ctrl,
__constant const struct ircd_gpt_opts *const opts,
__local union ircd_gpt_tokenv *const restrict token,
__local union ircd_gpt_tokenv *const restrict tmp,
__local union ircd_gpt_ffnn_aperaturev *const restrict buf,
__global const float4 *const restrict norm_bias,
__global const float4 *const restrict norm_weight,
@ -116,13 +126,13 @@ ircd_gpt_ffnn(__global const struct ircd_gpt_task *const ctrl,
// Layer re-normalization
ircd_simt_math_norm_f4lldr(token->word, token->word, buf->word);
ircd_gpt_norm_fmad(tmp->word, token->word, norm_bias, norm_weight, li);
ircd_gpt_norm_fmad(token->word, token->word, norm_bias, norm_weight, li);
// ln's writes are still pending but fcon reads results across threads.
barrier(CLK_LOCAL_MEM_FENCE);
// Fully connected
ircd_gpt_ffnn_fcon(ctrl, opts, buf, tmp, fcon_bias, fcon_weight);
ircd_gpt_ffnn_fcon(ctrl, opts, buf, token, fcon_bias, fcon_weight);
// fcon's writes are still pending but proj reads results across threads.
barrier(CLK_LOCAL_MEM_FENCE);
@ -255,14 +265,18 @@ ircd_gpt_coil(__global const struct ircd_gpt_task *const ctrl,
li = get_local_id(0),
wi = get_group_id(0);
__local float
self[96][12];
__local union ircd_gpt_tokenv
buf1, buf0;
__local union ircd_gpt_ffnn_aperaturev
ffnn_fcon;
__local union
{
union ircd_gpt_ffnn_aperaturev
ffnn_fcon;
float
attn_self[512][12];
}
buf;
// Self-attention backend; this computes the self-attention result now
// that keys and values are globally visible across tokens.
@ -271,7 +285,7 @@ ircd_gpt_coil(__global const struct ircd_gpt_task *const ctrl,
ctrl,
opts,
&buf1,
self,
buf.attn_self,
state,
mask
);
@ -305,8 +319,7 @@ ircd_gpt_coil(__global const struct ircd_gpt_task *const ctrl,
ctrl,
opts,
&buf0,
&buf1,
&ffnn_fcon,
&buf.ffnn_fcon,
ffnn_norm_bias,
ffnn_norm_weight,
ffnn_fcon_bias,
@ -404,10 +417,13 @@ ircd_gpt_lm_embed(__global const struct ircd_gpt_task *const ctrl,
__global const union ircd_gpt_tokenv *const restrict vocab)
{
const uint
li = get_local_id(0);
li = get_local_id(0),
wi = get_group_id(0),
wn = get_num_groups(0);
for(uint i = 0; i < ctrl->tokens; ++i)
_ircd_gpt_lm_embed(ctrl, opts, accum, pos, vocab, i, i, li);
if(i % wn == wi)
_ircd_gpt_lm_embed(ctrl, opts, accum, pos, vocab, i, i, li);
}
__kernel void
@ -429,9 +445,9 @@ ircd_gpt_lm_norm(__global const struct ircd_gpt_task *const ctrl,
// Final re-normalization
ircd_simt_math_norm_f4lldr(token.word, token.word, tmp.word);
ircd_gpt_norm_fmad(tmp.word, token.word, norm_bias, norm_weight, li);
ircd_gpt_norm_fmad(token.word, token.word, norm_bias, norm_weight, li);
accum[wi].word[li] = tmp.word[li];
accum[wi].word[li] = token.word[li];
}
__kernel void

View file

@ -277,8 +277,8 @@ ircd::gpt::pipe::exec::exec(task &task,
}
,range_lm_embed
{
{ 1 * 192UL, 0, },
{ 192UL, 0, },
{ std::min(tokens, 12UL) * 192UL, 0, },
{ 192UL, 0, },
}
,range_negative
{
@ -303,8 +303,8 @@ ircd::gpt::pipe::exec::exec(task &task,
}
,range_lm_logsm
{
{ 1 * 256UL, 0 },
{ 256UL, 0 },
{ 1 * 256UL, 0 },
{ 256UL, 0 },
}
,range_lm_select
{
@ -459,12 +459,12 @@ ircd::gpt::pipe::desc::desc(pipe::code &code,
}
,state
{
96 * 3 * 768 * sizeof(float),
512 * 3 * 768 * sizeof(float),
mutable_buffer{}
}
,accum
{
96 * 768 * sizeof(float),
512 * 768 * sizeof(float),
mutable_buffer{}
}
,logit