0
0
Fork 0
mirror of https://github.com/matrix-construct/construct synced 2024-11-28 17:52:54 +01:00

ircd::gpt::pipe: Enable mutable model; fixes for backpropagation; range stub.

This commit is contained in:
Jason Volk 2022-10-16 21:16:55 +00:00
parent 47117dde9a
commit 2609c21913
4 changed files with 15 additions and 13 deletions

View file

@ -30,6 +30,7 @@ namespace ircd::gpt::model
extern std::vector<json::object> default_data; extern std::vector<json::object> default_data;
constexpr auto alignment {4096}; constexpr auto alignment {4096};
extern conf::item<bool> cache_shared;
} }
/// Layer normalization /// Layer normalization

View file

@ -256,10 +256,15 @@ try
} }
,model ,model
{ {
std::make_unique<pipe::model> !gpt::model::cache_shared?
( std::make_unique<pipe::model>
*const_cast<const gpt::model::decoder *>(gpt::model::default_model) (
) *const_cast<const gpt::model::decoder *>(gpt::model::default_model)
):
std::make_unique<pipe::model>
(
*const_cast<gpt::model::decoder *>(gpt::model::default_model)
)
} }
,desc ,desc
{ {

View file

@ -1167,7 +1167,7 @@ ircd_gpt_prop_elem(__global const struct ircd_gpt_ctrl *const ctrl,
const float4 const float4
param = param_[li], param = param_[li],
grad = ctrl->label[0].loss.mean, grad = ctrl->target.loss.mean,
alpha[2] = { 1.0f - opts->beta[0], 1.0f - opts->beta[1], }, alpha[2] = { 1.0f - opts->beta[0], 1.0f - opts->beta[1], },
exp_avg = ts? exp_avg_[li]: 0.0f, exp_avg = ts? exp_avg_[li]: 0.0f,
exp_avg_sqr = ts? exp_avg_sqr_[li]: 0.0f, exp_avg_sqr = ts? exp_avg_sqr_[li]: 0.0f,
@ -1179,13 +1179,9 @@ ircd_gpt_prop_elem(__global const struct ircd_gpt_ctrl *const ctrl,
delta = opts->alpha * (exp_avg_dot / denom), delta = opts->alpha * (exp_avg_dot / denom),
update = param - delta; update = param - delta;
param_[li] = param + FLT_EPSILON; param_[li] = update;
exp_avg_[li] = exp_avg + FLT_EPSILON; exp_avg_[li] = exp_avg_dot;
exp_avg_sqr_[li] = exp_avg_sqr + FLT_EPSILON; exp_avg_sqr_[li] = exp_avg_sqr_dot;
//param_[li] = update;
//exp_avg_[li] = exp_avg_dot;
//exp_avg_sqr_[li] = exp_avg_sqr_dot;
} }
// //

View file

@ -238,7 +238,7 @@ ircd::gpt::pipe::cycle::cycle(gpt::samp &samp)
tokens, tokens,
cached, cached,
true, true,
false, ((false) && gpt::model::cache_shared)
} }
,stage ,stage
{ {