mirror of
https://github.com/matrix-construct/construct
synced 2024-12-26 07:23:53 +01:00
ircd::gpt::pipe: Optimize pipeline to cache attention state for generations.
This commit is contained in:
parent
c5f159ad58
commit
8a3eeb46f9
6 changed files with 95 additions and 53 deletions
|
@ -20,14 +20,14 @@ struct ircd::gpt::pipe::desc
|
|||
pipe::code *code;
|
||||
|
||||
cl::data
|
||||
master,
|
||||
state, // qry/key/val projection (tokens * embed * 3 * float)
|
||||
accum, // accumulator (tokens * embed * float)
|
||||
logit, // result logit vector (50257 * float)
|
||||
logexp, // outputs distribution (50257 * float)
|
||||
logsm, // outputs distribution (50257 * float)
|
||||
ctrl, // control page
|
||||
opts; // options page
|
||||
state, // [root] projection (layers * tokens * embed * 3 * float)
|
||||
master, // [root] single allocation for additional buffers:
|
||||
accum, // [-sub] accumulator (tokens * embed * float)
|
||||
logit, // [-sub] result logit vector (50257 * float)
|
||||
logexp, // [-sub] outputs distribution (50257 * float)
|
||||
logsm, // [-sub] outputs distribution (50257 * float)
|
||||
ctrl, // [root] control page
|
||||
opts; // [root] options page
|
||||
|
||||
cl::kern
|
||||
lm_embed,
|
||||
|
@ -46,6 +46,9 @@ struct ircd::gpt::pipe::desc
|
|||
|
||||
struct ircd::gpt::pipe::desc::layer
|
||||
{
|
||||
cl::data
|
||||
state; // [-sub] qry/key/val projection (tokens * embed * 3 * float)
|
||||
|
||||
cl::kern
|
||||
negative,
|
||||
positive,
|
||||
|
|
|
@ -37,6 +37,8 @@ struct ircd::gpt::pipe::exec
|
|||
recv_ctrl; // Set when receiving the control page.
|
||||
|
||||
cl::kern::range
|
||||
range_full,
|
||||
range_last,
|
||||
range_lm_embed, // Dimension range of the lm_embed kernel.
|
||||
range_negative, // Dimension range of a layer kernel.
|
||||
range_positive, // Dimension range of a layer kernel.
|
||||
|
|
|
@ -51,6 +51,9 @@ struct ircd_gpt_opts
|
|||
/// Embedding vector elements
|
||||
uint embed_elems;
|
||||
|
||||
/// Cross-attention dimension
|
||||
uint attn_rank;
|
||||
|
||||
/// Attention unit fcon width multiple
|
||||
uint attn_mult;
|
||||
|
||||
|
|
|
@ -317,6 +317,10 @@ noexcept
|
|||
{
|
||||
768U
|
||||
}
|
||||
,attn_rank
|
||||
{
|
||||
12U
|
||||
}
|
||||
,attn_mult
|
||||
{
|
||||
3U
|
||||
|
|
|
@ -159,11 +159,13 @@ inline void
|
|||
__attribute__((flatten, always_inline))
|
||||
ircd_gpt_attn_self_samax(__global const struct ircd_gpt_ctrl *const ctrl,
|
||||
__constant const struct ircd_gpt_opts *const opts,
|
||||
__local float self[][12])
|
||||
__local float self[][12],
|
||||
const uint wn)
|
||||
{
|
||||
const uint
|
||||
gn = get_global_size(0),
|
||||
li = get_local_id(0),
|
||||
wn = get_num_groups(0);
|
||||
ln = get_local_size(0);
|
||||
|
||||
struct ircd_math_samax samax =
|
||||
{
|
||||
|
@ -201,14 +203,16 @@ ircd_gpt_attn_self(__global const struct ircd_gpt_ctrl *const ctrl,
|
|||
gn = get_global_size(0),
|
||||
li = get_local_id(0),
|
||||
ln = get_local_size(0),
|
||||
wi = get_group_id(0),
|
||||
wn = get_num_groups(0),
|
||||
ti = li % 12,
|
||||
ki = li / 12;
|
||||
wi = get_global_offset(0) / ln + get_group_id(0),
|
||||
wn = ctrl->tokens.count,
|
||||
ti = li % opts->attn_rank,
|
||||
ki = li / opts->attn_rank,
|
||||
kn = ln / opts->attn_rank;
|
||||
|
||||
// Low-rank mask
|
||||
if(li < 12)
|
||||
if(li < opts->attn_rank)
|
||||
{
|
||||
// For each token
|
||||
for(uint i = 0; i < wn; ++i)
|
||||
{
|
||||
// Left-attention mask
|
||||
|
@ -219,7 +223,8 @@ ircd_gpt_attn_self(__global const struct ircd_gpt_ctrl *const ctrl,
|
|||
}
|
||||
|
||||
float4 acc = 0.0f;
|
||||
for(uint k = 0; k < 64/4; ++k)
|
||||
__attribute__((opencl_unroll_hint))
|
||||
for(uint k = 0; k < kn; ++k)
|
||||
{
|
||||
float4
|
||||
qry = token[wi].qry.attn[li][k],
|
||||
|
@ -236,7 +241,7 @@ ircd_gpt_attn_self(__global const struct ircd_gpt_ctrl *const ctrl,
|
|||
}
|
||||
|
||||
// Three-piece softmax
|
||||
ircd_gpt_attn_self_samax(ctrl, opts, self);
|
||||
ircd_gpt_attn_self_samax(ctrl, opts, self, wn);
|
||||
}
|
||||
|
||||
// Propagate to full width for value dot prod.
|
||||
|
@ -244,7 +249,7 @@ ircd_gpt_attn_self(__global const struct ircd_gpt_ctrl *const ctrl,
|
|||
|
||||
float4 acc = 0.0f;
|
||||
__attribute__((opencl_unroll_hint))
|
||||
for(uint i = 0; i < wn; ++i)
|
||||
for(uint i = 0; i < wi; ++i)
|
||||
{
|
||||
const float4
|
||||
attn = self[i][ti],
|
||||
|
@ -420,7 +425,7 @@ inline void
|
|||
__attribute__((always_inline))
|
||||
_ircd_gpt_lm_embed(__global const struct ircd_gpt_ctrl *const ctrl,
|
||||
__constant const struct ircd_gpt_opts *const opts,
|
||||
__global union ircd_gpt_tokenv *const restrict out,
|
||||
__global union ircd_gpt_tokenv *const restrict accum,
|
||||
__global const union ircd_gpt_tokenv *const restrict pos,
|
||||
__global const union ircd_gpt_tokenv *const restrict vocab,
|
||||
const uint out_idx,
|
||||
|
@ -435,7 +440,7 @@ _ircd_gpt_lm_embed(__global const struct ircd_gpt_ctrl *const ctrl,
|
|||
wte = vocab[token].word[word_idx],
|
||||
wpe = pos[tok_idx].word[word_idx];
|
||||
|
||||
out[out_idx].word[word_idx] = wte + wpe;
|
||||
accum[out_idx].word[word_idx] = wte + wpe;
|
||||
}
|
||||
|
||||
__kernel void
|
||||
|
@ -451,9 +456,7 @@ ircd_gpt_lm_embed(__global const struct ircd_gpt_ctrl *const ctrl,
|
|||
ln = get_local_size(0),
|
||||
wi = get_global_offset(0) / ln + get_group_id(0);
|
||||
|
||||
for(uint i = 0; i < ctrl->tokens.count; ++i)
|
||||
if(i % wn == wi)
|
||||
_ircd_gpt_lm_embed(ctrl, opts, accum, pos, vocab, i, i, li);
|
||||
_ircd_gpt_lm_embed(ctrl, opts, accum, pos, vocab, wi, wi, li);
|
||||
}
|
||||
|
||||
__kernel void
|
||||
|
|
|
@ -224,49 +224,71 @@ ircd::gpt::pipe::exec::exec(task &task,
|
|||
,send_opts
|
||||
{
|
||||
reinterpret_cast<const char *>(task.opts),
|
||||
release? sizeof(gpt::opts): 0
|
||||
release?
|
||||
sizeof(gpt::opts):
|
||||
0
|
||||
}
|
||||
,send_ctrl
|
||||
{
|
||||
reinterpret_cast<const char *>(task.ctrl),
|
||||
release? sizeof(gpt::ctrl): 0
|
||||
release?
|
||||
sizeof(gpt::ctrl):
|
||||
0
|
||||
}
|
||||
,send_coil
|
||||
{
|
||||
reinterpret_cast<const char *>(gpt::model::default_model),
|
||||
release && desc->model->invalid? (sizeof(gpt::model::block) * 12 + sizeof(gpt::model::norm)): 0
|
||||
release && desc->model->invalid?
|
||||
(sizeof(gpt::model::block) * 12 + sizeof(gpt::model::norm)):
|
||||
0
|
||||
}
|
||||
,send_head
|
||||
{
|
||||
reinterpret_cast<const char *>(&gpt::model::default_model->word),
|
||||
release && desc->model->invalid? sizeof(gpt::model::embed): 0
|
||||
release && desc->model->invalid?
|
||||
sizeof(gpt::model::embed):
|
||||
0
|
||||
}
|
||||
,recv_ctrl
|
||||
{
|
||||
reinterpret_cast<char *>(task.ctrl),
|
||||
acquire? sizeof(gpt::ctrl): 0
|
||||
acquire?
|
||||
sizeof(gpt::ctrl):
|
||||
0
|
||||
}
|
||||
,range_lm_embed
|
||||
{
|
||||
{ std::min(tokens, 12UL) * 192UL, 0, },
|
||||
{ 192UL, 0, },
|
||||
}
|
||||
,range_negative
|
||||
,range_full
|
||||
{
|
||||
{ tokens * 192UL, 0, },
|
||||
{ 192UL, 0, },
|
||||
}
|
||||
,range_positive
|
||||
{
|
||||
{ tokens * 192UL, 0, },
|
||||
{ 192UL, 0, },
|
||||
}
|
||||
,range_lm_norm
|
||||
,range_last
|
||||
{
|
||||
{ 1 * 192UL, 0 },
|
||||
{ 192UL, 0 },
|
||||
{ (tokens - 1) * 192UL, 0 },
|
||||
}
|
||||
,range_lm_embed
|
||||
{
|
||||
release?
|
||||
range_full:
|
||||
range_last
|
||||
}
|
||||
,range_negative
|
||||
{
|
||||
release?
|
||||
range_full:
|
||||
range_last
|
||||
}
|
||||
,range_positive
|
||||
{
|
||||
release?
|
||||
range_full:
|
||||
range_last
|
||||
}
|
||||
,range_lm_norm
|
||||
{
|
||||
range_last
|
||||
}
|
||||
,range_lm_logit
|
||||
{
|
||||
{ 786 * 64UL, 0 }, // align_up(50257) / 64
|
||||
|
@ -453,30 +475,27 @@ ircd::gpt::pipe::desc::desc(pipe::code &code,
|
|||
{
|
||||
&code
|
||||
}
|
||||
,state
|
||||
{
|
||||
0
|
||||
+ 12 * 512 * 3 * 768 * sizeof(float),
|
||||
mutable_buffer{},
|
||||
}
|
||||
,master
|
||||
{
|
||||
0
|
||||
+ 512 * 3 * 768 * sizeof(float)
|
||||
+ 512 * 768 * sizeof(float)
|
||||
+ 65536 * sizeof(float)
|
||||
+ 65536 * sizeof(float)
|
||||
+ 65536 * sizeof(float)
|
||||
,mutable_buffer{}
|
||||
}
|
||||
,state
|
||||
{
|
||||
master,
|
||||
{
|
||||
512 * 3 * 768 * sizeof(float),
|
||||
off_t(0),
|
||||
},
|
||||
}
|
||||
,accum
|
||||
{
|
||||
master,
|
||||
{
|
||||
512 * 768 * sizeof(float),
|
||||
state.offset() + off_t(state.size()),
|
||||
off_t(0),
|
||||
},
|
||||
}
|
||||
,logit
|
||||
|
@ -613,13 +632,21 @@ ircd::gpt::pipe::desc::desc(pipe::code &code,
|
|||
|
||||
ircd::gpt::pipe::desc::layer::layer(pipe::desc &desc,
|
||||
const int laynum)
|
||||
:negative
|
||||
:state
|
||||
{
|
||||
desc.state,
|
||||
{
|
||||
512 * 3 * 768 * sizeof(float),
|
||||
laynum * 512 * 3 * 768 * sizeof(float),
|
||||
}
|
||||
}
|
||||
,negative
|
||||
{
|
||||
*desc.code,
|
||||
"ircd_gpt_attn_fcon",
|
||||
desc.ctrl,
|
||||
desc.opts,
|
||||
desc.state,
|
||||
state,
|
||||
desc.accum,
|
||||
desc.model->decode->block[laynum].attn.norm.bias.param,
|
||||
desc.model->decode->block[laynum].attn.norm.weight.param,
|
||||
|
@ -633,7 +660,7 @@ ircd::gpt::pipe::desc::layer::layer(pipe::desc &desc,
|
|||
desc.ctrl,
|
||||
desc.opts,
|
||||
desc.accum,
|
||||
desc.state,
|
||||
state,
|
||||
desc.model->decode->block[laynum].attn.proj.bias.param,
|
||||
desc.model->decode->block[laynum].attn.proj.weight.param,
|
||||
desc.model->decode->block[laynum].ffnn.norm.bias.param,
|
||||
|
|
Loading…
Reference in a new issue