From 1f49f71530d76e576a696859d243f5e55d77b4df Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Thu, 22 Apr 2021 12:20:58 -0700 Subject: [PATCH] ircd::gpt: Backpropagate adaptive moment estimations. --- ircd/gpt.cc | 254 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 254 insertions(+) diff --git a/ircd/gpt.cc b/ircd/gpt.cc index 19916df78..426231663 100644 --- a/ircd/gpt.cc +++ b/ircd/gpt.cc @@ -10,6 +10,16 @@ namespace ircd::gpt { + static size_t adamw(f32x4 &, f32x4 &, f32x4 &, const f32, const f32, const f32, const f32, const u32, size_t); + static size_t adamw(task &, const f32, f32 *, const size_t, f32 *const (&)[2], const size_t); + + static size_t backprop(task &, const f32, model::norm &, f32 *const (&)[2], size_t); + static size_t backprop(task &, const f32, model::attn &, f32 *const (&)[2], size_t); + static size_t backprop(task &, const f32, model::ffnn &, f32 *const (&)[2], size_t); + static size_t backprop(task &, const f32, model::block &, f32 *const (&)[2], size_t); + static size_t backprop(task &, const f32, model::embed &, f32 *const (&)[2], size_t); + static size_t backprop(task &, const f32, model::decoder &, f32 *const (&)[2], size_t = 0); + template static void fmma(T *out, const T *in, const T *bias, const T *weight, const math::fmma_opts &); @@ -141,6 +151,50 @@ ircd::gpt::generate(const vector_view &out, } uint64_t cycles(0); + if(ctrl.prop) + { + static f32 *_momentum[2]; + if(!_momentum[0]) + { + _momentum[0] = new f32[sizeof(model::decoder) / 4] {0.0f}; + _momentum[1] = new f32[sizeof(model::decoder) / 4] {0.0f}; + } + + f32 *const momentum[2] + { + _momentum[0], _momentum[1], + }; + + const prof::scope_cycles task_cycles + { + cycles + }; + + backprop(task, ctrl.loss_mean, *model::default_model, momentum); + } + + if(ctrl.prop) + { + log::debug + { + log, "Backpropagation of %2.6f in %lu cycles.", + ctrl.loss_mean, + cycles, + }; + + ctrl.epoch = 0; + ctrl.loss_mean = 0; + ctrl.loss = ctrl.loss_mean; + ctrl.perp_mean = 0; + ctrl.perp = ctrl.perp_mean; + ctrl.cert_mean = 0; + ctrl.cert = ctrl.cert_mean; + ctrl.prop = false; + pipe::default_model->invalid = true; + return {}; + } + + cycles = 0; milliseconds last_time {0}; util::timer stopwatch; { @@ -631,6 +685,206 @@ ircd::gpt::gelu(f32x4 &out, out = 0.5 * in * (1.0 + tanh(in * f32(0.7978845608) * (1.0 + f32(0.044715) * in * in))); } +// +// backside +// + +size_t +ircd::gpt::backprop(task &task, + const f32 grad, + model::decoder ¶m, + f32 *const (&moment)[2], + size_t off) +{ + for(uint i(0); i < 12; ++i) + off = backprop(task, grad, param.layer[i], moment, off); + + off = backprop(task, grad, param.f, moment, off); + off = backprop(task, grad, param.word, moment, off); + return off; +} + +size_t +ircd::gpt::backprop(task &task, + const f32 grad, + model::embed ¶m, + f32 *const (&moment)[2], + size_t off) +{ + assert(task.opts); + const auto &opts + { + *task.opts + }; + + for(uint i(0); i < opts.context_tokens; ++i) + off = adamw(task, grad, param.pos[i], 768, moment, off); + + for(uint i(0); i < opts.logits; ++i) + off = adamw(task, grad, param.token[i], 768, moment, off); + + return off; +} + +size_t +ircd::gpt::backprop(task &task, + const f32 grad, + model::block ¶m, + f32 *const (&moment)[2], + size_t off) +{ + off = backprop(task, grad, param.ln1, moment, off); + off = backprop(task, grad, param.attn, moment, off); + off = backprop(task, grad, param.ln2, moment, off); + off = backprop(task, grad, param.ffnn, moment, off); + return off; +} + +size_t +ircd::gpt::backprop(task &task, + const f32 grad, + model::attn ¶m, + f32 *const (&moment)[2], + size_t off) +{ + off = adamw(task, grad, param.attn_bias, 2304, moment, off); + + for(uint i(0); i < 768; ++i) + off = adamw(task, grad, param.attn_weight[i], 2304, moment, off); + + off = adamw(task, grad, param.proj_bias, 768, moment, off); + + for(uint i(0); i < 768; ++i) + off = adamw(task, grad, param.proj_weight[i], 768, moment, off); + + return off; +} + +size_t +ircd::gpt::backprop(task &task, + const f32 grad, + model::ffnn ¶m, + f32 *const (&moment)[2], + size_t off) +{ + off = adamw(task, grad, param.fc_bias, 3072, moment, off); + + for(uint i(0); i < 768; ++i) + off = adamw(task, grad, param.fc_weight[i], 3072, moment, off); + + off = adamw(task, grad, param.proj_bias, 768, moment, off); + + for(uint i(0); i < 3072; ++i) + off = adamw(task, grad, param.proj_weight[i], 768, moment, off); + + return off; +} + +size_t +ircd::gpt::backprop(task &task, + const f32 grad, + model::norm ¶m, + f32 *const (&moment)[2], + size_t off) +{ + off = adamw(task, grad, param.bias, 768, moment, off); + off = adamw(task, grad, param.weight, 768, moment, off); + return off; +} + +size_t +ircd::gpt::adamw(task &task, + const f32 grad, + f32 *const p_, + const size_t num, + f32 *const (&__restrict__ m_)[2], + size_t off) +{ + assert(task.opts); + const auto &opts + { + *task.opts + }; + + assert(task.ctrl); + auto &ctrl + { + *task.ctrl + }; + + f32x4 *const p[3] + { + reinterpret_cast(p_), + reinterpret_cast(m_[0]) + off, + reinterpret_cast(m_[1]) + off, + }; + + for(uint i(0); i < num / 4; ++i) + off = adamw(p[0][i], p[1][i], p[2][i], grad, opts.alpha, opts.beta[0], opts.beta[1], ctrl.step, off); + + return off; +} + +size_t +ircd::gpt::adamw(f32x4 &__restrict__ param, + f32x4 &__restrict__ moment0, + f32x4 &__restrict__ moment1, + const f32 grad, + const f32 alpha, + const f32 beta0, + const f32 beta1, + const u32 step, + const size_t off) +{ + const f32x4 one + { + 1.0f, 1.0f, 1.0f, 1.0f, + }; + + const f32x4 a[2] + { + { one - beta0 }, + { one - beta1 }, + }; + + const f32x4 avg_mul[2] + { + { moment0 * beta0 }, + { moment1 * beta1 }, + }; + + const f32x4 avg_dot[2] + { + { avg_mul[0] + a[0] * grad }, + { avg_mul[1] + a[1] * grad * grad }, + }; + + const f32x4 bias[2] + { + { avg_dot[0] / (one - powf(beta0, step + 1)) }, + { avg_dot[1] / (one - powf(beta1, step + 1)) }, + }; + + const f32x4 denom + { + sqrtf(bias[1]) + 0.000001f // epsilon + }; + + const f32x4 delta + { + alpha * (bias[0] / denom) + }; + + const f32x4 update + { + param - delta + }; + + moment0 = avg_dot[0]; + moment1 = avg_dot[1]; + param = update; + return off + 1; +} // // gpt::task