0
0
Fork 0
mirror of https://github.com/matrix-construct/construct synced 2024-11-29 02:02:38 +01:00

ircd::gpt: Add backpropagation pipe.

This commit is contained in:
Jason Volk 2021-04-17 11:59:30 -07:00
parent 14a1561cad
commit d5eb1e3a87
8 changed files with 470 additions and 43 deletions

View file

@ -21,7 +21,7 @@ namespace ircd::gpt::model
struct decoder;
constexpr auto align {64};
extern const decoder *default_model;
extern decoder *default_model;
extern string_view default_dataset;
extern std::vector<json::object> default_data;
}

View file

@ -98,11 +98,29 @@ struct ircd_gpt_opts
#endif
;
/// Attention unit fcon width multiple
uint attn_mult
#ifdef __cplusplus
{
3U
}
#endif
;
/// MLP unit fcon width multiple
uint ffnn_mult
#ifdef __cplusplus
{
4U
}
#endif
;
/// Attention unit width multiple
uint attn_elems
#ifdef __cplusplus
{
embed_elems * 3
embed_elems * attn_mult
}
#endif
;
@ -111,7 +129,16 @@ struct ircd_gpt_opts
uint ffnn_elems
#ifdef __cplusplus
{
embed_elems * 4
embed_elems * ffnn_mult
}
#endif
;
/// SIMD lane count
uint lanes
#ifdef __cplusplus
{
4U
}
#endif
;
@ -119,7 +146,7 @@ struct ircd_gpt_opts
uint embed_width
#ifdef __cplusplus
{
embed_elems / 4
embed_elems / lanes
}
#endif
;
@ -127,7 +154,7 @@ struct ircd_gpt_opts
uint attn_width
#ifdef __cplusplus
{
attn_elems / 4
attn_elems / lanes
}
#endif
;
@ -135,7 +162,7 @@ struct ircd_gpt_opts
uint attn_height
#ifdef __cplusplus
{
embed_elems / 4
embed_elems / lanes
}
#endif
;
@ -143,7 +170,7 @@ struct ircd_gpt_opts
uint ffnn_width
#ifdef __cplusplus
{
ffnn_elems / 4
ffnn_elems / lanes
}
#endif
;
@ -151,7 +178,7 @@ struct ircd_gpt_opts
uint ffnn_height
#ifdef __cplusplus
{
embed_elems / 4
embed_elems / lanes
}
#endif
;
@ -199,6 +226,31 @@ struct ircd_gpt_opts
}
#endif
;
float alpha
#ifdef __cplusplus
{
0.001
}
#endif
;
float beta[2]
#ifdef __cplusplus
{
0.9, // Beta1
0.999, // Beta2
}
#endif
;
float epsilon
#ifdef __cplusplus
{
0.000001
}
#endif
;
}
__attribute__((aligned(4096)));

View file

@ -33,7 +33,9 @@ struct ircd::gpt::pipe::desc
lm_norm,
lm_logit,
lm_logsm,
lm_select;
lm_select,
lm_norm_backprop,
lm_embed_backprop;
std::unique_ptr<struct desc::layer>
layer[12];
@ -43,8 +45,11 @@ struct ircd::gpt::pipe::desc
struct ircd::gpt::pipe::desc::layer
{
cl::kern negative;
cl::kern positive;
cl::kern
negative,
positive,
backattn,
backffnn;
layer(pipe::desc &, const int);
};

View file

@ -47,6 +47,10 @@ struct ircd_gpt_task
/// Several cycles may occur during each epoch.
ulong epoch;
/// Accumulates the training epoch count for the task. The counter is
/// incremented by one in device software for each backward propagation.
ulong step;
/// Accumulates the number of tokens produced by the task. Several tokens
/// may be produced each epoch, but currently only one token is produced
/// each cycle.
@ -107,6 +111,18 @@ struct ircd_gpt_task
/// Certainty mean over context
float cert_mean;
/// Final loss
float l2_loss;
/// Final loss sum
float l2_loss_sum;
/// Final loss mean
float l2_loss_mean;
/// Perform backprop
bool prop;
/// The token buffer starts at offset 2048 and continues to the end of
/// the page; options specify the size of the tokens buffer in tokens.
/// Additional pages must be attached for larger buffer sizes.

