0
0
Fork 0
mirror of https://github.com/matrix-construct/construct synced 2024-06-11 06:28:55 +02:00

ircd::gpt::pipe: Various statistical instrumentation.

This commit is contained in:
Jason Volk 2021-04-22 12:15:31 -07:00
parent 2a3c54afa2
commit f61239a52c
3 changed files with 51 additions and 54 deletions

View file

@ -75,24 +75,6 @@ struct ircd_gpt_task
/// State counters for the accept/error sequence codes.
uint accept_seq[4], error_seq[4];
/// Loss for last token of last cycle
float loss;
/// Sum loss over all cycles
float loss_sum;
/// Average loss over all cycles
float loss_mean;
/// Perplexity score for last token of last cycle
float perp;
/// Perplexity sum over all cycles
float perp_sum;
/// Perplexity mean over context
float perp_mean;
/// Logit softmax mu
float samax_mu;
@ -102,11 +84,29 @@ struct ircd_gpt_task
/// Logit softmax lambda
float samax_lambda;
/// Loss for last token of last cycle
float loss;
/// Sum loss over all cycles
float loss_sum[4];
/// Average loss over all cycles
float loss_mean;
/// Perplexity score for last token of last cycle
float perp;
/// Sum ppl over all cycles
float perp_sum[4];
/// Perplexity mean over context
float perp_mean;
/// Certainty difference score for last token of last cycle
float cert;
/// Certainty sum over all cycles
float cert_sum;
/// Sum certainty over all cycles
float cert_sum[4];
/// Certainty mean over context
float cert_mean;
@ -114,9 +114,6 @@ struct ircd_gpt_task
/// Final loss
float l2_loss;
/// Final loss sum
float l2_loss_sum;
/// Final loss mean
float l2_loss_mean;

View file

@ -199,21 +199,21 @@ ircd::gpt::generate(const vector_view<u16> &out,
const size_t report_size = snprintf
(
report, sizeof(report),
"%4u:%-4u %4u:%-4u %1u%1u [ %4.1f%% %6.2f%% %5.2fL %5.2fL ] %5.1f%% %5.1f%% %4.1fL %4.1fL %s %04x %8s %8s | %8s",
"%4lu:%-4u %4lu:%-4lu %6.1f%% %5.1fP %6.3fL [%c%c%c] %5u %6.3fL %6.2fP %5.1f%% %s %04x %8s %8s | %8s",
j + in.size(),
ctrl.tokens,
ctrl.epoch,
ctrl.cycle,
accc[0] + accc[1] + accc[2],
errc[0] + errc[1] + errc[2],
ctrl.cert_mean < 100.0? ctrl.cert_mean: NAN,
ctrl.perp_mean < 100.0? ctrl.perp_mean: NAN,
ctrl.loss_mean < 100.0? ctrl.loss_mean: NAN,
ctrl.l2_loss_mean < 100.0? ctrl.l2_loss_mean: NAN,
ctrl.cert < 100.0? ctrl.cert: NAN,
ctrl.perp < 100.0? ctrl.perp: NAN,
ctrl.loss < 100.0? ctrl.loss: NAN,
ctrl.l2_loss < 100.0? ctrl.l2_loss: NAN,
std::clamp(ctrl.cert_mean * 100.0f, 0.0f, 100.0f),
std::clamp(ctrl.perp_mean, 0.0f, 100.0f),
std::clamp(ctrl.loss_mean, 0.0f, 99.99f),
opts.label == out[j]? '+': ' ',
accc[0] + accc[1] + accc[2] >= 3? 'A': ' ',
errc[0] + errc[1] + errc[2] >= 3? 'E': ' ',
opts.label,
std::clamp(ctrl.loss, 0.0f, 99.99f),
std::clamp(ctrl.perp, 0.0f, 100.0f),
std::clamp(ctrl.cert * 100.0f, 0.0f, 100.0f),
vocab::debug(dbuf, out[j]).c_str(),
out[j],
pretty(tmbuf[0], milliseconds(last_time / bsz), 1).c_str(),
@ -230,7 +230,7 @@ ircd::gpt::generate(const vector_view<u16> &out,
}
ret = ctrl.tokens - in.size();
for(uint i(0); i < 3; ++i)
if ((false)) for(uint i(0); i < 3; ++i)
if(accc_thresh[i] && ctrl.accept_seq[i] >= accc_thresh[i])
{
ret -= (3 - accc_thresh[i]);

View file

@ -190,7 +190,9 @@ ircd_gpt_attn_self(__global const struct ircd_gpt_task *const ctrl,
for(uint i = 0; i < wn; ++i)
sum += self[i][li];
const float lambda = 1.0f / sum;
const float
lambda = 1.0f / sum;
for(uint i = 0; i < wn; ++i)
self[i][li] *= lambda;
}
@ -615,34 +617,32 @@ ircd_gpt_lm_result(__global struct ircd_gpt_task *const ctrl,
ctrl->tokens = tokens;
ctrl->token[dest] = token;
if(opts->top_k > 1)
return;
const ushort
next_select = select + 1,
next_token = idx[next_select];
ln = get_local_size(0),
next_select = (select + 1) % ln,
next_token = idx[next_select],
sum_sel = ctrl->epoch % 3;
const float
test_lsm = logexp[opts->label] * ctrl->samax_lambda,
test_lsm = logexp[opts->label],
loss = 0.0f - log(test_lsm * ctrl->samax_lambda),
perp = logsm[token] * 100.0f,
cert = ((logsm[token] - logsm[next_token]) / logsm[token]) * 100.0f,
loss_sum = ctrl->loss_sum + loss,
perp_sum = ctrl->perp_sum + perp,
cert_sum = ctrl->cert_sum + cert,
mean_div = ctrl->epoch + 1.0f,
loss_mean = loss_sum / mean_div,
perp_mean = perp_sum / mean_div,
cert_mean = cert_sum / mean_div;
perp = (1.0f - logsm[token]) * native_log2(opts->logits),
cert = (logsm[token] - logsm[next_token]) / logsm[token],
loss_sum = ctrl->loss_sum[0] + ctrl->loss_sum[1] + ctrl->loss_sum[2] + loss,
perp_sum = ctrl->perp_sum[0] + ctrl->perp_sum[1] + ctrl->perp_sum[2] + perp,
cert_sum = ctrl->cert_sum[0] + ctrl->cert_sum[1] + ctrl->cert_sum[2] + cert,
loss_mean = loss_sum / (ctrl->epoch + 1.0f),
perp_mean = perp_sum / (ctrl->epoch + 1.0f),
cert_mean = cert_sum / (ctrl->epoch + 1.0f);
ctrl->loss = loss;
ctrl->loss_sum = loss_sum;
ctrl->loss_sum[sum_sel] += loss;
ctrl->loss_mean = loss_mean;
ctrl->perp = perp;
ctrl->perp_sum = perp_sum;
ctrl->perp_sum[sum_sel] += perp;
ctrl->perp_mean = perp_mean;
ctrl->cert = cert;
ctrl->cert_sum = cert_sum;
ctrl->cert_sum[sum_sel] += cert;
ctrl->cert_mean = cert_mean;
}