mirror of
https://github.com/matrix-construct/construct
synced 2024-12-26 15:33:54 +01:00
ircd::gpt: Add top N and target label result register control block.
This commit is contained in:
parent
8bd78af128
commit
aea6c79fc2
4 changed files with 170 additions and 95 deletions
|
@ -60,6 +60,41 @@ struct ircd_gpt_ctrl_tokens
|
|||
ulong witnessed;
|
||||
};
|
||||
|
||||
/// Target label register (abridged)
|
||||
///
|
||||
struct ircd_gpt_ctrl_logit
|
||||
{
|
||||
/// Vocabulary token.
|
||||
ushort token;
|
||||
|
||||
/// Padding #0.
|
||||
ushort _pad0;
|
||||
|
||||
/// Result logit softmax probability.
|
||||
float samax;
|
||||
};
|
||||
|
||||
/// Target label register (full)
|
||||
///
|
||||
struct ircd_gpt_ctrl_label
|
||||
{
|
||||
/// Vocabulary token.
|
||||
ushort token;
|
||||
|
||||
/// Padding #0.
|
||||
ushort _pad0;
|
||||
|
||||
/// Result logit softmax probability.
|
||||
float samax;
|
||||
|
||||
/// Loss state
|
||||
struct ircd_math_mean loss;
|
||||
|
||||
/// Perplexity state
|
||||
struct ircd_math_mean perp;
|
||||
}
|
||||
__attribute__((aligned(64)));
|
||||
|
||||
/// Task Control Page
|
||||
///
|
||||
/// The control block is shared with our device software. Execution state is
|
||||
|
@ -76,24 +111,22 @@ struct ircd_gpt_ctrl
|
|||
/// buffer; the buffer with the tokens themselves is elsewhere.
|
||||
struct ircd_gpt_ctrl_tokens tokens;
|
||||
|
||||
/// Logit softmax state
|
||||
/// Top result summary from the softed result logit softmax vector. This
|
||||
/// is updated each cycle by device software with extended statistics on
|
||||
/// the top N results.
|
||||
struct ircd_gpt_ctrl_logit top[16];
|
||||
|
||||
/// Target label control block. Results for each target are registered
|
||||
/// and state is updated each cycle.
|
||||
struct ircd_gpt_ctrl_label label[4];
|
||||
|
||||
/// Result logit vector softmax internal state.
|
||||
struct ircd_math_samax samax;
|
||||
|
||||
/// Target label loss state
|
||||
struct ircd_math_mean loss;
|
||||
|
||||
/// Target label perplexity score state
|
||||
struct ircd_math_mean perp;
|
||||
|
||||
/// Target label certainty difference state
|
||||
struct ircd_math_mean cert;
|
||||
|
||||
/// PRNG xoshiro256 state. This is the de facto random seed which can be
|
||||
/// set before cycle entry by the host. It is updated by device software
|
||||
/// when used.
|
||||
/// PRNG xoshiro256 internal state (note: see opts.h to seed the prng).
|
||||
ulong rand[4];
|
||||
|
||||
/// Perform backprop
|
||||
/// Perform backprop TODO: XXX
|
||||
bool prop;
|
||||
|
||||
/// Header magic 0xC7012C70
|
||||
|
|
|
@ -45,6 +45,15 @@ struct ircd_gpt_opts
|
|||
/// Flip a random coin between 0 and top_p ( = 90 = 0.9) for logit select.
|
||||
uint top_p;
|
||||
|
||||
/// Registers the top n result logits in the ctrl block each cycle.
|
||||
uint top_n;
|
||||
|
||||
/// Number of target labels to register results for in the ctrl block.
|
||||
uint labels;
|
||||
|
||||
/// Bitbar toggling various debug modes
|
||||
uint debug;
|
||||
|
||||
/// Specifies the token context size in tokens.
|
||||
uint context_tokens;
|
||||
|
||||
|
@ -99,12 +108,6 @@ struct ircd_gpt_opts
|
|||
/// Testing steps
|
||||
uint testing_steps;
|
||||
|
||||
/// Target label
|
||||
ushort label;
|
||||
|
||||
/// Bitbar toggling various debug modes
|
||||
ushort debug;
|
||||
|
||||
/// Learning rate
|
||||
float alpha;
|
||||
|
||||
|
|
50
ircd/gpt.cc
50
ircd/gpt.cc
|
@ -135,7 +135,7 @@ ircd::gpt::generate(task &task)
|
|||
cycles
|
||||
};
|
||||
|
||||
backprop(task, ctrl.loss.mean, *model::default_model, momentum);
|
||||
backprop(task, ctrl.label[0].loss.mean, *model::default_model, momentum);
|
||||
}
|
||||
|
||||
if(ctrl.prop)
|
||||
|
@ -143,17 +143,15 @@ ircd::gpt::generate(task &task)
|
|||
log::debug
|
||||
{
|
||||
log, "Backpropagation of %2.6f in %lu cycles.",
|
||||
ctrl.loss.mean,
|
||||
ctrl.label[0].loss.mean,
|
||||
cycles,
|
||||
};
|
||||
|
||||
ctrl.epic.epoch = 0;
|
||||
ctrl.loss.mean = 0;
|
||||
ctrl.loss.last = ctrl.loss.mean;
|
||||
ctrl.perp.mean = 0;
|
||||
ctrl.perp.last = ctrl.perp.mean;
|
||||
ctrl.cert.mean = 0;
|
||||
ctrl.cert.last = ctrl.cert.mean;
|
||||
ctrl.label[0].loss.mean = 0;
|
||||
ctrl.label[0].loss.last = ctrl.label[0].loss.mean;
|
||||
ctrl.label[0].perp.mean = 0;
|
||||
ctrl.label[0].perp.last = ctrl.label[0].perp.mean;
|
||||
ctrl.prop = false;
|
||||
pipe::default_model->invalid = true;
|
||||
return;
|
||||
|
@ -208,16 +206,16 @@ ircd::gpt::generate_debug(task &task,
|
|||
ctrl.tokens.count,
|
||||
ctrl.epic.epoch,
|
||||
ctrl.epic.cycle,
|
||||
std::clamp(ctrl.cert.mean * 100.0f, 0.0f, 100.0f),
|
||||
std::clamp(ctrl.perp.mean, 0.0f, 100.0f),
|
||||
std::clamp(ctrl.loss.mean, 0.0f, 99.99f),
|
||||
opts.label == tok? '+': ' ',
|
||||
0.0f, // cert
|
||||
std::clamp(ctrl.label[0].perp.mean, 0.0f, 100.0f),
|
||||
std::clamp(ctrl.label[0].loss.mean, 0.0f, 99.99f),
|
||||
ctrl.label[0].token == tok? '+': ' ',
|
||||
' ', // flag place
|
||||
' ', // flag place
|
||||
opts.label,
|
||||
std::clamp(ctrl.loss.last, 0.0f, 99.99f),
|
||||
std::clamp(ctrl.perp.last, 0.0f, 100.0f),
|
||||
std::clamp(ctrl.cert.last * 100.0f, 0.0f, 100.0f),
|
||||
ctrl.label[0].token,
|
||||
std::clamp(ctrl.label[0].loss.last, 0.0f, 99.99f),
|
||||
std::clamp(ctrl.label[0].perp.last, 0.0f, 100.0f),
|
||||
0.0f, // cert
|
||||
vocab::debug(dbuf, tok).c_str(),
|
||||
tok,
|
||||
pretty(tmbuf[0], milliseconds(0ms / bsz), 1).c_str(),
|
||||
|
@ -287,6 +285,18 @@ noexcept
|
|||
{
|
||||
90U
|
||||
}
|
||||
,top_n
|
||||
{
|
||||
16
|
||||
}
|
||||
,labels
|
||||
{
|
||||
0
|
||||
}
|
||||
,debug
|
||||
{
|
||||
0x01
|
||||
}
|
||||
,context_tokens
|
||||
{
|
||||
1024U
|
||||
|
@ -359,14 +369,6 @@ noexcept
|
|||
{
|
||||
5000
|
||||
}
|
||||
,label
|
||||
{
|
||||
198
|
||||
}
|
||||
,debug
|
||||
{
|
||||
0x01
|
||||
}
|
||||
,alpha
|
||||
{
|
||||
0.001f
|
||||
|
|
139
ircd/gpt_gpu.cl
139
ircd/gpt_gpu.cl
|
@ -595,38 +595,75 @@ ircd_gpt_lm_logsm(__global struct ircd_gpt_ctrl *const ctrl,
|
|||
|
||||
inline void
|
||||
__attribute__((always_inline))
|
||||
ircd_gpt_leave(__global struct ircd_gpt_ctrl *const ctrl,
|
||||
__constant const struct ircd_gpt_opts *const opts,
|
||||
const uint li)
|
||||
ircd_gpt_lm_result_top(__global struct ircd_gpt_ctrl *const ctrl,
|
||||
__constant const struct ircd_gpt_opts *const opts,
|
||||
__local const ushort *const restrict idx,
|
||||
__global const float *const restrict logsm,
|
||||
__global const float *const restrict logexp,
|
||||
__global const float *const restrict logit,
|
||||
const uint i)
|
||||
{
|
||||
// No action for other threads right now
|
||||
if(li != 0)
|
||||
return;
|
||||
const ushort
|
||||
token = idx[i];
|
||||
|
||||
if(ctrl->epic.cycle + 1 >= opts->limit)
|
||||
ctrl->epic.epoch += 1;
|
||||
const float
|
||||
samax = logsm[token];
|
||||
|
||||
ctrl->epic.cycle += 1;
|
||||
ctrl->magic = 0xC7012C70U;
|
||||
ctrl->top[i].token = token;
|
||||
ctrl->top[i].samax = samax;
|
||||
}
|
||||
|
||||
inline void
|
||||
__attribute__((always_inline))
|
||||
ircd_gpt_lm_result(__global struct ircd_gpt_ctrl *const ctrl,
|
||||
__constant const struct ircd_gpt_opts *const opts,
|
||||
const uint li,
|
||||
__local const ushort *const restrict idx,
|
||||
__global const float *const restrict logsm,
|
||||
__global const float *const restrict logexp,
|
||||
__global const float *const restrict logit)
|
||||
ircd_gpt_lm_result_label(__global struct ircd_gpt_ctrl *const ctrl,
|
||||
__constant const struct ircd_gpt_opts *const opts,
|
||||
__local const ushort *const restrict idx,
|
||||
__global const float *const restrict logsm,
|
||||
__global const float *const restrict logexp,
|
||||
__global const float *const restrict logit,
|
||||
const uint i)
|
||||
{
|
||||
// To read from cells other than idx[0] we need this barrier.
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
__global struct ircd_gpt_ctrl_label
|
||||
*const label = ctrl->label + i;
|
||||
|
||||
// Mask for write-leader
|
||||
if(li != 0)
|
||||
return;
|
||||
const ushort
|
||||
token = label->token,
|
||||
sum_sel = ctrl->epic.cycle % 3;
|
||||
|
||||
const float
|
||||
samax = logsm[token],
|
||||
mean_div = ctrl->epic.cycle + 1.0f;
|
||||
|
||||
const float
|
||||
loss = 0.0f - log(samax),
|
||||
loss_sum = label->loss.sum[0] + label->loss.sum[1] + label->loss.sum[2] + loss,
|
||||
loss_mean = loss_sum / mean_div;
|
||||
|
||||
const float
|
||||
perp = (1.0f - samax) * native_log2(opts->logits),
|
||||
perp_sum = label->perp.sum[0] + label->perp.sum[1] + label->perp.sum[2] + perp,
|
||||
perp_mean = perp_sum / mean_div;
|
||||
|
||||
label->samax = samax;
|
||||
|
||||
label->loss.last = loss;
|
||||
label->loss.sum[sum_sel] += loss;
|
||||
label->loss.mean = loss_mean;
|
||||
|
||||
label->perp.last = perp;
|
||||
label->perp.sum[sum_sel] += perp;
|
||||
label->perp.mean = perp_mean;
|
||||
}
|
||||
|
||||
inline void
|
||||
__attribute__((always_inline))
|
||||
ircd_gpt_lm_result_select(__global struct ircd_gpt_ctrl *const ctrl,
|
||||
__constant const struct ircd_gpt_opts *const opts,
|
||||
__local const ushort *const restrict idx,
|
||||
__global const float *const restrict logsm,
|
||||
__global const float *const restrict logexp,
|
||||
__global const float *const restrict logit)
|
||||
{
|
||||
const bool
|
||||
buffer_full = ctrl->tokens.count >= opts->buffer_tokens;
|
||||
|
||||
|
@ -654,35 +691,19 @@ ircd_gpt_lm_result(__global struct ircd_gpt_ctrl *const ctrl,
|
|||
ctrl->tokens.head = head;
|
||||
ctrl->tokens.count = tokens;
|
||||
ctrl->token[dest] = token;
|
||||
}
|
||||
|
||||
const ushort
|
||||
ln = get_local_size(0),
|
||||
next_select = (select + 1) % ln,
|
||||
next_token = idx[next_select],
|
||||
sum_sel = ctrl->epic.epoch % 3;
|
||||
inline void
|
||||
__attribute__((always_inline))
|
||||
ircd_gpt_leave(__global struct ircd_gpt_ctrl *const ctrl,
|
||||
__constant const struct ircd_gpt_opts *const opts,
|
||||
const uint li)
|
||||
{
|
||||
if(ctrl->epic.cycle + 1 >= opts->limit)
|
||||
ctrl->epic.epoch += 1;
|
||||
|
||||
const float
|
||||
test_lsm = logexp[opts->label],
|
||||
loss = 0.0f - log(test_lsm * ctrl->samax.lambda),
|
||||
perp = (1.0f - logsm[token]) * native_log2(opts->logits),
|
||||
cert = (logsm[token] - logsm[next_token]) / logsm[token],
|
||||
loss_sum = ctrl->loss.sum[0] + ctrl->loss.sum[1] + ctrl->loss.sum[2] + loss,
|
||||
perp_sum = ctrl->perp.sum[0] + ctrl->perp.sum[1] + ctrl->perp.sum[2] + perp,
|
||||
cert_sum = ctrl->cert.sum[0] + ctrl->cert.sum[1] + ctrl->cert.sum[2] + cert,
|
||||
mean_div = ctrl->epic.epoch + 1.0f,
|
||||
loss_mean = loss_sum / mean_div,
|
||||
perp_mean = perp_sum / mean_div,
|
||||
cert_mean = cert_sum / mean_div;
|
||||
|
||||
ctrl->loss.last = loss;
|
||||
ctrl->loss.sum[sum_sel] += loss;
|
||||
ctrl->loss.mean = loss_mean;
|
||||
ctrl->perp.last = perp;
|
||||
ctrl->perp.sum[sum_sel] += perp;
|
||||
ctrl->perp.mean = perp_mean;
|
||||
ctrl->cert.last = cert;
|
||||
ctrl->cert.sum[sum_sel] += cert;
|
||||
ctrl->cert.mean = cert_mean;
|
||||
ctrl->epic.cycle += 1;
|
||||
ctrl->magic = 0xC7012C70U;
|
||||
}
|
||||
|
||||
__kernel void
|
||||
|
@ -707,7 +728,23 @@ ircd_gpt_lm_select(__global struct ircd_gpt_ctrl *const ctrl,
|
|||
idx[li] = j;
|
||||
|
||||
ircd_simt_sort_idx16_flldr(idx, logsm);
|
||||
ircd_gpt_lm_result(ctrl, opts, li, idx, logsm, logexp, logit);
|
||||
|
||||
if(li < opts->top_n)
|
||||
ircd_gpt_lm_result_top(ctrl, opts, idx, logsm, logexp, logit, li);
|
||||
|
||||
if(li < opts->labels)
|
||||
ircd_gpt_lm_result_label(ctrl, opts, idx, logsm, logexp, logit, li);
|
||||
|
||||
// Writes to `idx` from the sort are still pending across threads.
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
|
||||
// Mask for write-leader
|
||||
if(li == 0)
|
||||
ircd_gpt_lm_result_select(ctrl, opts, idx, logsm, logexp, logit);
|
||||
|
||||
if(li != 0)
|
||||
return;
|
||||
|
||||
ircd_gpt_leave(ctrl, opts, li);
|
||||
}
|
||||
|
||||
|
@ -729,7 +766,7 @@ ircd_gpt_prop_elem(__global const struct ircd_gpt_ctrl *const ctrl,
|
|||
|
||||
const float4
|
||||
param = param_[li],
|
||||
grad = ctrl->loss.mean,
|
||||
grad = ctrl->label[0].loss.mean,
|
||||
alpha[2] = { 1.0f - opts->beta[0], 1.0f - opts->beta[1], },
|
||||
exp_avg = step? exp_avg_[li]: 0.0f,
|
||||
exp_avg_sqr = step? exp_avg_sqr_[li]: 0.0f,
|
||||
|
|
Loading…
Reference in a new issue