0
0
Fork 0
mirror of https://github.com/matrix-construct/construct synced 2024-12-26 07:23:53 +01:00

ircd::gpt::pipe: Optimize pipeline to cache attention state for generations.

This commit is contained in:
Jason Volk 2021-09-16 23:03:44 -07:00
parent c5f159ad58
commit 8a3eeb46f9
6 changed files with 95 additions and 53 deletions

View file

@ -20,14 +20,14 @@ struct ircd::gpt::pipe::desc
pipe::code *code;
cl::data
master,
state, // qry/key/val projection (tokens * embed * 3 * float)
accum, // accumulator (tokens * embed * float)
logit, // result logit vector (50257 * float)
logexp, // outputs distribution (50257 * float)
logsm, // outputs distribution (50257 * float)
ctrl, // control page
opts; // options page
state, // [root] projection (layers * tokens * embed * 3 * float)
master, // [root] single allocation for additional buffers:
accum, // [-sub] accumulator (tokens * embed * float)
logit, // [-sub] result logit vector (50257 * float)
logexp, // [-sub] outputs distribution (50257 * float)
logsm, // [-sub] outputs distribution (50257 * float)
ctrl, // [root] control page
opts; // [root] options page
cl::kern
lm_embed,
@ -46,6 +46,9 @@ struct ircd::gpt::pipe::desc
struct ircd::gpt::pipe::desc::layer
{
cl::data
state; // [-sub] qry/key/val projection (tokens * embed * 3 * float)
cl::kern
negative,
positive,

View file

@ -37,6 +37,8 @@ struct ircd::gpt::pipe::exec
recv_ctrl; // Set when receiving the control page.
cl::kern::range
range_full,
range_last,
range_lm_embed, // Dimension range of the lm_embed kernel.
range_negative, // Dimension range of a layer kernel.
range_positive, // Dimension range of a layer kernel.

View file

@ -51,6 +51,9 @@ struct ircd_gpt_opts
/// Embedding vector elements
uint embed_elems;
/// Cross-attention dimension
uint attn_rank;
/// Attention unit fcon width multiple
uint attn_mult;

View file

@ -317,6 +317,10 @@ noexcept
{
768U
}
,attn_rank
{
12U
}
,attn_mult
{
3U

View file

@ -159,11 +159,13 @@ inline void
__attribute__((flatten, always_inline))
ircd_gpt_attn_self_samax(__global const struct ircd_gpt_ctrl *const ctrl,
__constant const struct ircd_gpt_opts *const opts,
__local float self[][12])
__local float self[][12],
const uint wn)
{
const uint
gn = get_global_size(0),
li = get_local_id(0),
wn = get_num_groups(0);
ln = get_local_size(0);
struct ircd_math_samax samax =
{
@ -201,14 +203,16 @@ ircd_gpt_attn_self(__global const struct ircd_gpt_ctrl *const ctrl,
gn = get_global_size(0),
li = get_local_id(0),
ln = get_local_size(0),
wi = get_group_id(0),
wn = get_num_groups(0),
ti = li % 12,
ki = li / 12;
wi = get_global_offset(0) / ln + get_group_id(0),
wn = ctrl->tokens.count,
ti = li % opts->attn_rank,
ki = li / opts->attn_rank,
kn = ln / opts->attn_rank;
// Low-rank mask
if(li < 12)
if(li < opts->attn_rank)
{
// For each token
for(uint i = 0; i < wn; ++i)
{
// Left-attention mask
@ -219,7 +223,8 @@ ircd_gpt_attn_self(__global const struct ircd_gpt_ctrl *const ctrl,
}
float4 acc = 0.0f;
for(uint k = 0; k < 64/4; ++k)
__attribute__((opencl_unroll_hint))
for(uint k = 0; k < kn; ++k)
{
float4
qry = token[wi].qry.attn[li][k],
@ -236,7 +241,7 @@ ircd_gpt_attn_self(__global const struct ircd_gpt_ctrl *const ctrl,
}
// Three-piece softmax
ircd_gpt_attn_self_samax(ctrl, opts, self);
ircd_gpt_attn_self_samax(ctrl, opts, self, wn);
}
// Propagate to full width for value dot prod.
@ -244,7 +249,7 @@ ircd_gpt_attn_self(__global const struct ircd_gpt_ctrl *const ctrl,
float4 acc = 0.0f;
__attribute__((opencl_unroll_hint))
for(uint i = 0; i < wn; ++i)
for(uint i = 0; i < wi; ++i)
{
const float4
attn = self[i][ti],
@ -420,7 +425,7 @@ inline void
__attribute__((always_inline))
_ircd_gpt_lm_embed(__global const struct ircd_gpt_ctrl *const ctrl,
__constant const struct ircd_gpt_opts *const opts,
__global union ircd_gpt_tokenv *const restrict out,
__global union ircd_gpt_tokenv *const restrict accum,
__global const union ircd_gpt_tokenv *const restrict pos,
__global const union ircd_gpt_tokenv *const restrict vocab,
const uint out_idx,
@ -435,7 +440,7 @@ _ircd_gpt_lm_embed(__global const struct ircd_gpt_ctrl *const ctrl,
wte = vocab[token].word[word_idx],
wpe = pos[tok_idx].word[word_idx];
out[out_idx].word[word_idx] = wte + wpe;
accum[out_idx].word[word_idx] = wte + wpe;
}
__kernel void
@ -451,9 +456,7 @@ ircd_gpt_lm_embed(__global const struct ircd_gpt_ctrl *const ctrl,
ln = get_local_size(0),
wi = get_global_offset(0) / ln + get_group_id(0);
for(uint i = 0; i < ctrl->tokens.count; ++i)
if(i % wn == wi)
_ircd_gpt_lm_embed(ctrl, opts, accum, pos, vocab, i, i, li);
_ircd_gpt_lm_embed(ctrl, opts, accum, pos, vocab, wi, wi, li);
}
__kernel void

View file

@ -224,49 +224,71 @@ ircd::gpt::pipe::exec::exec(task &task,
,send_opts
{
reinterpret_cast<const char *>(task.opts),
release? sizeof(gpt::opts): 0
release?
sizeof(gpt::opts):
0
}
,send_ctrl
{
reinterpret_cast<const char *>(task.ctrl),
release? sizeof(gpt::ctrl): 0
release?
sizeof(gpt::ctrl):
0
}
,send_coil
{
reinterpret_cast<const char *>(gpt::model::default_model),
release && desc->model->invalid? (sizeof(gpt::model::block) * 12 + sizeof(gpt::model::norm)): 0
release && desc->model->invalid?
(sizeof(gpt::model::block) * 12 + sizeof(gpt::model::norm)):
0
}
,send_head
{
reinterpret_cast<const char *>(&gpt::model::default_model->word),
release && desc->model->invalid? sizeof(gpt::model::embed): 0
release && desc->model->invalid?
sizeof(gpt::model::embed):
0
}
,recv_ctrl
{
reinterpret_cast<char *>(task.ctrl),
acquire? sizeof(gpt::ctrl): 0
acquire?
sizeof(gpt::ctrl):
0
}
,range_lm_embed
{
{ std::min(tokens, 12UL) * 192UL, 0, },
{ 192UL, 0, },
}
,range_negative
,range_full
{
{ tokens * 192UL, 0, },
{ 192UL, 0, },
}
,range_positive
{
{ tokens * 192UL, 0, },
{ 192UL, 0, },
}
,range_lm_norm
,range_last
{
{ 1 * 192UL, 0 },
{ 192UL, 0 },
{ (tokens - 1) * 192UL, 0 },
}
,range_lm_embed
{
release?
range_full:
range_last
}
,range_negative
{
release?
range_full:
range_last
}
,range_positive
{
release?
range_full:
range_last
}
,range_lm_norm
{
range_last
}
,range_lm_logit
{
{ 786 * 64UL, 0 }, // align_up(50257) / 64
@ -453,30 +475,27 @@ ircd::gpt::pipe::desc::desc(pipe::code &code,
{
&code
}
,state
{
0
+ 12 * 512 * 3 * 768 * sizeof(float),
mutable_buffer{},
}
,master
{
0
+ 512 * 3 * 768 * sizeof(float)
+ 512 * 768 * sizeof(float)
+ 65536 * sizeof(float)
+ 65536 * sizeof(float)
+ 65536 * sizeof(float)
,mutable_buffer{}
}
,state
{
master,
{
512 * 3 * 768 * sizeof(float),
off_t(0),
},
}
,accum
{
master,
{
512 * 768 * sizeof(float),
state.offset() + off_t(state.size()),
off_t(0),
},
}
,logit
@ -613,13 +632,21 @@ ircd::gpt::pipe::desc::desc(pipe::code &code,
ircd::gpt::pipe::desc::layer::layer(pipe::desc &desc,
const int laynum)
:negative
:state
{
desc.state,
{
512 * 3 * 768 * sizeof(float),
laynum * 512 * 3 * 768 * sizeof(float),
}
}
,negative
{
*desc.code,
"ircd_gpt_attn_fcon",
desc.ctrl,
desc.opts,
desc.state,
state,
desc.accum,
desc.model->decode->block[laynum].attn.norm.bias.param,
desc.model->decode->block[laynum].attn.norm.weight.param,
@ -633,7 +660,7 @@ ircd::gpt::pipe::desc::layer::layer(pipe::desc &desc,
desc.ctrl,
desc.opts,
desc.accum,
desc.state,
state,
desc.model->decode->block[laynum].attn.proj.bias.param,
desc.model->decode->block[laynum].attn.proj.weight.param,
desc.model->decode->block[laynum].ffnn.norm.bias.param,