mirror of
https://github.com/matrix-construct/construct
synced 2024-11-25 00:02:34 +01:00
ircd::gpt::pipe: Enable mutable model; fixes for backpropagation; range stub.
This commit is contained in:
parent
47117dde9a
commit
2609c21913
4 changed files with 15 additions and 13 deletions
|
@ -30,6 +30,7 @@ namespace ircd::gpt::model
|
|||
extern std::vector<json::object> default_data;
|
||||
|
||||
constexpr auto alignment {4096};
|
||||
extern conf::item<bool> cache_shared;
|
||||
}
|
||||
|
||||
/// Layer normalization
|
||||
|
|
13
ircd/gpt.cc
13
ircd/gpt.cc
|
@ -256,10 +256,15 @@ try
|
|||
}
|
||||
,model
|
||||
{
|
||||
std::make_unique<pipe::model>
|
||||
(
|
||||
*const_cast<const gpt::model::decoder *>(gpt::model::default_model)
|
||||
)
|
||||
!gpt::model::cache_shared?
|
||||
std::make_unique<pipe::model>
|
||||
(
|
||||
*const_cast<const gpt::model::decoder *>(gpt::model::default_model)
|
||||
):
|
||||
std::make_unique<pipe::model>
|
||||
(
|
||||
*const_cast<gpt::model::decoder *>(gpt::model::default_model)
|
||||
)
|
||||
}
|
||||
,desc
|
||||
{
|
||||
|
|
|
@ -1167,7 +1167,7 @@ ircd_gpt_prop_elem(__global const struct ircd_gpt_ctrl *const ctrl,
|
|||
|
||||
const float4
|
||||
param = param_[li],
|
||||
grad = ctrl->label[0].loss.mean,
|
||||
grad = ctrl->target.loss.mean,
|
||||
alpha[2] = { 1.0f - opts->beta[0], 1.0f - opts->beta[1], },
|
||||
exp_avg = ts? exp_avg_[li]: 0.0f,
|
||||
exp_avg_sqr = ts? exp_avg_sqr_[li]: 0.0f,
|
||||
|
@ -1179,13 +1179,9 @@ ircd_gpt_prop_elem(__global const struct ircd_gpt_ctrl *const ctrl,
|
|||
delta = opts->alpha * (exp_avg_dot / denom),
|
||||
update = param - delta;
|
||||
|
||||
param_[li] = param + FLT_EPSILON;
|
||||
exp_avg_[li] = exp_avg + FLT_EPSILON;
|
||||
exp_avg_sqr_[li] = exp_avg_sqr + FLT_EPSILON;
|
||||
|
||||
//param_[li] = update;
|
||||
//exp_avg_[li] = exp_avg_dot;
|
||||
//exp_avg_sqr_[li] = exp_avg_sqr_dot;
|
||||
param_[li] = update;
|
||||
exp_avg_[li] = exp_avg_dot;
|
||||
exp_avg_sqr_[li] = exp_avg_sqr_dot;
|
||||
}
|
||||
|
||||
//
|
||||
|
|
|
@ -238,7 +238,7 @@ ircd::gpt::pipe::cycle::cycle(gpt::samp &samp)
|
|||
tokens,
|
||||
cached,
|
||||
true,
|
||||
false,
|
||||
((false) && gpt::model::cache_shared)
|
||||
}
|
||||
,stage
|
||||
{
|
||||
|
|
Loading…
Reference in a new issue