View file

@ -110,7 +110,7 @@ ircd::gpt::generate(const vector_view<u16> &out,
else
embed(data(dst), in[j], j, opts);
#if RB_DEBUG
#if 0 // RB_DEBUG
static char dbuf[512] {0};
char report[1536] {0};
char tmbuf[1][64] {{0}};
@ -199,7 +199,7 @@ ircd::gpt::generate(const vector_view<u16> &out,
const size_t report_size = snprintf
(
report, sizeof(report),
"%4u:%-4u %4u:%-4u %1u%1u [ %4.1f%% %6.2f%% %5.2fL ] %5.1f%% %5.1f%% %4.1fL %s %04x %8s %8s | %8s",
"%4u:%-4u %4u:%-4u %1u%1u [ %4.1f%% %6.2f%% %5.2fL %5.2fL ] %5.1f%% %5.1f%% %4.1fL %4.1fL %s %04x %8s %8s | %8s",
j + in.size(),
ctrl.tokens,
ctrl.epoch,
@ -209,9 +209,11 @@ ircd::gpt::generate(const vector_view<u16> &out,
ctrl.cert_mean < 100.0? ctrl.cert_mean: NAN,
ctrl.perp_mean < 100.0? ctrl.perp_mean: NAN,
ctrl.loss_mean < 100.0? ctrl.loss_mean: NAN,
ctrl.l2_loss_mean < 100.0? ctrl.l2_loss_mean: NAN,
ctrl.cert < 100.0? ctrl.cert: NAN,
ctrl.perp < 100.0? ctrl.perp: NAN,
ctrl.loss < 100.0? ctrl.loss: NAN,
ctrl.l2_loss < 100.0? ctrl.l2_loss: NAN,
vocab::debug(dbuf, out[j]).c_str(),
out[j],
pretty(tmbuf[0], milliseconds(last_time / bsz), 1).c_str(),

View file

@ -29,10 +29,15 @@ ircd_gpt_sgemv(__local float4 *const restrict out,
const uint height,
const uint i)
{
float4 acc = bias[i];
const uint
lanes = 4;
float4
acc = bias[i];
for(uint j = 0; j < height; ++j)
for(uint k = 0; k < 4; ++k)
acc += in[j][k] * weight[width * (j * 4 + k) + i];
for(uint k = 0; k < lanes; ++k)
acc += in[j][k] * weight[width * (j * lanes + k) + i];
out[i] = acc;
}
@ -77,10 +82,10 @@ ircd_gpt_ffnn_fcon(__global const struct ircd_gpt_task *const ctrl,
width = opts->ffnn_width,
height = opts->ffnn_height;
for(uint i = 0; i < 4; ++i)
for(uint i = 0; i < opts->ffnn_mult; ++i)
ircd_gpt_sgemv(out->fcon, in->word, bias, weight, width, height, i * ln + li);
for(uint i = 0; i < 4; ++i)
for(uint i = 0; i < opts->ffnn_mult; ++i)
ircd_gpt_ffnn_gelu(out->fcon, out->fcon, i * ln + li);
}
@ -337,11 +342,11 @@ ircd_gpt_attn_fcon(__global const struct ircd_gpt_task *const ctrl,
barrier(CLK_LOCAL_MEM_FENCE);
// Fully connected
for(uint i = 0; i < 3; ++i)
for(uint i = 0; i < opts->attn_mult; ++i)
ircd_gpt_sgemv(token.fcon, tmp, fcon_bias, fcon_weight, width, height, i * ln + li);
// Export queries, keys, and values.
for(uint i = 0; i < 3; ++i)
for(uint i = 0; i < opts->attn_mult; ++i)
state[wi].proj[i][li] = token.proj[i][li];
}
@ -648,7 +653,273 @@ ircd_gpt_lm_select(__global struct ircd_gpt_task *const ctrl,
if(logsm[j] > logsm[idx[li]])
idx[li] = j;
ircd_simt_sort_idx16_flldr(idx, logsm, ln, li);
ircd_simt_sort_idx16_flldr(idx, logsm);
ircd_gpt_lm_result(ctrl, opts, li, idx, logsm, logexp, logit);
ircd_gpt_leave(ctrl, opts, li);
}
//
// backpropagations
//
inline void
ircd_gpt_prop_elem(__global const struct ircd_gpt_task *const ctrl,
__constant const struct ircd_gpt_opts *const opts,
__global float4 *const restrict param_,
__global float4 *const restrict exp_avg_,
__global float4 *const restrict exp_avg_sqr_)
{
const uint
li = get_local_id(0),
step = ctrl->step;
const float4
param = param_[li],
grad = ctrl->loss_mean,
alpha[2] = { 1.0f - opts->beta[0], 1.0f - opts->beta[1], },
exp_avg = step? exp_avg_[li]: 0.0f,
exp_avg_sqr = step? exp_avg_sqr_[li]: 0.0f,
exp_avg_mul = exp_avg * opts->beta[0],
exp_avg_dot = exp_avg_mul + alpha[0] * grad,
exp_avg_sqr_mul = exp_avg_sqr * opts->beta[1],
exp_avg_sqr_dot = exp_avg_sqr_mul + alpha[1] * grad * grad,
denom = sqrt(exp_avg_sqr_dot) + opts->epsilon,
delta = opts->alpha * (exp_avg_dot / denom),
update = param - delta;
param_[li] = update;
exp_avg_[li] = exp_avg_dot;
exp_avg_sqr_[li] = exp_avg_sqr_dot;
}
__kernel void
ircd_gpt_norm_prop(__global const struct ircd_gpt_task *const ctrl,
__constant const struct ircd_gpt_opts *const opts,
__global union ircd_gpt_tokenv *const restrict bias,
__global union ircd_gpt_tokenv *const restrict bias_m0,
__global union ircd_gpt_tokenv *const restrict bias_m1,
__global union ircd_gpt_tokenv *const restrict weight,
__global union ircd_gpt_tokenv *const restrict weight_m0,
__global union ircd_gpt_tokenv *const restrict weight_m1)
{
const uint
gi = get_global_id(0),
gn = get_global_size(0),
li = get_local_id(0),
ln = get_local_size(0),
wi = get_group_id(0),
wn = get_num_groups(0);
ircd_gpt_prop_elem
(
ctrl, opts,
bias->word,
bias_m0->word,
bias_m1->word
);
ircd_gpt_prop_elem
(
ctrl, opts,
weight->word,
weight_m0->word,
weight_m1->word
);
}
__kernel void
ircd_gpt_coil_prop_attn(__global const struct ircd_gpt_task *const ctrl,
__constant const struct ircd_gpt_opts *const opts,
__global union ircd_gpt_tokenv *const restrict norm_bias,
__global union ircd_gpt_tokenv *const restrict norm_bias_m0,
__global union ircd_gpt_tokenv *const restrict norm_bias_m1,
__global union ircd_gpt_tokenv *const restrict norm_weight,
__global union ircd_gpt_tokenv *const restrict norm_weight_m0,
__global union ircd_gpt_tokenv *const restrict norm_weight_m1,
__global union ircd_gpt_attn_aperaturev *const restrict fcon_bias,
__global union ircd_gpt_attn_aperaturev *const restrict fcon_bias_m0,
__global union ircd_gpt_attn_aperaturev *const restrict fcon_bias_m1,
__global union ircd_gpt_attn_aperaturev *const restrict fcon_weight,
__global union ircd_gpt_attn_aperaturev *const restrict fcon_weight_m0,
__global union ircd_gpt_attn_aperaturev *const restrict fcon_weight_m1,
__global union ircd_gpt_tokenv *const restrict proj_bias,
__global union ircd_gpt_tokenv *const restrict proj_bias_m0,
__global union ircd_gpt_tokenv *const restrict proj_bias_m1,
__global union ircd_gpt_tokenv *const restrict proj_weight,
__global union ircd_gpt_tokenv *const restrict proj_weight_m0,
__global union ircd_gpt_tokenv *const restrict proj_weight_m1)
{
const uint
gi = get_global_id(0),
gn = get_global_size(0),
li = get_local_id(0),
ln = get_local_size(0),
wi = get_group_id(0),
wn = get_num_groups(0);
ircd_gpt_norm_prop
(
ctrl, opts,
norm_bias,
norm_bias_m0,
norm_bias_m1,
norm_weight,
norm_weight_m0,
norm_weight_m1
);
for(uint j = 0; j < 3; ++j)
ircd_gpt_prop_elem
(
ctrl, opts,
fcon_bias->proj[j],
fcon_bias_m0->proj[j],
fcon_bias_m1->proj[j]
);
for(uint i = 0; i < 768; ++i)
for(uint j = 0; j < 3; ++j)
ircd_gpt_prop_elem
(
ctrl, opts,
fcon_weight[i].proj[j],
fcon_weight_m0[i].proj[j],
fcon_weight_m1[i].proj[j]
);
ircd_gpt_prop_elem
(
ctrl, opts,
proj_bias->word,
proj_bias_m0->word,
proj_bias_m1->word
);
for(uint i = 0; i < 768; ++i)
ircd_gpt_prop_elem
(
ctrl, opts,
proj_weight[i].word,
proj_weight_m0[i].word,
proj_weight_m1[i].word
);
}
__kernel void
ircd_gpt_coil_prop_ffnn(__global const struct ircd_gpt_task *const ctrl,
__constant const struct ircd_gpt_opts *const opts,
__global union ircd_gpt_tokenv *const restrict norm_bias,
__global union ircd_gpt_tokenv *const restrict norm_bias_m0,
__global union ircd_gpt_tokenv *const restrict norm_bias_m1,
__global union ircd_gpt_tokenv *const restrict norm_weight,
__global union ircd_gpt_tokenv *const restrict norm_weight_m0,
__global union ircd_gpt_tokenv *const restrict norm_weight_m1,
__global union ircd_gpt_ffnn_aperaturev *const restrict fcon_bias,
__global union ircd_gpt_ffnn_aperaturev *const restrict fcon_bias_m0,
__global union ircd_gpt_ffnn_aperaturev *const restrict fcon_bias_m1,
__global union ircd_gpt_ffnn_aperaturev *const restrict fcon_weight,
__global union ircd_gpt_ffnn_aperaturev *const restrict fcon_weight_m0,
__global union ircd_gpt_ffnn_aperaturev *const restrict fcon_weight_m1,
__global union ircd_gpt_tokenv *const restrict proj_bias,
__global union ircd_gpt_tokenv *const restrict proj_bias_m0,
__global union ircd_gpt_tokenv *const restrict proj_bias_m1,
__global union ircd_gpt_tokenv *const restrict proj_weight,
__global union ircd_gpt_tokenv *const restrict proj_weight_m0,
__global union ircd_gpt_tokenv *const restrict proj_weight_m1)
{
const uint
gi = get_global_id(0),
gn = get_global_size(0),
li = get_local_id(0),
ln = get_local_size(0),
wi = get_group_id(0),
wn = get_num_groups(0);
ircd_gpt_norm_prop
(
ctrl, opts,
norm_bias,
norm_bias_m0,
norm_bias_m1,
norm_weight,
norm_weight_m0,
norm_weight_m1
);
for(uint j = 0; j < 4; ++j)
ircd_gpt_prop_elem
(
ctrl, opts,
fcon_bias->proj[j],
fcon_bias_m0->proj[j],
fcon_bias_m1->proj[j]
);
for(uint i = 0; i < 768; ++i)
for(uint j = 0; j < 4; ++j)
ircd_gpt_prop_elem
(
ctrl, opts,
fcon_weight[i].proj[j],
fcon_weight_m0[i].proj[j],
fcon_weight_m1[i].proj[j]
);
ircd_gpt_prop_elem
(
ctrl, opts,
proj_bias->word,
proj_bias_m0->word,
proj_bias_m1->word
);
for(uint i = 0; i < 3072; ++i)
ircd_gpt_prop_elem
(
ctrl, opts,
proj_weight[i].word,
proj_weight_m0[i].word,
proj_weight_m1[i].word
);
}
__kernel void
ircd_gpt_lm_embed_prop(__global const struct ircd_gpt_task *const ctrl,
__constant const struct ircd_gpt_opts *const opts,
__global union ircd_gpt_tokenv *const restrict pos,
__global union ircd_gpt_tokenv *const restrict pos_m0,
__global union ircd_gpt_tokenv *const restrict pos_m1,
__global union ircd_gpt_tokenv *const restrict token,
__global union ircd_gpt_tokenv *const restrict token_m0,
__global union ircd_gpt_tokenv *const restrict token_m1)
{
const uint
gi = get_global_id(0),
gn = get_global_size(0),
li = get_local_id(0),
ln = get_local_size(0),
wi = get_group_id(0),
wn = get_num_groups(0),
cn = opts->context_tokens / wn,
ci = cn * wi,
tn = opts->logits / wn,
ti = tn * wi;
for(uint i = ci; i < ci + cn; ++i)
ircd_gpt_prop_elem
(
ctrl, opts,
pos[i].word,
pos_m0[i].word,
pos_m1[i].word
);
for(uint i = ti; i < ti + tn; ++i)
ircd_gpt_prop_elem
(
ctrl, opts,
token[i].word,
token_m0[i].word,
token_m1[i].word
);
}

View file

@ -161,12 +161,16 @@ ircd::gpt::model::init_from_cache(const string_view &cache_path)
const fs::fd fd
{
cache_path
cache_path, std::ios::in | std::ios::out
};
fs::map::opts map_opts
{
std::ios::in | std::ios::out
};
fs::map::opts map_opts;
map_opts.huge2mb = true;
map_opts.locked = true;
map_opts.locked = false;
default_model_shm = fs::map
{
fd, map_opts, sizeof(decoder)

View file

@ -14,7 +14,8 @@ namespace ircd::gpt::pipe
static ircd::cl::exec::opts
negative_opts, positive_opts, selfattn_opts,
cathode_opts, anode_opts, lmhead_opts, lmamax_opts;
cathode_opts, anode_opts, lmhead_opts, lmamax_opts,
backprop_opts;
extern conf::item<size_t> flush_cycles;
extern conf::item<size_t> queue_cycles;
@ -53,7 +54,7 @@ ircd::gpt::pipe::handle_quit
void
ircd::gpt::pipe::init()
{
const auto &default_model
const gpt::model::decoder &default_model
{
*gpt::model::default_model
};
@ -475,8 +476,8 @@ ircd::gpt::pipe::desc::desc(pipe::code &code,
ctrl,
opts,
accum,
model.embed->pos,
model.embed->token,
model.embed->pos.param,
model.embed->token.param,
}
,lm_norm
{
@ -485,8 +486,8 @@ ircd::gpt::pipe::desc::desc(pipe::code &code,
ctrl,
opts,
accum,
model.decode->norm.bias,
model.decode->norm.weight,
model.decode->norm.bias.param,
model.decode->norm.weight.param,
}
,lm_logit
{
@ -496,7 +497,7 @@ ircd::gpt::pipe::desc::desc(pipe::code &code,
opts,
logit,
accum,
model.embed->token,
model.embed->token.param,
}
,lm_logsm
{
@ -518,6 +519,32 @@ ircd::gpt::pipe::desc::desc(pipe::code &code,
logexp,
logit,
}
,lm_norm_backprop
{
code,
"ircd_gpt_norm_prop",
ctrl,
opts,
model.decode->norm.bias.param,
model.decode->norm.bias.moment[0],
model.decode->norm.bias.moment[1],
model.decode->norm.weight.param,
model.decode->norm.weight.moment[0],
model.decode->norm.weight.moment[1],
}
,lm_embed_backprop
{
code,
"ircd_gpt_lm_embed_prop",
ctrl,
opts,
model.embed->pos.param,
model.embed->pos.moment[0],
model.embed->pos.moment[1],
model.embed->token.param,
model.embed->token.moment[0],
model.embed->token.moment[1],
}
,layer
{
std::make_unique<struct desc::layer>(*this, 0x00),
@ -550,10 +577,10 @@ ircd::gpt::pipe::desc::layer::layer(pipe::desc &desc,
desc.opts,
desc.state,
desc.accum,
desc.model->decode->block[laynum].attn.norm.bias,
desc.model->decode->block[laynum].attn.norm.weight,
desc.model->decode->block[laynum].attn.fcon.bias,
desc.model->decode->block[laynum].attn.fcon.weight,
desc.model->decode->block[laynum].attn.norm.bias.param,
desc.model->decode->block[laynum].attn.norm.weight.param,
desc.model->decode->block[laynum].attn.fcon.bias.param,
desc.model->decode->block[laynum].attn.fcon.weight.param,
}
,positive
{
@ -564,14 +591,64 @@ ircd::gpt::pipe::desc::layer::layer(pipe::desc &desc,
desc.accum,
desc.state,
desc.model->decode->block[laynum].attn.mask,
desc.model->decode->block[laynum].attn.proj.bias,
desc.model->decode->block[laynum].attn.proj.weight,
desc.model->decode->block[laynum].ffnn.norm.bias,
desc.model->decode->block[laynum].ffnn.norm.weight,
desc.model->decode->block[laynum].ffnn.fcon.bias,
desc.model->decode->block[laynum].ffnn.fcon.weight,
desc.model->decode->block[laynum].ffnn.proj.bias,
desc.model->decode->block[laynum].ffnn.proj.weight,
desc.model->decode->block[laynum].attn.proj.bias.param,
desc.model->decode->block[laynum].attn.proj.weight.param,
desc.model->decode->block[laynum].ffnn.norm.bias.param,
desc.model->decode->block[laynum].ffnn.norm.weight.param,
desc.model->decode->block[laynum].ffnn.fcon.bias.param,
desc.model->decode->block[laynum].ffnn.fcon.weight.param,
desc.model->decode->block[laynum].ffnn.proj.bias.param,
desc.model->decode->block[laynum].ffnn.proj.weight.param,
}
,backattn
{
*desc.code,
"ircd_gpt_coil_prop_attn",
desc.ctrl,
desc.opts,
desc.model->decode->block[laynum].attn.norm.bias.param,
desc.model->decode->block[laynum].attn.norm.bias.moment[0],
desc.model->decode->block[laynum].attn.norm.bias.moment[1],
desc.model->decode->block[laynum].attn.norm.weight.param,
desc.model->decode->block[laynum].attn.norm.weight.moment[0],
desc.model->decode->block[laynum].attn.norm.weight.moment[1],
desc.model->decode->block[laynum].attn.fcon.bias.param,
desc.model->decode->block[laynum].attn.fcon.bias.moment[0],
desc.model->decode->block[laynum].attn.fcon.bias.moment[1],
desc.model->decode->block[laynum].attn.fcon.weight.param,
desc.model->decode->block[laynum].attn.fcon.weight.moment[0],
desc.model->decode->block[laynum].attn.fcon.weight.moment[1],
desc.model->decode->block[laynum].attn.proj.bias.param,
desc.model->decode->block[laynum].attn.proj.bias.moment[0],
desc.model->decode->block[laynum].attn.proj.bias.moment[1],
desc.model->decode->block[laynum].attn.proj.weight.param,
desc.model->decode->block[laynum].attn.proj.weight.moment[0],
desc.model->decode->block[laynum].attn.proj.weight.moment[1],
}
,backffnn
{
*desc.code,
"ircd_gpt_coil_prop_ffnn",
desc.ctrl,
desc.opts,
desc.model->decode->block[laynum].ffnn.norm.bias.param,
desc.model->decode->block[laynum].ffnn.norm.bias.moment[0],
desc.model->decode->block[laynum].ffnn.norm.bias.moment[1],
desc.model->decode->block[laynum].ffnn.norm.weight.param,
desc.model->decode->block[laynum].ffnn.norm.weight.moment[0],
desc.model->decode->block[laynum].ffnn.norm.weight.moment[1],
desc.model->decode->block[laynum].ffnn.fcon.bias.param,
desc.model->decode->block[laynum].ffnn.fcon.bias.moment[0],
desc.model->decode->block[laynum].ffnn.fcon.bias.moment[1],
desc.model->decode->block[laynum].ffnn.fcon.weight.param,
desc.model->decode->block[laynum].ffnn.fcon.weight.moment[0],
desc.model->decode->block[laynum].ffnn.fcon.weight.moment[1],
desc.model->decode->block[laynum].ffnn.proj.bias.param,
desc.model->decode->block[laynum].ffnn.proj.bias.moment[0],
desc.model->decode->block[laynum].ffnn.proj.bias.moment[1],
desc.model->decode->block[laynum].ffnn.proj.weight.param,
desc.model->decode->block[laynum].ffnn.proj.weight.moment[0],
desc.model->decode->block[laynum].ffnn.proj.weight.moment[1],
}
{
}