0
0
Fork 0
mirror of https://github.com/matrix-construct/construct synced 2025-01-14 00:34:18 +01:00

ircd::gpt: Add top N and target label result register control block.

This commit is contained in:
Jason Volk 2021-09-17 23:27:23 -07:00
parent 8bd78af128
commit aea6c79fc2
4 changed files with 170 additions and 95 deletions

View file

@ -60,6 +60,41 @@ struct ircd_gpt_ctrl_tokens
ulong witnessed;
};
/// Target label register (abridged)
///
struct ircd_gpt_ctrl_logit
{
/// Vocabulary token.
ushort token;
/// Padding #0.
ushort _pad0;
/// Result logit softmax probability.
float samax;
};
/// Target label register (full)
///
struct ircd_gpt_ctrl_label
{
/// Vocabulary token.
ushort token;
/// Padding #0.
ushort _pad0;
/// Result logit softmax probability.
float samax;
/// Loss state
struct ircd_math_mean loss;
/// Perplexity state
struct ircd_math_mean perp;
}
__attribute__((aligned(64)));
/// Task Control Page
///
/// The control block is shared with our device software. Execution state is
@ -76,24 +111,22 @@ struct ircd_gpt_ctrl
/// buffer; the buffer with the tokens themselves is elsewhere.
struct ircd_gpt_ctrl_tokens tokens;
/// Logit softmax state
/// Top result summary from the softed result logit softmax vector. This
/// is updated each cycle by device software with extended statistics on
/// the top N results.
struct ircd_gpt_ctrl_logit top[16];
/// Target label control block. Results for each target are registered
/// and state is updated each cycle.
struct ircd_gpt_ctrl_label label[4];
/// Result logit vector softmax internal state.
struct ircd_math_samax samax;
/// Target label loss state
struct ircd_math_mean loss;
/// Target label perplexity score state
struct ircd_math_mean perp;
/// Target label certainty difference state
struct ircd_math_mean cert;
/// PRNG xoshiro256 state. This is the de facto random seed which can be
/// set before cycle entry by the host. It is updated by device software
/// when used.
/// PRNG xoshiro256 internal state (note: see opts.h to seed the prng).
ulong rand[4];
/// Perform backprop
/// Perform backprop TODO: XXX
bool prop;
/// Header magic 0xC7012C70

View file

@ -45,6 +45,15 @@ struct ircd_gpt_opts
/// Flip a random coin between 0 and top_p ( = 90 = 0.9) for logit select.
uint top_p;
/// Registers the top n result logits in the ctrl block each cycle.
uint top_n;
/// Number of target labels to register results for in the ctrl block.
uint labels;
/// Bitbar toggling various debug modes
uint debug;
/// Specifies the token context size in tokens.
uint context_tokens;
@ -99,12 +108,6 @@ struct ircd_gpt_opts
/// Testing steps
uint testing_steps;
/// Target label
ushort label;
/// Bitbar toggling various debug modes
ushort debug;
/// Learning rate
float alpha;

View file

