0
0
Fork 0
mirror of https://github.com/matrix-construct/construct synced 2024-12-28 08:24:08 +01:00

ircd::gpt::pipe: Expansions to the context size.

This commit is contained in:
Jason Volk 2021-04-22 12:17:29 -07:00
parent f61239a52c
commit 5392ba0c7d
2 changed files with 40 additions and 24 deletions

View file

@ -10,8 +10,8 @@
inline void inline void
ircd_gpt_norm_fmad(__local float4 *const restrict out, ircd_gpt_norm_fmad(__local float4 *const out,
__local const float4 *const restrict in, __local const float4 *const in,
__global const float4 *const restrict bias, __global const float4 *const restrict bias,
__global const float4 *const restrict weight, __global const float4 *const restrict weight,
const uint i) const uint i)
@ -36,8 +36,19 @@ ircd_gpt_sgemv(__local float4 *const restrict out,
acc = bias[i]; acc = bias[i];
for(uint j = 0; j < height; ++j) for(uint j = 0; j < height; ++j)
{
const uint
tile = j * lanes;
for(uint k = 0; k < lanes; ++k) for(uint k = 0; k < lanes; ++k)
acc += in[j][k] * weight[width * (j * lanes + k) + i]; {
const uint
row = tile + k,
cell = row * width + i;
acc += in[j][k] * weight[cell];
}
}
out[i] = acc; out[i] = acc;
} }
@ -95,7 +106,6 @@ __attribute__((always_inline))
ircd_gpt_ffnn(__global const struct ircd_gpt_task *const ctrl, ircd_gpt_ffnn(__global const struct ircd_gpt_task *const ctrl,
__constant const struct ircd_gpt_opts *const opts, __constant const struct ircd_gpt_opts *const opts,
__local union ircd_gpt_tokenv *const restrict token, __local union ircd_gpt_tokenv *const restrict token,
__local union ircd_gpt_tokenv *const restrict tmp,
__local union ircd_gpt_ffnn_aperaturev *const restrict buf, __local union ircd_gpt_ffnn_aperaturev *const restrict buf,
__global const float4 *const restrict norm_bias, __global const float4 *const restrict norm_bias,
__global const float4 *const restrict norm_weight, __global const float4 *const restrict norm_weight,
@ -116,13 +126,13 @@ ircd_gpt_ffnn(__global const struct ircd_gpt_task *const ctrl,
// Layer re-normalization // Layer re-normalization
ircd_simt_math_norm_f4lldr(token->word, token->word, buf->word); ircd_simt_math_norm_f4lldr(token->word, token->word, buf->word);
ircd_gpt_norm_fmad(tmp->word, token->word, norm_bias, norm_weight, li); ircd_gpt_norm_fmad(token->word, token->word, norm_bias, norm_weight, li);
// ln's writes are still pending but fcon reads results across threads. // ln's writes are still pending but fcon reads results across threads.
barrier(CLK_LOCAL_MEM_FENCE); barrier(CLK_LOCAL_MEM_FENCE);
// Fully connected // Fully connected
ircd_gpt_ffnn_fcon(ctrl, opts, buf, tmp, fcon_bias, fcon_weight); ircd_gpt_ffnn_fcon(ctrl, opts, buf, token, fcon_bias, fcon_weight);
// fcon's writes are still pending but proj reads results across threads. // fcon's writes are still pending but proj reads results across threads.
barrier(CLK_LOCAL_MEM_FENCE); barrier(CLK_LOCAL_MEM_FENCE);
@ -255,15 +265,19 @@ ircd_gpt_coil(__global const struct ircd_gpt_task *const ctrl,
li = get_local_id(0), li = get_local_id(0),
wi = get_group_id(0); wi = get_group_id(0);
__local float
self[96][12];
__local union ircd_gpt_tokenv __local union ircd_gpt_tokenv
buf1, buf0; buf1, buf0;
__local union ircd_gpt_ffnn_aperaturev __local union
{
union ircd_gpt_ffnn_aperaturev
ffnn_fcon; ffnn_fcon;
float
attn_self[512][12];
}
buf;
// Self-attention backend; this computes the self-attention result now // Self-attention backend; this computes the self-attention result now
// that keys and values are globally visible across tokens. // that keys and values are globally visible across tokens.
ircd_gpt_attn_self ircd_gpt_attn_self
@ -271,7 +285,7 @@ ircd_gpt_coil(__global const struct ircd_gpt_task *const ctrl,
ctrl, ctrl,
opts, opts,
&buf1, &buf1,
self, buf.attn_self,
state, state,
mask mask
); );
@ -305,8 +319,7 @@ ircd_gpt_coil(__global const struct ircd_gpt_task *const ctrl,
ctrl, ctrl,
opts, opts,
&buf0, &buf0,
&buf1, &buf.ffnn_fcon,
&ffnn_fcon,
ffnn_norm_bias, ffnn_norm_bias,
ffnn_norm_weight, ffnn_norm_weight,
ffnn_fcon_bias, ffnn_fcon_bias,
@ -404,9 +417,12 @@ ircd_gpt_lm_embed(__global const struct ircd_gpt_task *const ctrl,
__global const union ircd_gpt_tokenv *const restrict vocab) __global const union ircd_gpt_tokenv *const restrict vocab)
{ {
const uint const uint
li = get_local_id(0); li = get_local_id(0),
wi = get_group_id(0),
wn = get_num_groups(0);
for(uint i = 0; i < ctrl->tokens; ++i) for(uint i = 0; i < ctrl->tokens; ++i)
if(i % wn == wi)
_ircd_gpt_lm_embed(ctrl, opts, accum, pos, vocab, i, i, li); _ircd_gpt_lm_embed(ctrl, opts, accum, pos, vocab, i, i, li);
} }
@ -429,9 +445,9 @@ ircd_gpt_lm_norm(__global const struct ircd_gpt_task *const ctrl,
// Final re-normalization // Final re-normalization
ircd_simt_math_norm_f4lldr(token.word, token.word, tmp.word); ircd_simt_math_norm_f4lldr(token.word, token.word, tmp.word);
ircd_gpt_norm_fmad(tmp.word, token.word, norm_bias, norm_weight, li); ircd_gpt_norm_fmad(token.word, token.word, norm_bias, norm_weight, li);
accum[wi].word[li] = tmp.word[li]; accum[wi].word[li] = token.word[li];
} }
__kernel void __kernel void

View file

@ -277,7 +277,7 @@ ircd::gpt::pipe::exec::exec(task &task,
} }
,range_lm_embed ,range_lm_embed
{ {
{ 1 * 192UL, 0, }, { std::min(tokens, 12UL) * 192UL, 0, },
{ 192UL, 0, }, { 192UL, 0, },
} }
,range_negative ,range_negative
@ -459,12 +459,12 @@ ircd::gpt::pipe::desc::desc(pipe::code &code,
} }
,state ,state
{ {
96 * 3 * 768 * sizeof(float), 512 * 3 * 768 * sizeof(float),
mutable_buffer{} mutable_buffer{}
} }
,accum ,accum
{ {
96 * 768 * sizeof(float), 512 * 768 * sizeof(float),
mutable_buffer{} mutable_buffer{}
} }
,logit ,logit