mirror of
https://github.com/matrix-construct/construct
synced 2024-12-01 11:12:51 +01:00
ircd::gpt::pipe: Enable mutable model; fixes for backpropagation; range stub.
This commit is contained in:
parent
47117dde9a
commit
2609c21913
4 changed files with 15 additions and 13 deletions
|
@ -30,6 +30,7 @@ namespace ircd::gpt::model
|
||||||
extern std::vector<json::object> default_data;
|
extern std::vector<json::object> default_data;
|
||||||
|
|
||||||
constexpr auto alignment {4096};
|
constexpr auto alignment {4096};
|
||||||
|
extern conf::item<bool> cache_shared;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Layer normalization
|
/// Layer normalization
|
||||||
|
|
|
@ -256,9 +256,14 @@ try
|
||||||
}
|
}
|
||||||
,model
|
,model
|
||||||
{
|
{
|
||||||
|
!gpt::model::cache_shared?
|
||||||
std::make_unique<pipe::model>
|
std::make_unique<pipe::model>
|
||||||
(
|
(
|
||||||
*const_cast<const gpt::model::decoder *>(gpt::model::default_model)
|
*const_cast<const gpt::model::decoder *>(gpt::model::default_model)
|
||||||
|
):
|
||||||
|
std::make_unique<pipe::model>
|
||||||
|
(
|
||||||
|
*const_cast<gpt::model::decoder *>(gpt::model::default_model)
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
,desc
|
,desc
|
||||||
|
|
|
@ -1167,7 +1167,7 @@ ircd_gpt_prop_elem(__global const struct ircd_gpt_ctrl *const ctrl,
|
||||||
|
|
||||||
const float4
|
const float4
|
||||||
param = param_[li],
|
param = param_[li],
|
||||||
grad = ctrl->label[0].loss.mean,
|
grad = ctrl->target.loss.mean,
|
||||||
alpha[2] = { 1.0f - opts->beta[0], 1.0f - opts->beta[1], },
|
alpha[2] = { 1.0f - opts->beta[0], 1.0f - opts->beta[1], },
|
||||||
exp_avg = ts? exp_avg_[li]: 0.0f,
|
exp_avg = ts? exp_avg_[li]: 0.0f,
|
||||||
exp_avg_sqr = ts? exp_avg_sqr_[li]: 0.0f,
|
exp_avg_sqr = ts? exp_avg_sqr_[li]: 0.0f,
|
||||||
|
@ -1179,13 +1179,9 @@ ircd_gpt_prop_elem(__global const struct ircd_gpt_ctrl *const ctrl,
|
||||||
delta = opts->alpha * (exp_avg_dot / denom),
|
delta = opts->alpha * (exp_avg_dot / denom),
|
||||||
update = param - delta;
|
update = param - delta;
|
||||||
|
|
||||||
param_[li] = param + FLT_EPSILON;
|
param_[li] = update;
|
||||||
exp_avg_[li] = exp_avg + FLT_EPSILON;
|
exp_avg_[li] = exp_avg_dot;
|
||||||
exp_avg_sqr_[li] = exp_avg_sqr + FLT_EPSILON;
|
exp_avg_sqr_[li] = exp_avg_sqr_dot;
|
||||||
|
|
||||||
//param_[li] = update;
|
|
||||||
//exp_avg_[li] = exp_avg_dot;
|
|
||||||
//exp_avg_sqr_[li] = exp_avg_sqr_dot;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//
|
//
|
||||||
|
|
|
@ -238,7 +238,7 @@ ircd::gpt::pipe::cycle::cycle(gpt::samp &samp)
|
||||||
tokens,
|
tokens,
|
||||||
cached,
|
cached,
|
||||||
true,
|
true,
|
||||||
false,
|
((false) && gpt::model::cache_shared)
|
||||||
}
|
}
|
||||||
,stage
|
,stage
|
||||||
{
|
{
|
||||||
|
|
Loading…
Reference in a new issue