mirror of
https://github.com/matrix-construct/construct
synced 2024-11-29 02:02:38 +01:00
ircd::gpt: Add backpropagation pipe.
This commit is contained in:
parent
14a1561cad
commit
d5eb1e3a87
8 changed files with 470 additions and 43 deletions
|
@ -21,7 +21,7 @@ namespace ircd::gpt::model
|
|||
struct decoder;
|
||||
|
||||
constexpr auto align {64};
|
||||
extern const decoder *default_model;
|
||||
extern decoder *default_model;
|
||||
extern string_view default_dataset;
|
||||
extern std::vector<json::object> default_data;
|
||||
}
|
||||
|
|
|
@ -98,11 +98,29 @@ struct ircd_gpt_opts
|
|||
#endif
|
||||
;
|
||||
|
||||
/// Attention unit fcon width multiple
|
||||
uint attn_mult
|
||||
#ifdef __cplusplus
|
||||
{
|
||||
3U
|
||||
}
|
||||
#endif
|
||||
;
|
||||
|
||||
/// MLP unit fcon width multiple
|
||||
uint ffnn_mult
|
||||
#ifdef __cplusplus
|
||||
{
|
||||
4U
|
||||
}
|
||||
#endif
|
||||
;
|
||||
|
||||
/// Attention unit width multiple
|
||||
uint attn_elems
|
||||
#ifdef __cplusplus
|
||||
{
|
||||
embed_elems * 3
|
||||
embed_elems * attn_mult
|
||||
}
|
||||
#endif
|
||||
;
|
||||
|
@ -111,7 +129,16 @@ struct ircd_gpt_opts
|
|||
uint ffnn_elems
|
||||
#ifdef __cplusplus
|
||||
{
|
||||
embed_elems * 4
|
||||
embed_elems * ffnn_mult
|
||||
}
|
||||
#endif
|
||||
;
|
||||
|
||||
/// SIMD lane count
|
||||
uint lanes
|
||||
#ifdef __cplusplus
|
||||
{
|
||||
4U
|
||||
}
|
||||
#endif
|
||||
;
|
||||
|
@ -119,7 +146,7 @@ struct ircd_gpt_opts
|
|||
uint embed_width
|
||||
#ifdef __cplusplus
|
||||
{
|
||||
embed_elems / 4
|
||||
embed_elems / lanes
|
||||
}
|
||||
#endif
|
||||
;
|
||||
|
@ -127,7 +154,7 @@ struct ircd_gpt_opts
|
|||
uint attn_width
|
||||
#ifdef __cplusplus
|
||||
{
|
||||
attn_elems / 4
|
||||
attn_elems / lanes
|
||||
}
|
||||
#endif
|
||||
;
|
||||
|
@ -135,7 +162,7 @@ struct ircd_gpt_opts
|
|||
uint attn_height
|
||||
#ifdef __cplusplus
|
||||
{
|
||||
embed_elems / 4
|
||||
embed_elems / lanes
|
||||
}
|
||||
#endif
|
||||
;
|
||||
|
@ -143,7 +170,7 @@ struct ircd_gpt_opts
|
|||
uint ffnn_width
|
||||
#ifdef __cplusplus
|
||||
{
|
||||
ffnn_elems / 4
|
||||
ffnn_elems / lanes
|
||||
}
|
||||
#endif
|
||||
;
|
||||
|
@ -151,7 +178,7 @@ struct ircd_gpt_opts
|
|||
uint ffnn_height
|
||||
#ifdef __cplusplus
|
||||
{
|
||||
embed_elems / 4
|
||||
embed_elems / lanes
|
||||
}
|
||||
#endif
|
||||
;
|
||||
|
@ -199,6 +226,31 @@ struct ircd_gpt_opts
|
|||
}
|
||||
#endif
|
||||
;
|
||||
|
||||
float alpha
|
||||
#ifdef __cplusplus
|
||||
{
|
||||
0.001
|
||||
}
|
||||
#endif
|
||||
;
|
||||
|
||||
float beta[2]
|
||||
#ifdef __cplusplus
|
||||
{
|
||||
0.9, // Beta1
|
||||
0.999, // Beta2
|
||||
}
|
||||
#endif
|
||||
;
|
||||
|
||||
float epsilon
|
||||
#ifdef __cplusplus
|
||||
{
|
||||
0.000001
|
||||
}
|
||||
#endif
|
||||
;
|
||||
}
|
||||
__attribute__((aligned(4096)));
|
||||
|
||||
|
|
|
@ -33,7 +33,9 @@ struct ircd::gpt::pipe::desc
|
|||
lm_norm,
|
||||
lm_logit,
|
||||
lm_logsm,
|
||||
lm_select;
|
||||
lm_select,
|
||||
lm_norm_backprop,
|
||||
lm_embed_backprop;
|
||||
|
||||
std::unique_ptr<struct desc::layer>
|
||||
layer[12];
|
||||
|
@ -43,8 +45,11 @@ struct ircd::gpt::pipe::desc
|
|||
|
||||
struct ircd::gpt::pipe::desc::layer
|
||||
{
|
||||
cl::kern negative;
|
||||
cl::kern positive;
|
||||
cl::kern
|
||||
negative,
|
||||
positive,
|
||||
backattn,
|
||||
backffnn;
|
||||
|
||||
layer(pipe::desc &, const int);
|
||||
};
|
||||
|
|
|
@ -47,6 +47,10 @@ struct ircd_gpt_task
|
|||
/// Several cycles may occur during each epoch.
|
||||
ulong epoch;
|
||||
|
||||
/// Accumulates the training epoch count for the task. The counter is
|
||||
/// incremented by one in device software for each backward propagation.
|
||||
ulong step;
|
||||
|
||||
/// Accumulates the number of tokens produced by the task. Several tokens
|
||||
/// may be produced each epoch, but currently only one token is produced
|
||||
/// each cycle.
|
||||
|
@ -107,6 +111,18 @@ struct ircd_gpt_task
|
|||
/// Certainty mean over context
|
||||
float cert_mean;
|
||||
|
||||
/// Final loss
|
||||
float l2_loss;
|
||||
|
||||
/// Final loss sum
|
||||
float l2_loss_sum;
|
||||
|
||||
/// Final loss mean
|
||||
float l2_loss_mean;
|
||||
|
||||
/// Perform backprop
|
||||
bool prop;
|
||||
|
||||
/// The token buffer starts at offset 2048 and continues to the end of
|
||||
/// the page; options specify the size of the tokens buffer in tokens.
|
||||
/// Additional pages must be attached for larger buffer sizes.
|
||||
|
|
|
@ -110,7 +110,7 @@ ircd::gpt::generate(const vector_view<u16> &out,
|
|||
else
|
||||
embed(data(dst), in[j], j, opts);
|
||||
|
||||
#if RB_DEBUG
|
||||
#if 0 // RB_DEBUG
|
||||
static char dbuf[512] {0};
|
||||
char report[1536] {0};
|
||||
char tmbuf[1][64] {{0}};
|
||||
|
@ -199,7 +199,7 @@ ircd::gpt::generate(const vector_view<u16> &out,
|
|||
const size_t report_size = snprintf
|
||||
(
|
||||
report, sizeof(report),
|
||||
"%4u:%-4u %4u:%-4u %1u%1u [ %4.1f%% %6.2f%% %5.2fL ] %5.1f%% %5.1f%% %4.1fL %s %04x %8s %8s | %8s",
|
||||
"%4u:%-4u %4u:%-4u %1u%1u [ %4.1f%% %6.2f%% %5.2fL %5.2fL ] %5.1f%% %5.1f%% %4.1fL %4.1fL %s %04x %8s %8s | %8s",
|
||||
j + in.size(),
|
||||
ctrl.tokens,
|
||||
ctrl.epoch,
|
||||
|
@ -209,9 +209,11 @@ ircd::gpt::generate(const vector_view<u16> &out,
|
|||
ctrl.cert_mean < 100.0? ctrl.cert_mean: NAN,
|
||||
ctrl.perp_mean < 100.0? ctrl.perp_mean: NAN,
|
||||
ctrl.loss_mean < 100.0? ctrl.loss_mean: NAN,
|
||||
ctrl.l2_loss_mean < 100.0? ctrl.l2_loss_mean: NAN,
|
||||
ctrl.cert < 100.0? ctrl.cert: NAN,
|
||||
ctrl.perp < 100.0? ctrl.perp: NAN,
|
||||
ctrl.loss < 100.0? ctrl.loss: NAN,
|
||||
ctrl.l2_loss < 100.0? ctrl.l2_loss: NAN,
|
||||
vocab::debug(dbuf, out[j]).c_str(),
|
||||
out[j],
|
||||
pretty(tmbuf[0], milliseconds(last_time / bsz), 1).c_str(),
|
||||
|
|
287
ircd/gpt_cl.cl
287
ircd/gpt_cl.cl
|
@ -29,10 +29,15 @@ ircd_gpt_sgemv(__local float4 *const restrict out,
|
|||
const uint height,
|
||||
const uint i)
|
||||
{
|
||||
float4 acc = bias[i];
|
||||
const uint
|
||||
lanes = 4;
|
||||
|
||||
float4
|
||||
acc = bias[i];
|
||||
|
||||
for(uint j = 0; j < height; ++j)
|
||||
for(uint k = 0; k < 4; ++k)
|
||||
acc += in[j][k] * weight[width * (j * 4 + k) + i];
|
||||
for(uint k = 0; k < lanes; ++k)
|
||||
acc += in[j][k] * weight[width * (j * lanes + k) + i];
|
||||
|
||||
out[i] = acc;
|
||||
}
|
||||
|
@ -77,10 +82,10 @@ ircd_gpt_ffnn_fcon(__global const struct ircd_gpt_task *const ctrl,
|
|||
width = opts->ffnn_width,
|
||||
height = opts->ffnn_height;
|
||||
|
||||
for(uint i = 0; i < 4; ++i)
|
||||
for(uint i = 0; i < opts->ffnn_mult; ++i)
|
||||
ircd_gpt_sgemv(out->fcon, in->word, bias, weight, width, height, i * ln + li);
|
||||
|
||||
for(uint i = 0; i < 4; ++i)
|
||||
for(uint i = 0; i < opts->ffnn_mult; ++i)
|
||||
ircd_gpt_ffnn_gelu(out->fcon, out->fcon, i * ln + li);
|
||||
}
|
||||
|
||||
|
@ -337,11 +342,11 @@ ircd_gpt_attn_fcon(__global const struct ircd_gpt_task *const ctrl,
|
|||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
|
||||
// Fully connected
|
||||
for(uint i = 0; i < 3; ++i)
|
||||
for(uint i = 0; i < opts->attn_mult; ++i)
|
||||
ircd_gpt_sgemv(token.fcon, tmp, fcon_bias, fcon_weight, width, height, i * ln + li);
|
||||
|
||||
// Export queries, keys, and values.
|
||||
for(uint i = 0; i < 3; ++i)
|
||||
for(uint i = 0; i < opts->attn_mult; ++i)
|
||||
state[wi].proj[i][li] = token.proj[i][li];
|
||||
}
|
||||
|
||||
|
@ -648,7 +653,273 @@ ircd_gpt_lm_select(__global struct ircd_gpt_task *const ctrl,
|
|||
if(logsm[j] > logsm[idx[li]])
|
||||
idx[li] = j;
|
||||
|
||||
ircd_simt_sort_idx16_flldr(idx, logsm, ln, li);
|
||||
ircd_simt_sort_idx16_flldr(idx, logsm);
|
||||
ircd_gpt_lm_result(ctrl, opts, li, idx, logsm, logexp, logit);
|
||||
ircd_gpt_leave(ctrl, opts, li);
|
||||
}
|
||||
|
||||
//
|
||||
// backpropagations
|
||||
//
|
||||
|
||||
inline void
|
||||
ircd_gpt_prop_elem(__global const struct ircd_gpt_task *const ctrl,
|
||||
__constant const struct ircd_gpt_opts *const opts,
|
||||
__global float4 *const restrict param_,
|
||||
__global float4 *const restrict exp_avg_,
|
||||
__global float4 *const restrict exp_avg_sqr_)
|
||||
{
|
||||
const uint
|
||||
li = get_local_id(0),
|
||||
step = ctrl->step;
|
||||
|
||||
const float4
|
||||
param = param_[li],
|
||||
grad = ctrl->loss_mean,
|
||||
alpha[2] = { 1.0f - opts->beta[0], 1.0f - opts->beta[1], },
|
||||
exp_avg = step? exp_avg_[li]: 0.0f,
|
||||
exp_avg_sqr = step? exp_avg_sqr_[li]: 0.0f,
|
||||
exp_avg_mul = exp_avg * opts->beta[0],
|
||||
exp_avg_dot = exp_avg_mul + alpha[0] * grad,
|
||||
exp_avg_sqr_mul = exp_avg_sqr * opts->beta[1],
|
||||
exp_avg_sqr_dot = exp_avg_sqr_mul + alpha[1] * grad * grad,
|
||||
denom = sqrt(exp_avg_sqr_dot) + opts->epsilon,
|
||||
delta = opts->alpha * (exp_avg_dot / denom),
|
||||
update = param - delta;
|
||||
|
||||
param_[li] = update;
|
||||
exp_avg_[li] = exp_avg_dot;
|
||||
exp_avg_sqr_[li] = exp_avg_sqr_dot;
|
||||
}
|
||||
|
||||
__kernel void
|
||||
ircd_gpt_norm_prop(__global const struct ircd_gpt_task *const ctrl,
|
||||
__constant const struct ircd_gpt_opts *const opts,
|
||||
__global union ircd_gpt_tokenv *const restrict bias,
|
||||
__global union ircd_gpt_tokenv *const restrict bias_m0,
|
||||
__global union ircd_gpt_tokenv *const restrict bias_m1,
|
||||
__global union ircd_gpt_tokenv *const restrict weight,
|
||||
__global union ircd_gpt_tokenv *const restrict weight_m0,
|
||||
__global union ircd_gpt_tokenv *const restrict weight_m1)
|
||||
{
|
||||
const uint
|
||||
gi = get_global_id(0),
|
||||
gn = get_global_size(0),
|
||||
li = get_local_id(0),
|
||||
ln = get_local_size(0),
|
||||
wi = get_group_id(0),
|
||||
wn = get_num_groups(0);
|
||||
|
||||
ircd_gpt_prop_elem
|
||||
(
|
||||
ctrl, opts,
|
||||
bias->word,
|
||||
bias_m0->word,
|
||||
bias_m1->word
|
||||
);
|
||||
|
||||
ircd_gpt_prop_elem
|
||||
(
|
||||
ctrl, opts,
|
||||
weight->word,
|
||||
weight_m0->word,
|
||||
weight_m1->word
|
||||
);
|
||||
}
|
||||
|
||||
__kernel void
|
||||
ircd_gpt_coil_prop_attn(__global const struct ircd_gpt_task *const ctrl,
|
||||
__constant const struct ircd_gpt_opts *const opts,
|
||||
__global union ircd_gpt_tokenv *const restrict norm_bias,
|
||||
__global union ircd_gpt_tokenv *const restrict norm_bias_m0,
|
||||
__global union ircd_gpt_tokenv *const restrict norm_bias_m1,
|
||||
__global union ircd_gpt_tokenv *const restrict norm_weight,
|
||||
__global union ircd_gpt_tokenv *const restrict norm_weight_m0,
|
||||
__global union ircd_gpt_tokenv *const restrict norm_weight_m1,
|
||||
__global union ircd_gpt_attn_aperaturev *const restrict fcon_bias,
|
||||
__global union ircd_gpt_attn_aperaturev *const restrict fcon_bias_m0,
|
||||
__global union ircd_gpt_attn_aperaturev *const restrict fcon_bias_m1,
|
||||
__global union ircd_gpt_attn_aperaturev *const restrict fcon_weight,
|
||||
__global union ircd_gpt_attn_aperaturev *const restrict fcon_weight_m0,
|
||||
__global union ircd_gpt_attn_aperaturev *const restrict fcon_weight_m1,
|
||||
__global union ircd_gpt_tokenv *const restrict proj_bias,
|
||||
__global union ircd_gpt_tokenv *const restrict proj_bias_m0,
|
||||
__global union ircd_gpt_tokenv *const restrict proj_bias_m1,
|
||||
__global union ircd_gpt_tokenv *const restrict proj_weight,
|
||||
__global union ircd_gpt_tokenv *const restrict proj_weight_m0,
|
||||
__global union ircd_gpt_tokenv *const restrict proj_weight_m1)
|
||||
{
|
||||
const uint
|
||||
gi = get_global_id(0),
|
||||
gn = get_global_size(0),
|
||||
li = get_local_id(0),
|
||||
ln = get_local_size(0),
|
||||
wi = get_group_id(0),
|
||||
wn = get_num_groups(0);
|
||||
|
||||
ircd_gpt_norm_prop
|
||||
(
|
||||
ctrl, opts,
|
||||
norm_bias,
|
||||
norm_bias_m0,
|
||||
norm_bias_m1,
|
||||
norm_weight,
|
||||
norm_weight_m0,
|
||||
norm_weight_m1
|
||||
);
|
||||
|
||||
for(uint j = 0; j < 3; ++j)
|
||||
ircd_gpt_prop_elem
|
||||
(
|
||||
ctrl, opts,
|
||||
fcon_bias->proj[j],
|
||||
fcon_bias_m0->proj[j],
|
||||
fcon_bias_m1->proj[j]
|
||||
);
|
||||
|
||||
for(uint i = 0; i < 768; ++i)
|
||||
for(uint j = 0; j < 3; ++j)
|
||||
ircd_gpt_prop_elem
|
||||
(
|
||||
ctrl, opts,
|
||||
fcon_weight[i].proj[j],
|
||||
fcon_weight_m0[i].proj[j],
|
||||
fcon_weight_m1[i].proj[j]
|
||||
);
|
||||
|
||||
ircd_gpt_prop_elem
|
||||
(
|
||||
ctrl, opts,
|
||||
proj_bias->word,
|
||||
proj_bias_m0->word,
|
||||
proj_bias_m1->word
|
||||
);
|
||||
|
||||
for(uint i = 0; i < 768; ++i)
|
||||
ircd_gpt_prop_elem
|
||||
(
|
||||
ctrl, opts,
|
||||
proj_weight[i].word,
|
||||
proj_weight_m0[i].word,
|
||||
proj_weight_m1[i].word
|
||||
);
|
||||
}
|
||||
|
||||
__kernel void
|
||||
ircd_gpt_coil_prop_ffnn(__global const struct ircd_gpt_task *const ctrl,
|
||||
__constant const struct ircd_gpt_opts *const opts,
|
||||
__global union ircd_gpt_tokenv *const restrict norm_bias,
|
||||
__global union ircd_gpt_tokenv *const restrict norm_bias_m0,
|
||||
__global union ircd_gpt_tokenv *const restrict norm_bias_m1,
|
||||
__global union ircd_gpt_tokenv *const restrict norm_weight,
|
||||
__global union ircd_gpt_tokenv *const restrict norm_weight_m0,
|
||||
__global union ircd_gpt_tokenv *const restrict norm_weight_m1,
|
||||
__global union ircd_gpt_ffnn_aperaturev *const restrict fcon_bias,
|
||||
__global union ircd_gpt_ffnn_aperaturev *const restrict fcon_bias_m0,
|
||||
__global union ircd_gpt_ffnn_aperaturev *const restrict fcon_bias_m1,
|
||||
__global union ircd_gpt_ffnn_aperaturev *const restrict fcon_weight,
|
||||
__global union ircd_gpt_ffnn_aperaturev *const restrict fcon_weight_m0,
|
||||
__global union ircd_gpt_ffnn_aperaturev *const restrict fcon_weight_m1,
|
||||
__global union ircd_gpt_tokenv *const restrict proj_bias,
|
||||
__global union ircd_gpt_tokenv *const restrict proj_bias_m0,
|
||||
__global union ircd_gpt_tokenv *const restrict proj_bias_m1,
|
||||
__global union ircd_gpt_tokenv *const restrict proj_weight,
|
||||
__global union ircd_gpt_tokenv *const restrict proj_weight_m0,
|
||||
__global union ircd_gpt_tokenv *const restrict proj_weight_m1)
|
||||
{
|
||||
const uint
|
||||
gi = get_global_id(0),
|
||||
gn = get_global_size(0),
|
||||
li = get_local_id(0),
|
||||
ln = get_local_size(0),
|
||||
wi = get_group_id(0),
|
||||
wn = get_num_groups(0);
|
||||
|
||||
ircd_gpt_norm_prop
|
||||
(
|
||||
ctrl, opts,
|
||||
norm_bias,
|
||||
norm_bias_m0,
|
||||
norm_bias_m1,
|
||||
norm_weight,
|
||||
norm_weight_m0,
|
||||
norm_weight_m1
|
||||
);
|
||||
|
||||
for(uint j = 0; j < 4; ++j)
|
||||
ircd_gpt_prop_elem
|
||||
(
|
||||
ctrl, opts,
|
||||
fcon_bias->proj[j],
|
||||
fcon_bias_m0->proj[j],
|
||||
fcon_bias_m1->proj[j]
|
||||
);
|
||||
|
||||
for(uint i = 0; i < 768; ++i)
|
||||
for(uint j = 0; j < 4; ++j)
|
||||
ircd_gpt_prop_elem
|
||||
(
|
||||
ctrl, opts,
|
||||
fcon_weight[i].proj[j],
|
||||
fcon_weight_m0[i].proj[j],
|
||||
fcon_weight_m1[i].proj[j]
|
||||
);
|
||||
|
||||
ircd_gpt_prop_elem
|
||||
(
|
||||
ctrl, opts,
|
||||
proj_bias->word,
|
||||
proj_bias_m0->word,
|
||||
proj_bias_m1->word
|
||||
);
|
||||
|
||||
for(uint i = 0; i < 3072; ++i)
|
||||
ircd_gpt_prop_elem
|
||||
(
|
||||
ctrl, opts,
|
||||
proj_weight[i].word,
|
||||
proj_weight_m0[i].word,
|
||||
proj_weight_m1[i].word
|
||||
);
|
||||
}
|
||||
|
||||
__kernel void
|
||||
ircd_gpt_lm_embed_prop(__global const struct ircd_gpt_task *const ctrl,
|
||||
__constant const struct ircd_gpt_opts *const opts,
|
||||
__global union ircd_gpt_tokenv *const restrict pos,
|
||||
__global union ircd_gpt_tokenv *const restrict pos_m0,
|
||||
__global union ircd_gpt_tokenv *const restrict pos_m1,
|
||||
__global union ircd_gpt_tokenv *const restrict token,
|
||||
__global union ircd_gpt_tokenv *const restrict token_m0,
|
||||
__global union ircd_gpt_tokenv *const restrict token_m1)
|
||||
{
|
||||
const uint
|
||||
gi = get_global_id(0),
|
||||
gn = get_global_size(0),
|
||||
li = get_local_id(0),
|
||||
ln = get_local_size(0),
|
||||
wi = get_group_id(0),
|
||||
wn = get_num_groups(0),
|
||||
cn = opts->context_tokens / wn,
|
||||
ci = cn * wi,
|
||||
tn = opts->logits / wn,
|
||||
ti = tn * wi;
|
||||
|
||||
for(uint i = ci; i < ci + cn; ++i)
|
||||
ircd_gpt_prop_elem
|
||||
(
|
||||
ctrl, opts,
|
||||
pos[i].word,
|
||||
pos_m0[i].word,
|
||||
pos_m1[i].word
|
||||
);
|
||||
|
||||
for(uint i = ti; i < ti + tn; ++i)
|
||||
ircd_gpt_prop_elem
|
||||
(
|
||||
ctrl, opts,
|
||||
token[i].word,
|
||||
token_m0[i].word,
|
||||
token_m1[i].word
|
||||
);
|
||||
}
|
||||
|
|
|
@ -161,12 +161,16 @@ ircd::gpt::model::init_from_cache(const string_view &cache_path)
|
|||
|
||||
const fs::fd fd
|
||||
{
|
||||
cache_path
|
||||
cache_path, std::ios::in | std::ios::out
|
||||
};
|
||||
|
||||
fs::map::opts map_opts
|
||||
{
|
||||
std::ios::in | std::ios::out
|
||||
};
|
||||
|
||||
fs::map::opts map_opts;
|
||||
map_opts.huge2mb = true;
|
||||
map_opts.locked = true;
|
||||
map_opts.locked = false;
|
||||
default_model_shm = fs::map
|
||||
{
|
||||
fd, map_opts, sizeof(decoder)
|
||||
|
|
115
ircd/gpt_pipe.cc
115
ircd/gpt_pipe.cc
|
@ -14,7 +14,8 @@ namespace ircd::gpt::pipe
|
|||
|
||||
static ircd::cl::exec::opts
|
||||
negative_opts, positive_opts, selfattn_opts,
|
||||
cathode_opts, anode_opts, lmhead_opts, lmamax_opts;
|
||||
cathode_opts, anode_opts, lmhead_opts, lmamax_opts,
|
||||
backprop_opts;
|
||||
|
||||
extern conf::item<size_t> flush_cycles;
|
||||
extern conf::item<size_t> queue_cycles;
|
||||
|
@ -53,7 +54,7 @@ ircd::gpt::pipe::handle_quit
|
|||
void
|
||||
ircd::gpt::pipe::init()
|
||||
{
|
||||
const auto &default_model
|
||||
const gpt::model::decoder &default_model
|
||||
{
|
||||
*gpt::model::default_model
|
||||
};
|
||||
|
@ -475,8 +476,8 @@ ircd::gpt::pipe::desc::desc(pipe::code &code,
|
|||
ctrl,
|
||||
opts,
|
||||
accum,
|
||||
model.embed->pos,
|
||||
model.embed->token,
|
||||
model.embed->pos.param,
|
||||
model.embed->token.param,
|
||||
}
|
||||
,lm_norm
|
||||
{
|
||||
|
@ -485,8 +486,8 @@ ircd::gpt::pipe::desc::desc(pipe::code &code,
|
|||
ctrl,
|
||||
opts,
|
||||
accum,
|
||||
model.decode->norm.bias,
|
||||
model.decode->norm.weight,
|
||||
model.decode->norm.bias.param,
|
||||
model.decode->norm.weight.param,
|
||||
}
|
||||
,lm_logit
|
||||
{
|
||||
|
@ -496,7 +497,7 @@ ircd::gpt::pipe::desc::desc(pipe::code &code,
|
|||
opts,
|
||||
logit,
|
||||
accum,
|
||||
model.embed->token,
|
||||
model.embed->token.param,
|
||||
}
|
||||
,lm_logsm
|
||||
{
|
||||
|
@ -518,6 +519,32 @@ ircd::gpt::pipe::desc::desc(pipe::code &code,
|
|||
logexp,
|
||||
logit,
|
||||
}
|
||||
,lm_norm_backprop
|
||||
{
|
||||
code,
|
||||
"ircd_gpt_norm_prop",
|
||||
ctrl,
|
||||
opts,
|
||||
model.decode->norm.bias.param,
|
||||
model.decode->norm.bias.moment[0],
|
||||
model.decode->norm.bias.moment[1],
|
||||
model.decode->norm.weight.param,
|
||||
model.decode->norm.weight.moment[0],
|
||||
model.decode->norm.weight.moment[1],
|
||||
}
|
||||
,lm_embed_backprop
|
||||
{
|
||||
code,
|
||||
"ircd_gpt_lm_embed_prop",
|
||||
ctrl,
|
||||
opts,
|
||||
model.embed->pos.param,
|
||||
model.embed->pos.moment[0],
|
||||
model.embed->pos.moment[1],
|
||||
model.embed->token.param,
|
||||
model.embed->token.moment[0],
|
||||
model.embed->token.moment[1],
|
||||
}
|
||||
,layer
|
||||
{
|
||||
std::make_unique<struct desc::layer>(*this, 0x00),
|
||||
|
@ -550,10 +577,10 @@ ircd::gpt::pipe::desc::layer::layer(pipe::desc &desc,
|
|||
desc.opts,
|
||||
desc.state,
|
||||
desc.accum,
|
||||
desc.model->decode->block[laynum].attn.norm.bias,
|
||||
desc.model->decode->block[laynum].attn.norm.weight,
|
||||
desc.model->decode->block[laynum].attn.fcon.bias,
|
||||
desc.model->decode->block[laynum].attn.fcon.weight,
|
||||
desc.model->decode->block[laynum].attn.norm.bias.param,
|
||||
desc.model->decode->block[laynum].attn.norm.weight.param,
|
||||
desc.model->decode->block[laynum].attn.fcon.bias.param,
|
||||
desc.model->decode->block[laynum].attn.fcon.weight.param,
|
||||
}
|
||||
,positive
|
||||
{
|
||||
|
@ -564,14 +591,64 @@ ircd::gpt::pipe::desc::layer::layer(pipe::desc &desc,
|
|||
desc.accum,
|
||||
desc.state,
|
||||
desc.model->decode->block[laynum].attn.mask,
|
||||
desc.model->decode->block[laynum].attn.proj.bias,
|
||||
desc.model->decode->block[laynum].attn.proj.weight,
|
||||
desc.model->decode->block[laynum].ffnn.norm.bias,
|
||||
desc.model->decode->block[laynum].ffnn.norm.weight,
|
||||
desc.model->decode->block[laynum].ffnn.fcon.bias,
|
||||
desc.model->decode->block[laynum].ffnn.fcon.weight,
|
||||
desc.model->decode->block[laynum].ffnn.proj.bias,
|
||||
desc.model->decode->block[laynum].ffnn.proj.weight,
|
||||
desc.model->decode->block[laynum].attn.proj.bias.param,
|
||||
desc.model->decode->block[laynum].attn.proj.weight.param,
|
||||
desc.model->decode->block[laynum].ffnn.norm.bias.param,
|
||||
desc.model->decode->block[laynum].ffnn.norm.weight.param,
|
||||
desc.model->decode->block[laynum].ffnn.fcon.bias.param,
|
||||
desc.model->decode->block[laynum].ffnn.fcon.weight.param,
|
||||
desc.model->decode->block[laynum].ffnn.proj.bias.param,
|
||||
desc.model->decode->block[laynum].ffnn.proj.weight.param,
|
||||
}
|
||||
,backattn
|
||||
{
|
||||
*desc.code,
|
||||
"ircd_gpt_coil_prop_attn",
|
||||
desc.ctrl,
|
||||
desc.opts,
|
||||
desc.model->decode->block[laynum].attn.norm.bias.param,
|
||||
desc.model->decode->block[laynum].attn.norm.bias.moment[0],
|
||||
desc.model->decode->block[laynum].attn.norm.bias.moment[1],
|
||||
desc.model->decode->block[laynum].attn.norm.weight.param,
|
||||
desc.model->decode->block[laynum].attn.norm.weight.moment[0],
|
||||
desc.model->decode->block[laynum].attn.norm.weight.moment[1],
|
||||
desc.model->decode->block[laynum].attn.fcon.bias.param,
|
||||
desc.model->decode->block[laynum].attn.fcon.bias.moment[0],
|
||||
desc.model->decode->block[laynum].attn.fcon.bias.moment[1],
|
||||
desc.model->decode->block[laynum].attn.fcon.weight.param,
|
||||
desc.model->decode->block[laynum].attn.fcon.weight.moment[0],
|
||||
desc.model->decode->block[laynum].attn.fcon.weight.moment[1],
|
||||
desc.model->decode->block[laynum].attn.proj.bias.param,
|
||||
desc.model->decode->block[laynum].attn.proj.bias.moment[0],
|
||||
desc.model->decode->block[laynum].attn.proj.bias.moment[1],
|
||||
desc.model->decode->block[laynum].attn.proj.weight.param,
|
||||
desc.model->decode->block[laynum].attn.proj.weight.moment[0],
|
||||
desc.model->decode->block[laynum].attn.proj.weight.moment[1],
|
||||
}
|
||||
,backffnn
|
||||
{
|
||||
*desc.code,
|
||||
"ircd_gpt_coil_prop_ffnn",
|
||||
desc.ctrl,
|
||||
desc.opts,
|
||||
desc.model->decode->block[laynum].ffnn.norm.bias.param,
|
||||
desc.model->decode->block[laynum].ffnn.norm.bias.moment[0],
|
||||
desc.model->decode->block[laynum].ffnn.norm.bias.moment[1],
|
||||
desc.model->decode->block[laynum].ffnn.norm.weight.param,
|
||||
desc.model->decode->block[laynum].ffnn.norm.weight.moment[0],
|
||||
desc.model->decode->block[laynum].ffnn.norm.weight.moment[1],
|
||||
desc.model->decode->block[laynum].ffnn.fcon.bias.param,
|
||||
desc.model->decode->block[laynum].ffnn.fcon.bias.moment[0],
|
||||
desc.model->decode->block[laynum].ffnn.fcon.bias.moment[1],
|
||||
desc.model->decode->block[laynum].ffnn.fcon.weight.param,
|
||||
desc.model->decode->block[laynum].ffnn.fcon.weight.moment[0],
|
||||
desc.model->decode->block[laynum].ffnn.fcon.weight.moment[1],
|
||||
desc.model->decode->block[laynum].ffnn.proj.bias.param,
|
||||
desc.model->decode->block[laynum].ffnn.proj.bias.moment[0],
|
||||
desc.model->decode->block[laynum].ffnn.proj.bias.moment[1],
|
||||
desc.model->decode->block[laynum].ffnn.proj.weight.param,
|
||||
desc.model->decode->block[laynum].ffnn.proj.weight.moment[0],
|
||||
desc.model->decode->block[laynum].ffnn.proj.weight.moment[1],
|
||||
}
|
||||
{
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue