mirror of
https://github.com/matrix-construct/construct
synced 2024-11-18 07:50:57 +01:00
ircd::gpt: Backpropagate adaptive moment estimations.
This commit is contained in:
parent
da5fa1217a
commit
1f49f71530
1 changed files with 254 additions and 0 deletions
254
ircd/gpt.cc
254
ircd/gpt.cc
|
@ -10,6 +10,16 @@
|
|||
|
||||
namespace ircd::gpt
|
||||
{
|
||||
static size_t adamw(f32x4 &, f32x4 &, f32x4 &, const f32, const f32, const f32, const f32, const u32, size_t);
|
||||
static size_t adamw(task &, const f32, f32 *, const size_t, f32 *const (&)[2], const size_t);
|
||||
|
||||
static size_t backprop(task &, const f32, model::norm &, f32 *const (&)[2], size_t);
|
||||
static size_t backprop(task &, const f32, model::attn &, f32 *const (&)[2], size_t);
|
||||
static size_t backprop(task &, const f32, model::ffnn &, f32 *const (&)[2], size_t);
|
||||
static size_t backprop(task &, const f32, model::block &, f32 *const (&)[2], size_t);
|
||||
static size_t backprop(task &, const f32, model::embed &, f32 *const (&)[2], size_t);
|
||||
static size_t backprop(task &, const f32, model::decoder &, f32 *const (&)[2], size_t = 0);
|
||||
|
||||
template<class T>
|
||||
static void fmma(T *out, const T *in, const T *bias, const T *weight, const math::fmma_opts &);
|
||||
|
||||
|
@ -141,6 +151,50 @@ ircd::gpt::generate(const vector_view<u16> &out,
|
|||
}
|
||||
|
||||
uint64_t cycles(0);
|
||||
if(ctrl.prop)
|
||||
{
|
||||
static f32 *_momentum[2];
|
||||
if(!_momentum[0])
|
||||
{
|
||||
_momentum[0] = new f32[sizeof(model::decoder) / 4] {0.0f};
|
||||
_momentum[1] = new f32[sizeof(model::decoder) / 4] {0.0f};
|
||||
}
|
||||
|
||||
f32 *const momentum[2]
|
||||
{
|
||||
_momentum[0], _momentum[1],
|
||||
};
|
||||
|
||||
const prof::scope_cycles task_cycles
|
||||
{
|
||||
cycles
|
||||
};
|
||||
|
||||
backprop(task, ctrl.loss_mean, *model::default_model, momentum);
|
||||
}
|
||||
|
||||
if(ctrl.prop)
|
||||
{
|
||||
log::debug
|
||||
{
|
||||
log, "Backpropagation of %2.6f in %lu cycles.",
|
||||
ctrl.loss_mean,
|
||||
cycles,
|
||||
};
|
||||
|
||||
ctrl.epoch = 0;
|
||||
ctrl.loss_mean = 0;
|
||||
ctrl.loss = ctrl.loss_mean;
|
||||
ctrl.perp_mean = 0;
|
||||
ctrl.perp = ctrl.perp_mean;
|
||||
ctrl.cert_mean = 0;
|
||||
ctrl.cert = ctrl.cert_mean;
|
||||
ctrl.prop = false;
|
||||
pipe::default_model->invalid = true;
|
||||
return {};
|
||||
}
|
||||
|
||||
cycles = 0;
|
||||
milliseconds last_time {0};
|
||||
util::timer stopwatch;
|
||||
{
|
||||
|
@ -631,6 +685,206 @@ ircd::gpt::gelu(f32x4 &out,
|
|||
out = 0.5 * in * (1.0 + tanh(in * f32(0.7978845608) * (1.0 + f32(0.044715) * in * in)));
|
||||
}
|
||||
|
||||
//
|
||||
// backside
|
||||
//
|
||||
|
||||
size_t
|
||||
ircd::gpt::backprop(task &task,
|
||||
const f32 grad,
|
||||
model::decoder ¶m,
|
||||
f32 *const (&moment)[2],
|
||||
size_t off)
|
||||
{
|
||||
for(uint i(0); i < 12; ++i)
|
||||
off = backprop(task, grad, param.layer[i], moment, off);
|
||||
|
||||
off = backprop(task, grad, param.f, moment, off);
|
||||
off = backprop(task, grad, param.word, moment, off);
|
||||
return off;
|
||||
}
|
||||
|
||||
size_t
|
||||
ircd::gpt::backprop(task &task,
|
||||
const f32 grad,
|
||||
model::embed ¶m,
|
||||
f32 *const (&moment)[2],
|
||||
size_t off)
|
||||
{
|
||||
assert(task.opts);
|
||||
const auto &opts
|
||||
{
|
||||
*task.opts
|
||||
};
|
||||
|
||||
for(uint i(0); i < opts.context_tokens; ++i)
|
||||
off = adamw(task, grad, param.pos[i], 768, moment, off);
|
||||
|
||||
for(uint i(0); i < opts.logits; ++i)
|
||||
off = adamw(task, grad, param.token[i], 768, moment, off);
|
||||
|
||||
return off;
|
||||
}
|
||||
|
||||
size_t
|
||||
ircd::gpt::backprop(task &task,
|
||||
const f32 grad,
|
||||
model::block ¶m,
|
||||
f32 *const (&moment)[2],
|
||||
size_t off)
|
||||
{
|
||||
off = backprop(task, grad, param.ln1, moment, off);
|
||||
off = backprop(task, grad, param.attn, moment, off);
|
||||
off = backprop(task, grad, param.ln2, moment, off);
|
||||
off = backprop(task, grad, param.ffnn, moment, off);
|
||||
return off;
|
||||
}
|
||||
|
||||
size_t
|
||||
ircd::gpt::backprop(task &task,
|
||||
const f32 grad,
|
||||
model::attn ¶m,
|
||||
f32 *const (&moment)[2],
|
||||
size_t off)
|
||||
{
|
||||
off = adamw(task, grad, param.attn_bias, 2304, moment, off);
|
||||
|
||||
for(uint i(0); i < 768; ++i)
|
||||
off = adamw(task, grad, param.attn_weight[i], 2304, moment, off);
|
||||
|
||||
off = adamw(task, grad, param.proj_bias, 768, moment, off);
|
||||
|
||||
for(uint i(0); i < 768; ++i)
|
||||
off = adamw(task, grad, param.proj_weight[i], 768, moment, off);
|
||||
|
||||
return off;
|
||||
}
|
||||
|
||||
size_t
|
||||
ircd::gpt::backprop(task &task,
|
||||
const f32 grad,
|
||||
model::ffnn ¶m,
|
||||
f32 *const (&moment)[2],
|
||||
size_t off)
|
||||
{
|
||||
off = adamw(task, grad, param.fc_bias, 3072, moment, off);
|
||||
|
||||
for(uint i(0); i < 768; ++i)
|
||||
off = adamw(task, grad, param.fc_weight[i], 3072, moment, off);
|
||||
|
||||
off = adamw(task, grad, param.proj_bias, 768, moment, off);
|
||||
|
||||
for(uint i(0); i < 3072; ++i)
|
||||
off = adamw(task, grad, param.proj_weight[i], 768, moment, off);
|
||||
|
||||
return off;
|
||||
}
|
||||
|
||||
size_t
|
||||
ircd::gpt::backprop(task &task,
|
||||
const f32 grad,
|
||||
model::norm ¶m,
|
||||
f32 *const (&moment)[2],
|
||||
size_t off)
|
||||
{
|
||||
off = adamw(task, grad, param.bias, 768, moment, off);
|
||||
off = adamw(task, grad, param.weight, 768, moment, off);
|
||||
return off;
|
||||
}
|
||||
|
||||
size_t
|
||||
ircd::gpt::adamw(task &task,
|
||||
const f32 grad,
|
||||
f32 *const p_,
|
||||
const size_t num,
|
||||
f32 *const (&__restrict__ m_)[2],
|
||||
size_t off)
|
||||
{
|
||||
assert(task.opts);
|
||||
const auto &opts
|
||||
{
|
||||
*task.opts
|
||||
};
|
||||
|
||||
assert(task.ctrl);
|
||||
auto &ctrl
|
||||
{
|
||||
*task.ctrl
|
||||
};
|
||||
|
||||
f32x4 *const p[3]
|
||||
{
|
||||
reinterpret_cast<f32x4 *>(p_),
|
||||
reinterpret_cast<f32x4 *>(m_[0]) + off,
|
||||
reinterpret_cast<f32x4 *>(m_[1]) + off,
|
||||
};
|
||||
|
||||
for(uint i(0); i < num / 4; ++i)
|
||||
off = adamw(p[0][i], p[1][i], p[2][i], grad, opts.alpha, opts.beta[0], opts.beta[1], ctrl.step, off);
|
||||
|
||||
return off;
|
||||
}
|
||||
|
||||
size_t
|
||||
ircd::gpt::adamw(f32x4 &__restrict__ param,
|
||||
f32x4 &__restrict__ moment0,
|
||||
f32x4 &__restrict__ moment1,
|
||||
const f32 grad,
|
||||
const f32 alpha,
|
||||
const f32 beta0,
|
||||
const f32 beta1,
|
||||
const u32 step,
|
||||
const size_t off)
|
||||
{
|
||||
const f32x4 one
|
||||
{
|
||||
1.0f, 1.0f, 1.0f, 1.0f,
|
||||
};
|
||||
|
||||
const f32x4 a[2]
|
||||
{
|
||||
{ one - beta0 },
|
||||
{ one - beta1 },
|
||||
};
|
||||
|
||||
const f32x4 avg_mul[2]
|
||||
{
|
||||
{ moment0 * beta0 },
|
||||
{ moment1 * beta1 },
|
||||
};
|
||||
|
||||
const f32x4 avg_dot[2]
|
||||
{
|
||||
{ avg_mul[0] + a[0] * grad },
|
||||
{ avg_mul[1] + a[1] * grad * grad },
|
||||
};
|
||||
|
||||
const f32x4 bias[2]
|
||||
{
|
||||
{ avg_dot[0] / (one - powf(beta0, step + 1)) },
|
||||
{ avg_dot[1] / (one - powf(beta1, step + 1)) },
|
||||
};
|
||||
|
||||
const f32x4 denom
|
||||
{
|
||||
sqrtf(bias[1]) + 0.000001f // epsilon
|
||||
};
|
||||
|
||||
const f32x4 delta
|
||||
{
|
||||
alpha * (bias[0] / denom)
|
||||
};
|
||||
|
||||
const f32x4 update
|
||||
{
|
||||
param - delta
|
||||
};
|
||||
|
||||
moment0 = avg_dot[0];
|
||||
moment1 = avg_dot[1];
|
||||
param = update;
|
||||
return off + 1;
|
||||
}
|
||||
|
||||
//
|
||||
// gpt::task
|
||||
|
|
Loading…
Reference in a new issue