mirror of
https://github.com/matrix-construct/construct
synced 2024-12-26 23:44:01 +01:00
ircd::gpt::pipe: Expansions to the context size.
This commit is contained in:
parent
f61239a52c
commit
5392ba0c7d
2 changed files with 40 additions and 24 deletions
|
@ -10,8 +10,8 @@
|
|||
|
||||
|
||||
inline void
|
||||
ircd_gpt_norm_fmad(__local float4 *const restrict out,
|
||||
__local const float4 *const restrict in,
|
||||
ircd_gpt_norm_fmad(__local float4 *const out,
|
||||
__local const float4 *const in,
|
||||
__global const float4 *const restrict bias,
|
||||
__global const float4 *const restrict weight,
|
||||
const uint i)
|
||||
|
@ -36,8 +36,19 @@ ircd_gpt_sgemv(__local float4 *const restrict out,
|
|||
acc = bias[i];
|
||||
|
||||
for(uint j = 0; j < height; ++j)
|
||||
{
|
||||
const uint
|
||||
tile = j * lanes;
|
||||
|
||||
for(uint k = 0; k < lanes; ++k)
|
||||
acc += in[j][k] * weight[width * (j * lanes + k) + i];
|
||||
{
|
||||
const uint
|
||||
row = tile + k,
|
||||
cell = row * width + i;
|
||||
|
||||
acc += in[j][k] * weight[cell];
|
||||
}
|
||||
}
|
||||
|
||||
out[i] = acc;
|
||||
}
|
||||
|
@ -95,7 +106,6 @@ __attribute__((always_inline))
|
|||
ircd_gpt_ffnn(__global const struct ircd_gpt_task *const ctrl,
|
||||
__constant const struct ircd_gpt_opts *const opts,
|
||||
__local union ircd_gpt_tokenv *const restrict token,
|
||||
__local union ircd_gpt_tokenv *const restrict tmp,
|
||||
__local union ircd_gpt_ffnn_aperaturev *const restrict buf,
|
||||
__global const float4 *const restrict norm_bias,
|
||||
__global const float4 *const restrict norm_weight,
|
||||
|
@ -116,13 +126,13 @@ ircd_gpt_ffnn(__global const struct ircd_gpt_task *const ctrl,
|
|||
|
||||
// Layer re-normalization
|
||||
ircd_simt_math_norm_f4lldr(token->word, token->word, buf->word);
|
||||
ircd_gpt_norm_fmad(tmp->word, token->word, norm_bias, norm_weight, li);
|
||||
ircd_gpt_norm_fmad(token->word, token->word, norm_bias, norm_weight, li);
|
||||
|
||||
// ln's writes are still pending but fcon reads results across threads.
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
|
||||
// Fully connected
|
||||
ircd_gpt_ffnn_fcon(ctrl, opts, buf, tmp, fcon_bias, fcon_weight);
|
||||
ircd_gpt_ffnn_fcon(ctrl, opts, buf, token, fcon_bias, fcon_weight);
|
||||
|
||||
// fcon's writes are still pending but proj reads results across threads.
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
|
@ -255,14 +265,18 @@ ircd_gpt_coil(__global const struct ircd_gpt_task *const ctrl,
|
|||
li = get_local_id(0),
|
||||
wi = get_group_id(0);
|
||||
|
||||
__local float
|
||||
self[96][12];
|
||||
|
||||
__local union ircd_gpt_tokenv
|
||||
buf1, buf0;
|
||||
|
||||
__local union ircd_gpt_ffnn_aperaturev
|
||||
ffnn_fcon;
|
||||
__local union
|
||||
{
|
||||
union ircd_gpt_ffnn_aperaturev
|
||||
ffnn_fcon;
|
||||
|
||||
float
|
||||
attn_self[512][12];
|
||||
}
|
||||
buf;
|
||||
|
||||
// Self-attention backend; this computes the self-attention result now
|
||||
// that keys and values are globally visible across tokens.
|
||||
|
@ -271,7 +285,7 @@ ircd_gpt_coil(__global const struct ircd_gpt_task *const ctrl,
|
|||
ctrl,
|
||||
opts,
|
||||
&buf1,
|
||||
self,
|
||||
buf.attn_self,
|
||||
state,
|
||||
mask
|
||||
);
|
||||
|
@ -305,8 +319,7 @@ ircd_gpt_coil(__global const struct ircd_gpt_task *const ctrl,
|
|||
ctrl,
|
||||
opts,
|
||||
&buf0,
|
||||
&buf1,
|
||||
&ffnn_fcon,
|
||||
&buf.ffnn_fcon,
|
||||
ffnn_norm_bias,
|
||||
ffnn_norm_weight,
|
||||
ffnn_fcon_bias,
|
||||
|
@ -404,10 +417,13 @@ ircd_gpt_lm_embed(__global const struct ircd_gpt_task *const ctrl,
|
|||
__global const union ircd_gpt_tokenv *const restrict vocab)
|
||||
{
|
||||
const uint
|
||||
li = get_local_id(0);
|
||||
li = get_local_id(0),
|
||||
wi = get_group_id(0),
|
||||
wn = get_num_groups(0);
|
||||
|
||||
for(uint i = 0; i < ctrl->tokens; ++i)
|
||||
_ircd_gpt_lm_embed(ctrl, opts, accum, pos, vocab, i, i, li);
|
||||
if(i % wn == wi)
|
||||
_ircd_gpt_lm_embed(ctrl, opts, accum, pos, vocab, i, i, li);
|
||||
}
|
||||
|
||||
__kernel void
|
||||
|
@ -429,9 +445,9 @@ ircd_gpt_lm_norm(__global const struct ircd_gpt_task *const ctrl,
|
|||
|
||||
// Final re-normalization
|
||||
ircd_simt_math_norm_f4lldr(token.word, token.word, tmp.word);
|
||||
ircd_gpt_norm_fmad(tmp.word, token.word, norm_bias, norm_weight, li);
|
||||
ircd_gpt_norm_fmad(token.word, token.word, norm_bias, norm_weight, li);
|
||||
|
||||
accum[wi].word[li] = tmp.word[li];
|
||||
accum[wi].word[li] = token.word[li];
|
||||
}
|
||||
|
||||
__kernel void
|
||||
|
|
|
@ -277,8 +277,8 @@ ircd::gpt::pipe::exec::exec(task &task,
|
|||
}
|
||||
,range_lm_embed
|
||||
{
|
||||
{ 1 * 192UL, 0, },
|
||||
{ 192UL, 0, },
|
||||
{ std::min(tokens, 12UL) * 192UL, 0, },
|
||||
{ 192UL, 0, },
|
||||
}
|
||||
,range_negative
|
||||
{
|
||||
|
@ -303,8 +303,8 @@ ircd::gpt::pipe::exec::exec(task &task,
|
|||
}
|
||||
,range_lm_logsm
|
||||
{
|
||||
{ 1 * 256UL, 0 },
|
||||
{ 256UL, 0 },
|
||||
{ 1 * 256UL, 0 },
|
||||
{ 256UL, 0 },
|
||||
}
|
||||
,range_lm_select
|
||||
{
|
||||
|
@ -459,12 +459,12 @@ ircd::gpt::pipe::desc::desc(pipe::code &code,
|
|||
}
|
||||
,state
|
||||
{
|
||||
96 * 3 * 768 * sizeof(float),
|
||||
512 * 3 * 768 * sizeof(float),
|
||||
mutable_buffer{}
|
||||
}
|
||||
,accum
|
||||
{
|
||||
96 * 768 * sizeof(float),
|
||||
512 * 768 * sizeof(float),
|
||||
mutable_buffer{}
|
||||
}
|
||||
,logit
|
||||
|
|
Loading…
Reference in a new issue