@ -135,7 +135,7 @@ ircd::gpt::generate(task &task)
cycles
};
backprop(task, ctrl.loss.mean, *model::default_model, momentum);
backprop(task, ctrl.label[0].loss.mean, *model::default_model, momentum);
}
if(ctrl.prop)
@ -143,17 +143,15 @@ ircd::gpt::generate(task &task)
log::debug
{
log, "Backpropagation of %2.6f in %lu cycles.",
ctrl.loss.mean,
ctrl.label[0].loss.mean,
cycles,
};
ctrl.epic.epoch = 0;
ctrl.loss.mean = 0;
ctrl.loss.last = ctrl.loss.mean;
ctrl.perp.mean = 0;
ctrl.perp.last = ctrl.perp.mean;
ctrl.cert.mean = 0;
ctrl.cert.last = ctrl.cert.mean;
ctrl.label[0].loss.mean = 0;
ctrl.label[0].loss.last = ctrl.label[0].loss.mean;
ctrl.label[0].perp.mean = 0;
ctrl.label[0].perp.last = ctrl.label[0].perp.mean;
ctrl.prop = false;
pipe::default_model->invalid = true;
return;
@ -208,16 +206,16 @@ ircd::gpt::generate_debug(task &task,
ctrl.tokens.count,
ctrl.epic.epoch,
ctrl.epic.cycle,
std::clamp(ctrl.cert.mean * 100.0f, 0.0f, 100.0f),
std::clamp(ctrl.perp.mean, 0.0f, 100.0f),
std::clamp(ctrl.loss.mean, 0.0f, 99.99f),
opts.label == tok? '+': ' ',
0.0f, // cert
std::clamp(ctrl.label[0].perp.mean, 0.0f, 100.0f),
std::clamp(ctrl.label[0].loss.mean, 0.0f, 99.99f),
ctrl.label[0].token == tok? '+': ' ',
' ', // flag place
' ', // flag place
opts.label,
std::clamp(ctrl.loss.last, 0.0f, 99.99f),
std::clamp(ctrl.perp.last, 0.0f, 100.0f),
std::clamp(ctrl.cert.last * 100.0f, 0.0f, 100.0f),
ctrl.label[0].token,
std::clamp(ctrl.label[0].loss.last, 0.0f, 99.99f),
std::clamp(ctrl.label[0].perp.last, 0.0f, 100.0f),
0.0f, // cert
vocab::debug(dbuf, tok).c_str(),
tok,
pretty(tmbuf[0], milliseconds(0ms / bsz), 1).c_str(),
@ -287,6 +285,18 @@ noexcept
{
90U
}
,top_n
{
16
}
,labels
{
0
}
,debug
{
0x01
}
,context_tokens
{
1024U
@ -359,14 +369,6 @@ noexcept
{
5000
}
,label
{
198
}
,debug
{
0x01
}
,alpha
{
0.001f

View file

@ -595,38 +595,75 @@ ircd_gpt_lm_logsm(__global struct ircd_gpt_ctrl *const ctrl,
inline void
__attribute__((always_inline))
ircd_gpt_leave(__global struct ircd_gpt_ctrl *const ctrl,
__constant const struct ircd_gpt_opts *const opts,
const uint li)
ircd_gpt_lm_result_top(__global struct ircd_gpt_ctrl *const ctrl,
__constant const struct ircd_gpt_opts *const opts,
__local const ushort *const restrict idx,
__global const float *const restrict logsm,
__global const float *const restrict logexp,
__global const float *const restrict logit,
const uint i)
{
// No action for other threads right now
if(li != 0)
return;
const ushort
token = idx[i];
if(ctrl->epic.cycle + 1 >= opts->limit)
ctrl->epic.epoch += 1;
const float
samax = logsm[token];
ctrl->epic.cycle += 1;
ctrl->magic = 0xC7012C70U;
ctrl->top[i].token = token;
ctrl->top[i].samax = samax;
}
inline void
__attribute__((always_inline))
ircd_gpt_lm_result(__global struct ircd_gpt_ctrl *const ctrl,
__constant const struct ircd_gpt_opts *const opts,
const uint li,
__local const ushort *const restrict idx,
__global const float *const restrict logsm,
__global const float *const restrict logexp,
__global const float *const restrict logit)
ircd_gpt_lm_result_label(__global struct ircd_gpt_ctrl *const ctrl,
__constant const struct ircd_gpt_opts *const opts,
__local const ushort *const restrict idx,
__global const float *const restrict logsm,
__global const float *const restrict logexp,
__global const float *const restrict logit,
const uint i)
{
// To read from cells other than idx[0] we need this barrier.
barrier(CLK_LOCAL_MEM_FENCE);
__global struct ircd_gpt_ctrl_label
*const label = ctrl->label + i;
// Mask for write-leader
if(li != 0)
return;
const ushort
token = label->token,
sum_sel = ctrl->epic.cycle % 3;
const float
samax = logsm[token],
mean_div = ctrl->epic.cycle + 1.0f;
const float
loss = 0.0f - log(samax),
loss_sum = label->loss.sum[0] + label->loss.sum[1] + label->loss.sum[2] + loss,
loss_mean = loss_sum / mean_div;
const float
perp = (1.0f - samax) * native_log2(opts->logits),
perp_sum = label->perp.sum[0] + label->perp.sum[1] + label->perp.sum[2] + perp,
perp_mean = perp_sum / mean_div;
label->samax = samax;
label->loss.last = loss;
label->loss.sum[sum_sel] += loss;
label->loss.mean = loss_mean;
label->perp.last = perp;
label->perp.sum[sum_sel] += perp;
label->perp.mean = perp_mean;
}
inline void
__attribute__((always_inline))
ircd_gpt_lm_result_select(__global struct ircd_gpt_ctrl *const ctrl,
__constant const struct ircd_gpt_opts *const opts,
__local const ushort *const restrict idx,
__global const float *const restrict logsm,
__global const float *const restrict logexp,
__global const float *const restrict logit)
{
const bool
buffer_full = ctrl->tokens.count >= opts->buffer_tokens;
@ -654,35 +691,19 @@ ircd_gpt_lm_result(__global struct ircd_gpt_ctrl *const ctrl,
ctrl->tokens.head = head;
ctrl->tokens.count = tokens;
ctrl->token[dest] = token;
}
const ushort
ln = get_local_size(0),
next_select = (select + 1) % ln,
next_token = idx[next_select],
sum_sel = ctrl->epic.epoch % 3;
inline void
__attribute__((always_inline))
ircd_gpt_leave(__global struct ircd_gpt_ctrl *const ctrl,
__constant const struct ircd_gpt_opts *const opts,
const uint li)
{
if(ctrl->epic.cycle + 1 >= opts->limit)
ctrl->epic.epoch += 1;
const float
test_lsm = logexp[opts->label],
loss = 0.0f - log(test_lsm * ctrl->samax.lambda),
perp = (1.0f - logsm[token]) * native_log2(opts->logits),
cert = (logsm[token] - logsm[next_token]) / logsm[token],
loss_sum = ctrl->loss.sum[0] + ctrl->loss.sum[1] + ctrl->loss.sum[2] + loss,
perp_sum = ctrl->perp.sum[0] + ctrl->perp.sum[1] + ctrl->perp.sum[2] + perp,
cert_sum = ctrl->cert.sum[0] + ctrl->cert.sum[1] + ctrl->cert.sum[2] + cert,
mean_div = ctrl->epic.epoch + 1.0f,
loss_mean = loss_sum / mean_div,
perp_mean = perp_sum / mean_div,
cert_mean = cert_sum / mean_div;
ctrl->loss.last = loss;
ctrl->loss.sum[sum_sel] += loss;
ctrl->loss.mean = loss_mean;
ctrl->perp.last = perp;
ctrl->perp.sum[sum_sel] += perp;
ctrl->perp.mean = perp_mean;
ctrl->cert.last = cert;
ctrl->cert.sum[sum_sel] += cert;
ctrl->cert.mean = cert_mean;
ctrl->epic.cycle += 1;
ctrl->magic = 0xC7012C70U;
}
__kernel void
@ -707,7 +728,23 @@ ircd_gpt_lm_select(__global struct ircd_gpt_ctrl *const ctrl,
idx[li] = j;
ircd_simt_sort_idx16_flldr(idx, logsm);
ircd_gpt_lm_result(ctrl, opts, li, idx, logsm, logexp, logit);
if(li < opts->top_n)
ircd_gpt_lm_result_top(ctrl, opts, idx, logsm, logexp, logit, li);
if(li < opts->labels)
ircd_gpt_lm_result_label(ctrl, opts, idx, logsm, logexp, logit, li);
// Writes to `idx` from the sort are still pending across threads.
barrier(CLK_LOCAL_MEM_FENCE);
// Mask for write-leader
if(li == 0)
ircd_gpt_lm_result_select(ctrl, opts, idx, logsm, logexp, logit);
if(li != 0)
return;
ircd_gpt_leave(ctrl, opts, li);
}
@ -729,7 +766,7 @@ ircd_gpt_prop_elem(__global const struct ircd_gpt_ctrl *const ctrl,
const float4
param = param_[li],
grad = ctrl->loss.mean,
grad = ctrl->label[0].loss.mean,
alpha[2] = { 1.0f - opts->beta[0], 1.0f - opts->beta[1], },
exp_avg = step? exp_avg_[li]: 0.0f,
exp_avg_sqr = step? exp_avg_sqr_[li]: 0.0f,