mirror of
https://github.com/matrix-construct/construct
synced 2024-12-28 08:24:08 +01:00
ircd::gpt::pipe: Expansions to the context size.
This commit is contained in:
parent
f61239a52c
commit
5392ba0c7d
2 changed files with 40 additions and 24 deletions
|
@ -10,8 +10,8 @@
|
||||||
|
|
||||||
|
|
||||||
inline void
|
inline void
|
||||||
ircd_gpt_norm_fmad(__local float4 *const restrict out,
|
ircd_gpt_norm_fmad(__local float4 *const out,
|
||||||
__local const float4 *const restrict in,
|
__local const float4 *const in,
|
||||||
__global const float4 *const restrict bias,
|
__global const float4 *const restrict bias,
|
||||||
__global const float4 *const restrict weight,
|
__global const float4 *const restrict weight,
|
||||||
const uint i)
|
const uint i)
|
||||||
|
@ -36,8 +36,19 @@ ircd_gpt_sgemv(__local float4 *const restrict out,
|
||||||
acc = bias[i];
|
acc = bias[i];
|
||||||
|
|
||||||
for(uint j = 0; j < height; ++j)
|
for(uint j = 0; j < height; ++j)
|
||||||
|
{
|
||||||
|
const uint
|
||||||
|
tile = j * lanes;
|
||||||
|
|
||||||
for(uint k = 0; k < lanes; ++k)
|
for(uint k = 0; k < lanes; ++k)
|
||||||
acc += in[j][k] * weight[width * (j * lanes + k) + i];
|
{
|
||||||
|
const uint
|
||||||
|
row = tile + k,
|
||||||
|
cell = row * width + i;
|
||||||
|
|
||||||
|
acc += in[j][k] * weight[cell];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
out[i] = acc;
|
out[i] = acc;
|
||||||
}
|
}
|
||||||
|
@ -95,7 +106,6 @@ __attribute__((always_inline))
|
||||||
ircd_gpt_ffnn(__global const struct ircd_gpt_task *const ctrl,
|
ircd_gpt_ffnn(__global const struct ircd_gpt_task *const ctrl,
|
||||||
__constant const struct ircd_gpt_opts *const opts,
|
__constant const struct ircd_gpt_opts *const opts,
|
||||||
__local union ircd_gpt_tokenv *const restrict token,
|
__local union ircd_gpt_tokenv *const restrict token,
|
||||||
__local union ircd_gpt_tokenv *const restrict tmp,
|
|
||||||
__local union ircd_gpt_ffnn_aperaturev *const restrict buf,
|
__local union ircd_gpt_ffnn_aperaturev *const restrict buf,
|
||||||
__global const float4 *const restrict norm_bias,
|
__global const float4 *const restrict norm_bias,
|
||||||
__global const float4 *const restrict norm_weight,
|
__global const float4 *const restrict norm_weight,
|
||||||
|
@ -116,13 +126,13 @@ ircd_gpt_ffnn(__global const struct ircd_gpt_task *const ctrl,
|
||||||
|
|
||||||
// Layer re-normalization
|
// Layer re-normalization
|
||||||
ircd_simt_math_norm_f4lldr(token->word, token->word, buf->word);
|
ircd_simt_math_norm_f4lldr(token->word, token->word, buf->word);
|
||||||
ircd_gpt_norm_fmad(tmp->word, token->word, norm_bias, norm_weight, li);
|
ircd_gpt_norm_fmad(token->word, token->word, norm_bias, norm_weight, li);
|
||||||
|
|
||||||
// ln's writes are still pending but fcon reads results across threads.
|
// ln's writes are still pending but fcon reads results across threads.
|
||||||
barrier(CLK_LOCAL_MEM_FENCE);
|
barrier(CLK_LOCAL_MEM_FENCE);
|
||||||
|
|
||||||
// Fully connected
|
// Fully connected
|
||||||
ircd_gpt_ffnn_fcon(ctrl, opts, buf, tmp, fcon_bias, fcon_weight);
|
ircd_gpt_ffnn_fcon(ctrl, opts, buf, token, fcon_bias, fcon_weight);
|
||||||
|
|
||||||
// fcon's writes are still pending but proj reads results across threads.
|
// fcon's writes are still pending but proj reads results across threads.
|
||||||
barrier(CLK_LOCAL_MEM_FENCE);
|
barrier(CLK_LOCAL_MEM_FENCE);
|
||||||
|
@ -255,14 +265,18 @@ ircd_gpt_coil(__global const struct ircd_gpt_task *const ctrl,
|
||||||
li = get_local_id(0),
|
li = get_local_id(0),
|
||||||
wi = get_group_id(0);
|
wi = get_group_id(0);
|
||||||
|
|
||||||
__local float
|
|
||||||
self[96][12];
|
|
||||||
|
|
||||||
__local union ircd_gpt_tokenv
|
__local union ircd_gpt_tokenv
|
||||||
buf1, buf0;
|
buf1, buf0;
|
||||||
|
|
||||||
__local union ircd_gpt_ffnn_aperaturev
|
__local union
|
||||||
ffnn_fcon;
|
{
|
||||||
|
union ircd_gpt_ffnn_aperaturev
|
||||||
|
ffnn_fcon;
|
||||||
|
|
||||||
|
float
|
||||||
|
attn_self[512][12];
|
||||||
|
}
|
||||||
|
buf;
|
||||||
|
|
||||||
// Self-attention backend; this computes the self-attention result now
|
// Self-attention backend; this computes the self-attention result now
|
||||||
// that keys and values are globally visible across tokens.
|
// that keys and values are globally visible across tokens.
|
||||||
|
@ -271,7 +285,7 @@ ircd_gpt_coil(__global const struct ircd_gpt_task *const ctrl,
|
||||||
ctrl,
|
ctrl,
|
||||||
opts,
|
opts,
|
||||||
&buf1,
|
&buf1,
|
||||||
self,
|
buf.attn_self,
|
||||||
state,
|
state,
|
||||||
mask
|
mask
|
||||||
);
|
);
|
||||||
|
@ -305,8 +319,7 @@ ircd_gpt_coil(__global const struct ircd_gpt_task *const ctrl,
|
||||||
ctrl,
|
ctrl,
|
||||||
opts,
|
opts,
|
||||||
&buf0,
|
&buf0,
|
||||||
&buf1,
|
&buf.ffnn_fcon,
|
||||||
&ffnn_fcon,
|
|
||||||
ffnn_norm_bias,
|
ffnn_norm_bias,
|
||||||
ffnn_norm_weight,
|
ffnn_norm_weight,
|
||||||
ffnn_fcon_bias,
|
ffnn_fcon_bias,
|
||||||
|
@ -404,10 +417,13 @@ ircd_gpt_lm_embed(__global const struct ircd_gpt_task *const ctrl,
|
||||||
__global const union ircd_gpt_tokenv *const restrict vocab)
|
__global const union ircd_gpt_tokenv *const restrict vocab)
|
||||||
{
|
{
|
||||||
const uint
|
const uint
|
||||||
li = get_local_id(0);
|
li = get_local_id(0),
|
||||||
|
wi = get_group_id(0),
|
||||||
|
wn = get_num_groups(0);
|
||||||
|
|
||||||
for(uint i = 0; i < ctrl->tokens; ++i)
|
for(uint i = 0; i < ctrl->tokens; ++i)
|
||||||
_ircd_gpt_lm_embed(ctrl, opts, accum, pos, vocab, i, i, li);
|
if(i % wn == wi)
|
||||||
|
_ircd_gpt_lm_embed(ctrl, opts, accum, pos, vocab, i, i, li);
|
||||||
}
|
}
|
||||||
|
|
||||||
__kernel void
|
__kernel void
|
||||||
|
@ -429,9 +445,9 @@ ircd_gpt_lm_norm(__global const struct ircd_gpt_task *const ctrl,
|
||||||
|
|
||||||
// Final re-normalization
|
// Final re-normalization
|
||||||
ircd_simt_math_norm_f4lldr(token.word, token.word, tmp.word);
|
ircd_simt_math_norm_f4lldr(token.word, token.word, tmp.word);
|
||||||
ircd_gpt_norm_fmad(tmp.word, token.word, norm_bias, norm_weight, li);
|
ircd_gpt_norm_fmad(token.word, token.word, norm_bias, norm_weight, li);
|
||||||
|
|
||||||
accum[wi].word[li] = tmp.word[li];
|
accum[wi].word[li] = token.word[li];
|
||||||
}
|
}
|
||||||
|
|
||||||
__kernel void
|
__kernel void
|
||||||
|
|
|
@ -277,8 +277,8 @@ ircd::gpt::pipe::exec::exec(task &task,
|
||||||
}
|
}
|
||||||
,range_lm_embed
|
,range_lm_embed
|
||||||
{
|
{
|
||||||
{ 1 * 192UL, 0, },
|
{ std::min(tokens, 12UL) * 192UL, 0, },
|
||||||
{ 192UL, 0, },
|
{ 192UL, 0, },
|
||||||
}
|
}
|
||||||
,range_negative
|
,range_negative
|
||||||
{
|
{
|
||||||
|
@ -303,8 +303,8 @@ ircd::gpt::pipe::exec::exec(task &task,
|
||||||
}
|
}
|
||||||
,range_lm_logsm
|
,range_lm_logsm
|
||||||
{
|
{
|
||||||
{ 1 * 256UL, 0 },
|
{ 1 * 256UL, 0 },
|
||||||
{ 256UL, 0 },
|
{ 256UL, 0 },
|
||||||
}
|
}
|
||||||
,range_lm_select
|
,range_lm_select
|
||||||
{
|
{
|
||||||
|
@ -459,12 +459,12 @@ ircd::gpt::pipe::desc::desc(pipe::code &code,
|
||||||
}
|
}
|
||||||
,state
|
,state
|
||||||
{
|
{
|
||||||
96 * 3 * 768 * sizeof(float),
|
512 * 3 * 768 * sizeof(float),
|
||||||
mutable_buffer{}
|
mutable_buffer{}
|
||||||
}
|
}
|
||||||
,accum
|
,accum
|
||||||
{
|
{
|
||||||
96 * 768 * sizeof(float),
|
512 * 768 * sizeof(float),
|
||||||
mutable_buffer{}
|
mutable_buffer{}
|
||||||
}
|
}
|
||||||
,logit
|
,logit
|
||||||
|
|
Loading…
Reference in a new issue