mirror of
https://github.com/matrix-construct/construct
synced 2025-02-18 09:40:12 +01:00
ircd::gpt::pipe: Various statistical instrumentation.
This commit is contained in:
parent
2a3c54afa2
commit
f61239a52c
3 changed files with 51 additions and 54 deletions
|
@ -75,24 +75,6 @@ struct ircd_gpt_task
|
|||
/// State counters for the accept/error sequence codes.
|
||||
uint accept_seq[4], error_seq[4];
|
||||
|
||||
/// Loss for last token of last cycle
|
||||
float loss;
|
||||
|
||||
/// Sum loss over all cycles
|
||||
float loss_sum;
|
||||
|
||||
/// Average loss over all cycles
|
||||
float loss_mean;
|
||||
|
||||
/// Perplexity score for last token of last cycle
|
||||
float perp;
|
||||
|
||||
/// Perplexity sum over all cycles
|
||||
float perp_sum;
|
||||
|
||||
/// Perplexity mean over context
|
||||
float perp_mean;
|
||||
|
||||
/// Logit softmax mu
|
||||
float samax_mu;
|
||||
|
||||
|
@ -102,11 +84,29 @@ struct ircd_gpt_task
|
|||
/// Logit softmax lambda
|
||||
float samax_lambda;
|
||||
|
||||
/// Loss for last token of last cycle
|
||||
float loss;
|
||||
|
||||
/// Sum loss over all cycles
|
||||
float loss_sum[4];
|
||||
|
||||
/// Average loss over all cycles
|
||||
float loss_mean;
|
||||
|
||||
/// Perplexity score for last token of last cycle
|
||||
float perp;
|
||||
|
||||
/// Sum ppl over all cycles
|
||||
float perp_sum[4];
|
||||
|
||||
/// Perplexity mean over context
|
||||
float perp_mean;
|
||||
|
||||
/// Certainty difference score for last token of last cycle
|
||||
float cert;
|
||||
|
||||
/// Certainty sum over all cycles
|
||||
float cert_sum;
|
||||
/// Sum certainty over all cycles
|
||||
float cert_sum[4];
|
||||
|
||||
/// Certainty mean over context
|
||||
float cert_mean;
|
||||
|
@ -114,9 +114,6 @@ struct ircd_gpt_task
|
|||
/// Final loss
|
||||
float l2_loss;
|
||||
|
||||
/// Final loss sum
|
||||
float l2_loss_sum;
|
||||
|
||||
/// Final loss mean
|
||||
float l2_loss_mean;
|
||||
|
||||
|
|
24
ircd/gpt.cc
24
ircd/gpt.cc
|
@ -199,21 +199,21 @@ ircd::gpt::generate(const vector_view<u16> &out,
|
|||
const size_t report_size = snprintf
|
||||
(
|
||||
report, sizeof(report),
|
||||
"%4u:%-4u %4u:%-4u %1u%1u [ %4.1f%% %6.2f%% %5.2fL %5.2fL ] %5.1f%% %5.1f%% %4.1fL %4.1fL %s %04x %8s %8s | %8s",
|
||||
"%4lu:%-4u %4lu:%-4lu %6.1f%% %5.1fP %6.3fL [%c%c%c] %5u %6.3fL %6.2fP %5.1f%% %s %04x %8s %8s | %8s",
|
||||
j + in.size(),
|
||||
ctrl.tokens,
|
||||
ctrl.epoch,
|
||||
ctrl.cycle,
|
||||
accc[0] + accc[1] + accc[2],
|
||||
errc[0] + errc[1] + errc[2],
|
||||
ctrl.cert_mean < 100.0? ctrl.cert_mean: NAN,
|
||||
ctrl.perp_mean < 100.0? ctrl.perp_mean: NAN,
|
||||
ctrl.loss_mean < 100.0? ctrl.loss_mean: NAN,
|
||||
ctrl.l2_loss_mean < 100.0? ctrl.l2_loss_mean: NAN,
|
||||
ctrl.cert < 100.0? ctrl.cert: NAN,
|
||||
ctrl.perp < 100.0? ctrl.perp: NAN,
|
||||
ctrl.loss < 100.0? ctrl.loss: NAN,
|
||||
ctrl.l2_loss < 100.0? ctrl.l2_loss: NAN,
|
||||
std::clamp(ctrl.cert_mean * 100.0f, 0.0f, 100.0f),
|
||||
std::clamp(ctrl.perp_mean, 0.0f, 100.0f),
|
||||
std::clamp(ctrl.loss_mean, 0.0f, 99.99f),
|
||||
opts.label == out[j]? '+': ' ',
|
||||
accc[0] + accc[1] + accc[2] >= 3? 'A': ' ',
|
||||
errc[0] + errc[1] + errc[2] >= 3? 'E': ' ',
|
||||
opts.label,
|
||||
std::clamp(ctrl.loss, 0.0f, 99.99f),
|
||||
std::clamp(ctrl.perp, 0.0f, 100.0f),
|
||||
std::clamp(ctrl.cert * 100.0f, 0.0f, 100.0f),
|
||||
vocab::debug(dbuf, out[j]).c_str(),
|
||||
out[j],
|
||||
pretty(tmbuf[0], milliseconds(last_time / bsz), 1).c_str(),
|
||||
|
@ -230,7 +230,7 @@ ircd::gpt::generate(const vector_view<u16> &out,
|
|||
}
|
||||
|
||||
ret = ctrl.tokens - in.size();
|
||||
for(uint i(0); i < 3; ++i)
|
||||
if ((false)) for(uint i(0); i < 3; ++i)
|
||||
if(accc_thresh[i] && ctrl.accept_seq[i] >= accc_thresh[i])
|
||||
{
|
||||
ret -= (3 - accc_thresh[i]);
|
||||
|
|
|
@ -190,7 +190,9 @@ ircd_gpt_attn_self(__global const struct ircd_gpt_task *const ctrl,
|
|||
for(uint i = 0; i < wn; ++i)
|
||||
sum += self[i][li];
|
||||
|
||||
const float lambda = 1.0f / sum;
|
||||
const float
|
||||
lambda = 1.0f / sum;
|
||||
|
||||
for(uint i = 0; i < wn; ++i)
|
||||
self[i][li] *= lambda;
|
||||
}
|
||||
|
@ -615,34 +617,32 @@ ircd_gpt_lm_result(__global struct ircd_gpt_task *const ctrl,
|
|||
ctrl->tokens = tokens;
|
||||
ctrl->token[dest] = token;
|
||||
|
||||
if(opts->top_k > 1)
|
||||
return;
|
||||
|
||||
const ushort
|
||||
next_select = select + 1,
|
||||
next_token = idx[next_select];
|
||||
ln = get_local_size(0),
|
||||
next_select = (select + 1) % ln,
|
||||
next_token = idx[next_select],
|
||||
sum_sel = ctrl->epoch % 3;
|
||||
|
||||
const float
|
||||
test_lsm = logexp[opts->label] * ctrl->samax_lambda,
|
||||
test_lsm = logexp[opts->label],
|
||||
loss = 0.0f - log(test_lsm * ctrl->samax_lambda),
|
||||
perp = logsm[token] * 100.0f,
|
||||
cert = ((logsm[token] - logsm[next_token]) / logsm[token]) * 100.0f,
|
||||
loss_sum = ctrl->loss_sum + loss,
|
||||
perp_sum = ctrl->perp_sum + perp,
|
||||
cert_sum = ctrl->cert_sum + cert,
|
||||
mean_div = ctrl->epoch + 1.0f,
|
||||
loss_mean = loss_sum / mean_div,
|
||||
perp_mean = perp_sum / mean_div,
|
||||
cert_mean = cert_sum / mean_div;
|
||||
perp = (1.0f - logsm[token]) * native_log2(opts->logits),
|
||||
cert = (logsm[token] - logsm[next_token]) / logsm[token],
|
||||
loss_sum = ctrl->loss_sum[0] + ctrl->loss_sum[1] + ctrl->loss_sum[2] + loss,
|
||||
perp_sum = ctrl->perp_sum[0] + ctrl->perp_sum[1] + ctrl->perp_sum[2] + perp,
|
||||
cert_sum = ctrl->cert_sum[0] + ctrl->cert_sum[1] + ctrl->cert_sum[2] + cert,
|
||||
loss_mean = loss_sum / (ctrl->epoch + 1.0f),
|
||||
perp_mean = perp_sum / (ctrl->epoch + 1.0f),
|
||||
cert_mean = cert_sum / (ctrl->epoch + 1.0f);
|
||||
|
||||
ctrl->loss = loss;
|
||||
ctrl->loss_sum = loss_sum;
|
||||
ctrl->loss_sum[sum_sel] += loss;
|
||||
ctrl->loss_mean = loss_mean;
|
||||
ctrl->perp = perp;
|
||||
ctrl->perp_sum = perp_sum;
|
||||
ctrl->perp_sum[sum_sel] += perp;
|
||||
ctrl->perp_mean = perp_mean;
|
||||
ctrl->cert = cert;
|
||||
ctrl->cert_sum = cert_sum;
|
||||
ctrl->cert_sum[sum_sel] += cert;
|
||||
ctrl->cert_mean = cert_mean;
|
||||
}
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue