diff --git a/include/ircd/gpt/ctrl.h b/include/ircd/gpt/ctrl.h index 4ee1e8631..dad122ae2 100644 --- a/include/ircd/gpt/ctrl.h +++ b/include/ircd/gpt/ctrl.h @@ -60,6 +60,41 @@ struct ircd_gpt_ctrl_tokens ulong witnessed; }; +/// Target label register (abridged) +/// +struct ircd_gpt_ctrl_logit +{ + /// Vocabulary token. + ushort token; + + /// Padding #0. + ushort _pad0; + + /// Result logit softmax probability. + float samax; +}; + +/// Target label register (full) +/// +struct ircd_gpt_ctrl_label +{ + /// Vocabulary token. + ushort token; + + /// Padding #0. + ushort _pad0; + + /// Result logit softmax probability. + float samax; + + /// Loss state + struct ircd_math_mean loss; + + /// Perplexity state + struct ircd_math_mean perp; +} +__attribute__((aligned(64))); + /// Task Control Page /// /// The control block is shared with our device software. Execution state is @@ -76,24 +111,22 @@ struct ircd_gpt_ctrl /// buffer; the buffer with the tokens themselves is elsewhere. struct ircd_gpt_ctrl_tokens tokens; - /// Logit softmax state + /// Top result summary from the softed result logit softmax vector. This + /// is updated each cycle by device software with extended statistics on + /// the top N results. + struct ircd_gpt_ctrl_logit top[16]; + + /// Target label control block. Results for each target are registered + /// and state is updated each cycle. + struct ircd_gpt_ctrl_label label[4]; + + /// Result logit vector softmax internal state. struct ircd_math_samax samax; - /// Target label loss state - struct ircd_math_mean loss; - - /// Target label perplexity score state - struct ircd_math_mean perp; - - /// Target label certainty difference state - struct ircd_math_mean cert; - - /// PRNG xoshiro256 state. This is the de facto random seed which can be - /// set before cycle entry by the host. It is updated by device software - /// when used. + /// PRNG xoshiro256 internal state (note: see opts.h to seed the prng). ulong rand[4]; - /// Perform backprop + /// Perform backprop TODO: XXX bool prop; /// Header magic 0xC7012C70 diff --git a/include/ircd/gpt/opts.h b/include/ircd/gpt/opts.h index 390cc7c03..48a243233 100644 --- a/include/ircd/gpt/opts.h +++ b/include/ircd/gpt/opts.h @@ -45,6 +45,15 @@ struct ircd_gpt_opts /// Flip a random coin between 0 and top_p ( = 90 = 0.9) for logit select. uint top_p; + /// Registers the top n result logits in the ctrl block each cycle. + uint top_n; + + /// Number of target labels to register results for in the ctrl block. + uint labels; + + /// Bitbar toggling various debug modes + uint debug; + /// Specifies the token context size in tokens. uint context_tokens; @@ -99,12 +108,6 @@ struct ircd_gpt_opts /// Testing steps uint testing_steps; - /// Target label - ushort label; - - /// Bitbar toggling various debug modes - ushort debug; - /// Learning rate float alpha; diff --git a/ircd/gpt.cc b/ircd/gpt.cc index 097db6fca..af058b05c 100644 --- a/ircd/gpt.cc +++ b/ircd/gpt.cc @@ -135,7 +135,7 @@ ircd::gpt::generate(task &task) cycles }; - backprop(task, ctrl.loss.mean, *model::default_model, momentum); + backprop(task, ctrl.label[0].loss.mean, *model::default_model, momentum); } if(ctrl.prop) @@ -143,17 +143,15 @@ ircd::gpt::generate(task &task) log::debug { log, "Backpropagation of %2.6f in %lu cycles.", - ctrl.loss.mean, + ctrl.label[0].loss.mean, cycles, }; ctrl.epic.epoch = 0; - ctrl.loss.mean = 0; - ctrl.loss.last = ctrl.loss.mean; - ctrl.perp.mean = 0; - ctrl.perp.last = ctrl.perp.mean; - ctrl.cert.mean = 0; - ctrl.cert.last = ctrl.cert.mean; + ctrl.label[0].loss.mean = 0; + ctrl.label[0].loss.last = ctrl.label[0].loss.mean; + ctrl.label[0].perp.mean = 0; + ctrl.label[0].perp.last = ctrl.label[0].perp.mean; ctrl.prop = false; pipe::default_model->invalid = true; return; @@ -208,16 +206,16 @@ ircd::gpt::generate_debug(task &task, ctrl.tokens.count, ctrl.epic.epoch, ctrl.epic.cycle, - std::clamp(ctrl.cert.mean * 100.0f, 0.0f, 100.0f), - std::clamp(ctrl.perp.mean, 0.0f, 100.0f), - std::clamp(ctrl.loss.mean, 0.0f, 99.99f), - opts.label == tok? '+': ' ', + 0.0f, // cert + std::clamp(ctrl.label[0].perp.mean, 0.0f, 100.0f), + std::clamp(ctrl.label[0].loss.mean, 0.0f, 99.99f), + ctrl.label[0].token == tok? '+': ' ', ' ', // flag place ' ', // flag place - opts.label, - std::clamp(ctrl.loss.last, 0.0f, 99.99f), - std::clamp(ctrl.perp.last, 0.0f, 100.0f), - std::clamp(ctrl.cert.last * 100.0f, 0.0f, 100.0f), + ctrl.label[0].token, + std::clamp(ctrl.label[0].loss.last, 0.0f, 99.99f), + std::clamp(ctrl.label[0].perp.last, 0.0f, 100.0f), + 0.0f, // cert vocab::debug(dbuf, tok).c_str(), tok, pretty(tmbuf[0], milliseconds(0ms / bsz), 1).c_str(), @@ -287,6 +285,18 @@ noexcept { 90U } +,top_n +{ + 16 +} +,labels +{ + 0 +} +,debug +{ + 0x01 +} ,context_tokens { 1024U @@ -359,14 +369,6 @@ noexcept { 5000 } -,label -{ - 198 -} -,debug -{ - 0x01 -} ,alpha { 0.001f diff --git a/ircd/gpt_gpu.cl b/ircd/gpt_gpu.cl index b145f6dcf..b3856486b 100644 --- a/ircd/gpt_gpu.cl +++ b/ircd/gpt_gpu.cl @@ -595,38 +595,75 @@ ircd_gpt_lm_logsm(__global struct ircd_gpt_ctrl *const ctrl, inline void __attribute__((always_inline)) -ircd_gpt_leave(__global struct ircd_gpt_ctrl *const ctrl, - __constant const struct ircd_gpt_opts *const opts, - const uint li) +ircd_gpt_lm_result_top(__global struct ircd_gpt_ctrl *const ctrl, + __constant const struct ircd_gpt_opts *const opts, + __local const ushort *const restrict idx, + __global const float *const restrict logsm, + __global const float *const restrict logexp, + __global const float *const restrict logit, + const uint i) { - // No action for other threads right now - if(li != 0) - return; + const ushort + token = idx[i]; - if(ctrl->epic.cycle + 1 >= opts->limit) - ctrl->epic.epoch += 1; + const float + samax = logsm[token]; - ctrl->epic.cycle += 1; - ctrl->magic = 0xC7012C70U; + ctrl->top[i].token = token; + ctrl->top[i].samax = samax; } inline void __attribute__((always_inline)) -ircd_gpt_lm_result(__global struct ircd_gpt_ctrl *const ctrl, - __constant const struct ircd_gpt_opts *const opts, - const uint li, - __local const ushort *const restrict idx, - __global const float *const restrict logsm, - __global const float *const restrict logexp, - __global const float *const restrict logit) +ircd_gpt_lm_result_label(__global struct ircd_gpt_ctrl *const ctrl, + __constant const struct ircd_gpt_opts *const opts, + __local const ushort *const restrict idx, + __global const float *const restrict logsm, + __global const float *const restrict logexp, + __global const float *const restrict logit, + const uint i) { - // To read from cells other than idx[0] we need this barrier. - barrier(CLK_LOCAL_MEM_FENCE); + __global struct ircd_gpt_ctrl_label + *const label = ctrl->label + i; - // Mask for write-leader - if(li != 0) - return; + const ushort + token = label->token, + sum_sel = ctrl->epic.cycle % 3; + const float + samax = logsm[token], + mean_div = ctrl->epic.cycle + 1.0f; + + const float + loss = 0.0f - log(samax), + loss_sum = label->loss.sum[0] + label->loss.sum[1] + label->loss.sum[2] + loss, + loss_mean = loss_sum / mean_div; + + const float + perp = (1.0f - samax) * native_log2(opts->logits), + perp_sum = label->perp.sum[0] + label->perp.sum[1] + label->perp.sum[2] + perp, + perp_mean = perp_sum / mean_div; + + label->samax = samax; + + label->loss.last = loss; + label->loss.sum[sum_sel] += loss; + label->loss.mean = loss_mean; + + label->perp.last = perp; + label->perp.sum[sum_sel] += perp; + label->perp.mean = perp_mean; +} + +inline void +__attribute__((always_inline)) +ircd_gpt_lm_result_select(__global struct ircd_gpt_ctrl *const ctrl, + __constant const struct ircd_gpt_opts *const opts, + __local const ushort *const restrict idx, + __global const float *const restrict logsm, + __global const float *const restrict logexp, + __global const float *const restrict logit) +{ const bool buffer_full = ctrl->tokens.count >= opts->buffer_tokens; @@ -654,35 +691,19 @@ ircd_gpt_lm_result(__global struct ircd_gpt_ctrl *const ctrl, ctrl->tokens.head = head; ctrl->tokens.count = tokens; ctrl->token[dest] = token; +} - const ushort - ln = get_local_size(0), - next_select = (select + 1) % ln, - next_token = idx[next_select], - sum_sel = ctrl->epic.epoch % 3; +inline void +__attribute__((always_inline)) +ircd_gpt_leave(__global struct ircd_gpt_ctrl *const ctrl, + __constant const struct ircd_gpt_opts *const opts, + const uint li) +{ + if(ctrl->epic.cycle + 1 >= opts->limit) + ctrl->epic.epoch += 1; - const float - test_lsm = logexp[opts->label], - loss = 0.0f - log(test_lsm * ctrl->samax.lambda), - perp = (1.0f - logsm[token]) * native_log2(opts->logits), - cert = (logsm[token] - logsm[next_token]) / logsm[token], - loss_sum = ctrl->loss.sum[0] + ctrl->loss.sum[1] + ctrl->loss.sum[2] + loss, - perp_sum = ctrl->perp.sum[0] + ctrl->perp.sum[1] + ctrl->perp.sum[2] + perp, - cert_sum = ctrl->cert.sum[0] + ctrl->cert.sum[1] + ctrl->cert.sum[2] + cert, - mean_div = ctrl->epic.epoch + 1.0f, - loss_mean = loss_sum / mean_div, - perp_mean = perp_sum / mean_div, - cert_mean = cert_sum / mean_div; - - ctrl->loss.last = loss; - ctrl->loss.sum[sum_sel] += loss; - ctrl->loss.mean = loss_mean; - ctrl->perp.last = perp; - ctrl->perp.sum[sum_sel] += perp; - ctrl->perp.mean = perp_mean; - ctrl->cert.last = cert; - ctrl->cert.sum[sum_sel] += cert; - ctrl->cert.mean = cert_mean; + ctrl->epic.cycle += 1; + ctrl->magic = 0xC7012C70U; } __kernel void @@ -707,7 +728,23 @@ ircd_gpt_lm_select(__global struct ircd_gpt_ctrl *const ctrl, idx[li] = j; ircd_simt_sort_idx16_flldr(idx, logsm); - ircd_gpt_lm_result(ctrl, opts, li, idx, logsm, logexp, logit); + + if(li < opts->top_n) + ircd_gpt_lm_result_top(ctrl, opts, idx, logsm, logexp, logit, li); + + if(li < opts->labels) + ircd_gpt_lm_result_label(ctrl, opts, idx, logsm, logexp, logit, li); + + // Writes to `idx` from the sort are still pending across threads. + barrier(CLK_LOCAL_MEM_FENCE); + + // Mask for write-leader + if(li == 0) + ircd_gpt_lm_result_select(ctrl, opts, idx, logsm, logexp, logit); + + if(li != 0) + return; + ircd_gpt_leave(ctrl, opts, li); } @@ -729,7 +766,7 @@ ircd_gpt_prop_elem(__global const struct ircd_gpt_ctrl *const ctrl, const float4 param = param_[li], - grad = ctrl->loss.mean, + grad = ctrl->label[0].loss.mean, alpha[2] = { 1.0f - opts->beta[0], 1.0f - opts->beta[1], }, exp_avg = step? exp_avg_[li]: 0.0f, exp_avg_sqr = step? exp_avg_sqr_[li]: 0.0f,