2021-03-05 02:03:33 +01:00
|
|
|
// Matrix Construct Is All You Need Is All You Need Is AllĊĊĊĊĊĊĊĊ
|
|
|
|
//
|
|
|
|
// Copyright (C) Matrix Construct Developers, Authors & Contributors
|
|
|
|
// Copyright (C) 2016-2021 Jason Volk <jason@zemos.net>
|
|
|
|
//
|
|
|
|
// Permission to use, copy, modify, and/or distribute this software for any
|
|
|
|
// purpose with or without fee is hereby granted, provided that the above
|
|
|
|
// copyright notice and this permission notice is present in all copies. The
|
|
|
|
// full license for this software is available in the LICENSE file.
|
|
|
|
|
|
|
|
namespace ircd::gpt
|
|
|
|
{
|
2021-04-11 04:28:23 +02:00
|
|
|
template<class T>
|
|
|
|
static void fmma(T *out, const T *in, const T *bias, const T *weight, const math::fmma_opts &);
|
|
|
|
|
2021-03-30 03:18:59 +02:00
|
|
|
static void gelu(f32x4 &, const f32x4 &);
|
|
|
|
static void gelu(f32x4 *, const f32x4 *);
|
|
|
|
static void norm(f32x4 *, const f32x4 *, const f32x4 *, const f32x4 *, const f32);
|
2021-03-05 02:03:33 +01:00
|
|
|
static void vals(float (&)[12][1024][64], const float (&)[12][1024][1024], const float (&)[3][1024][12][64], const size_t);
|
|
|
|
static void pare(float (&)[12][1024][1024], const float (&)[3][1024][12][64], const size_t);
|
|
|
|
static void mask(float (&)[12][1024][1024], const float (&)[12][1024][1024], const bool (&)[1024][1024], const size_t);
|
|
|
|
static void smax(float (&)[12][1024][1024], const float (&)[12][1024][1024], const size_t);
|
2021-03-30 03:18:59 +02:00
|
|
|
static void ctrl(float (&)[3][1024][12][64], const float *const, const size_t, const model::decoder &, const uint layer);
|
|
|
|
static void ffnn(float *, const float *, const model::decoder &, const uint layer);
|
|
|
|
static void coil(float *, const size_t, const model::decoder &);
|
|
|
|
static void logitsmax(float *, const float *, const size_t);
|
|
|
|
static void logits(float *, const float *, const model::decoder &);
|
2021-03-09 11:08:47 +01:00
|
|
|
static void tail(float *, const float *, const model::decoder &);
|
|
|
|
static u16 argmax(const float *, const opts &);
|
2021-03-10 09:18:23 +01:00
|
|
|
static void embed(float *, const u16 token, const u16 position, const opts &);
|
2021-03-05 02:03:33 +01:00
|
|
|
|
|
|
|
static f32
|
|
|
|
logit alignas(64) [65536],
|
2021-04-02 22:01:38 +02:00
|
|
|
embeds alignas(64) [1024 * 768],
|
2021-03-05 02:03:33 +01:00
|
|
|
scratch alignas(64) [1024 * 768];
|
|
|
|
}
|
|
|
|
|
2021-03-09 11:08:47 +01:00
|
|
|
decltype(ircd::gpt::log)
|
|
|
|
ircd::gpt::log
|
|
|
|
{
|
|
|
|
"gpt"
|
|
|
|
};
|
|
|
|
|
|
|
|
ircd::string_view
|
|
|
|
ircd::gpt::generate(const mutable_buffer &out,
|
|
|
|
const string_view &in,
|
2021-04-02 22:01:38 +02:00
|
|
|
task &task)
|
2021-03-05 02:03:33 +01:00
|
|
|
{
|
2021-04-02 22:01:38 +02:00
|
|
|
u16 buf[2][1024];
|
2021-03-09 11:08:47 +01:00
|
|
|
const auto input_tokens
|
|
|
|
{
|
|
|
|
vocab::tokenize(buf[0], in)
|
|
|
|
};
|
2021-03-05 02:03:33 +01:00
|
|
|
|
2021-03-09 11:08:47 +01:00
|
|
|
const auto output_tokens
|
2021-03-05 02:03:33 +01:00
|
|
|
{
|
2021-04-02 22:01:38 +02:00
|
|
|
generate(buf[1], input_tokens, task)
|
2021-03-09 11:08:47 +01:00
|
|
|
};
|
|
|
|
|
|
|
|
const auto output
|
|
|
|
{
|
|
|
|
vocab::detokenize(out, output_tokens)
|
|
|
|
};
|
|
|
|
|
|
|
|
return output;
|
|
|
|
}
|
|
|
|
|
|
|
|
ircd::vector_view<ircd::u16>
|
|
|
|
ircd::gpt::generate(const vector_view<u16> &out,
|
|
|
|
const vector_view<const u16> &in,
|
2021-04-02 22:01:38 +02:00
|
|
|
task &task)
|
2021-03-09 11:08:47 +01:00
|
|
|
{
|
2021-04-02 22:01:38 +02:00
|
|
|
assert(task.ctrl);
|
|
|
|
assert(task.opts);
|
2021-03-30 03:18:59 +02:00
|
|
|
|
|
|
|
uint ret(0);
|
2021-03-09 11:08:47 +01:00
|
|
|
bool halt(false);
|
2021-04-02 22:01:38 +02:00
|
|
|
|
|
|
|
const auto &opts(*task.opts);
|
|
|
|
auto &ctrl(*task.ctrl);
|
|
|
|
auto &errc(ctrl.error_seq);
|
|
|
|
auto &accc(ctrl.accept_seq);
|
|
|
|
ctrl.tokens = in.size();
|
2021-04-11 04:28:23 +02:00
|
|
|
ctrl.head = 0;
|
2021-04-02 22:01:38 +02:00
|
|
|
|
|
|
|
const size_t tmax
|
|
|
|
{
|
|
|
|
in.size() + opts.limit
|
|
|
|
};
|
|
|
|
|
|
|
|
const vector_view<f32> accum
|
|
|
|
{
|
|
|
|
gpt::scratch, tmax * 768
|
|
|
|
};
|
|
|
|
|
|
|
|
const vector_view<f32> embeds
|
2021-03-09 11:08:47 +01:00
|
|
|
{
|
2021-04-02 22:01:38 +02:00
|
|
|
gpt::embeds, tmax * 768
|
|
|
|
};
|
2021-03-30 03:18:59 +02:00
|
|
|
|
2021-04-02 22:01:38 +02:00
|
|
|
for(uint j(0); j < in.size(); ++j)
|
|
|
|
{
|
|
|
|
const vector_view<f32> dst
|
2021-03-05 02:03:33 +01:00
|
|
|
{
|
2021-04-02 22:01:38 +02:00
|
|
|
data(embeds) + j * 768, 768
|
2021-03-05 02:03:33 +01:00
|
|
|
};
|
|
|
|
|
2021-04-02 22:01:38 +02:00
|
|
|
if(ircd::cl::enable)
|
|
|
|
ctrl.token[j] = in[j];
|
|
|
|
else
|
|
|
|
embed(data(dst), in[j], j, opts);
|
|
|
|
|
2021-04-17 20:59:30 +02:00
|
|
|
#if 0 // RB_DEBUG
|
2021-04-02 22:01:38 +02:00
|
|
|
static char dbuf[512] {0};
|
|
|
|
char report[1536] {0};
|
|
|
|
char tmbuf[1][64] {{0}};
|
|
|
|
const size_t report_size = snprintf
|
|
|
|
(
|
|
|
|
report, sizeof(report),
|
2021-04-11 04:28:23 +02:00
|
|
|
"%-4u %4u %4u:%-4u %1u%1u [ %6.2fL %6.2f%% ] %6.2fL %5.1f%% %s",
|
|
|
|
ctrl.epoch,
|
|
|
|
ctrl.cycle,
|
2021-04-02 22:01:38 +02:00
|
|
|
j,
|
|
|
|
ctrl.tokens,
|
2021-04-11 04:28:23 +02:00
|
|
|
0,
|
|
|
|
0,
|
|
|
|
0.0,
|
|
|
|
0.0,
|
|
|
|
0.0,
|
|
|
|
0.0,
|
|
|
|
vocab::debug(dbuf, in[j]).c_str()
|
2021-04-02 22:01:38 +02:00
|
|
|
);
|
|
|
|
|
2021-04-11 04:28:23 +02:00
|
|
|
log::logf
|
2021-03-05 02:03:33 +01:00
|
|
|
{
|
2021-04-11 04:28:23 +02:00
|
|
|
log, log::level::DEBUG,
|
|
|
|
"%s",
|
2021-04-02 22:01:38 +02:00
|
|
|
string_view{report, report_size}
|
2021-03-05 02:03:33 +01:00
|
|
|
};
|
2021-04-11 04:28:23 +02:00
|
|
|
#endif
|
2021-04-02 22:01:38 +02:00
|
|
|
}
|
2021-03-05 02:03:33 +01:00
|
|
|
|
2021-04-02 22:01:38 +02:00
|
|
|
uint64_t cycles(0);
|
|
|
|
milliseconds last_time {0};
|
|
|
|
util::timer stopwatch;
|
|
|
|
{
|
|
|
|
const prof::scope_cycles task_cycles
|
2021-03-09 11:08:47 +01:00
|
|
|
{
|
2021-04-02 22:01:38 +02:00
|
|
|
cycles
|
|
|
|
};
|
2021-03-09 11:08:47 +01:00
|
|
|
|
2021-04-02 22:01:38 +02:00
|
|
|
generate(task);
|
|
|
|
}
|
|
|
|
last_time = stopwatch.at<milliseconds>();
|
|
|
|
ctrl.elapsed += last_time.count();
|
2021-03-09 11:08:47 +01:00
|
|
|
|
2021-04-02 22:01:38 +02:00
|
|
|
/*
|
|
|
|
coil(data(scratch), tokens, *opts.model);
|
|
|
|
tail(logit, data(last_embed), *opts.model);
|
|
|
|
out[i] = argmax(logit, *opts);
|
|
|
|
*/
|
2021-03-09 11:08:47 +01:00
|
|
|
|
2021-04-02 22:01:38 +02:00
|
|
|
uint accc_thresh[3] {3, 3, 3};
|
|
|
|
for(uint i(0); i < 3; ++i)
|
|
|
|
for(uint j(3); j > 0; --j)
|
|
|
|
if(opts.accept_code[i][j - 1] == -1U)
|
|
|
|
--accc_thresh[i];
|
2021-03-30 03:18:59 +02:00
|
|
|
else
|
2021-04-02 22:01:38 +02:00
|
|
|
break;
|
2021-03-09 11:08:47 +01:00
|
|
|
|
2021-04-02 22:01:38 +02:00
|
|
|
uint errc_thresh[3] {3, 3, 3};
|
|
|
|
for(uint i(0); i < 3; ++i)
|
|
|
|
for(uint j(3); j > 0; --j)
|
|
|
|
if(opts.error_code[i][j - 1] == -1U)
|
|
|
|
--errc_thresh[i];
|
|
|
|
else
|
|
|
|
break;
|
2021-03-09 11:08:47 +01:00
|
|
|
|
2021-04-02 22:01:38 +02:00
|
|
|
for(auto &j(ret); j + in.size() < ctrl.tokens && j < out.size() && !halt; ++j)
|
|
|
|
{
|
|
|
|
out[j] = ctrl.token[(in.size() + j + ctrl.head) % opts.buffer_tokens];
|
2021-03-09 11:08:47 +01:00
|
|
|
|
|
|
|
for(uint j(0); j < 3; ++j)
|
2021-04-02 22:01:38 +02:00
|
|
|
errc[j] = opts.error_code[j][errc[j]] == out[j]?
|
|
|
|
errc[j] + 1: 0;
|
2021-03-30 03:18:59 +02:00
|
|
|
|
|
|
|
for(uint j(0); j < 3; ++j)
|
2021-04-02 22:01:38 +02:00
|
|
|
accc[j] = opts.accept_code[j][accc[j]] == out[j]?
|
|
|
|
accc[j] + 1: 0;
|
2021-03-30 03:18:59 +02:00
|
|
|
|
|
|
|
for(uint j(0); j < 3; ++j)
|
|
|
|
halt |= accc_thresh[j] && accc[j] >= accc_thresh[j],
|
|
|
|
halt |= errc_thresh[j] && errc[j] >= errc_thresh[j];
|
|
|
|
|
|
|
|
static char dbuf[512] {0};
|
|
|
|
char report[1536] {0};
|
|
|
|
char tmbuf[4][64] {0};
|
2021-04-02 22:01:38 +02:00
|
|
|
const size_t bsz(ctrl.tokens - in.size());
|
|
|
|
const size_t report_size = snprintf
|
2021-03-30 03:18:59 +02:00
|
|
|
(
|
|
|
|
report, sizeof(report),
|
2021-04-22 21:15:31 +02:00
|
|
|
"%4lu:%-4u %4lu:%-4lu %6.1f%% %5.1fP %6.3fL [%c%c%c] %5u %6.3fL %6.2fP %5.1f%% %s %04x %8s %8s | %8s",
|
2021-04-02 22:01:38 +02:00
|
|
|
j + in.size(),
|
2021-03-30 03:18:59 +02:00
|
|
|
ctrl.tokens,
|
2021-04-02 22:01:38 +02:00
|
|
|
ctrl.epoch,
|
2021-04-11 04:28:23 +02:00
|
|
|
ctrl.cycle,
|
2021-04-22 21:15:31 +02:00
|
|
|
std::clamp(ctrl.cert_mean * 100.0f, 0.0f, 100.0f),
|
|
|
|
std::clamp(ctrl.perp_mean, 0.0f, 100.0f),
|
|
|
|
std::clamp(ctrl.loss_mean, 0.0f, 99.99f),
|
|
|
|
opts.label == out[j]? '+': ' ',
|
|
|
|
accc[0] + accc[1] + accc[2] >= 3? 'A': ' ',
|
|
|
|
errc[0] + errc[1] + errc[2] >= 3? 'E': ' ',
|
|
|
|
opts.label,
|
|
|
|
std::clamp(ctrl.loss, 0.0f, 99.99f),
|
|
|
|
std::clamp(ctrl.perp, 0.0f, 100.0f),
|
|
|
|
std::clamp(ctrl.cert * 100.0f, 0.0f, 100.0f),
|
2021-04-02 22:01:38 +02:00
|
|
|
vocab::debug(dbuf, out[j]).c_str(),
|
2021-04-11 04:28:23 +02:00
|
|
|
out[j],
|
2021-04-02 22:01:38 +02:00
|
|
|
pretty(tmbuf[0], milliseconds(last_time / bsz), 1).c_str(),
|
|
|
|
pretty(tmbuf[1], si(cycles / bsz), 1).c_str(),
|
|
|
|
pretty(tmbuf[2], milliseconds(ctrl.elapsed), 1).c_str()
|
2021-03-30 03:18:59 +02:00
|
|
|
);
|
|
|
|
|
2021-04-11 04:28:23 +02:00
|
|
|
log::logf
|
2021-03-09 11:08:47 +01:00
|
|
|
{
|
2021-04-11 04:28:23 +02:00
|
|
|
log, log::level::DEBUG,
|
|
|
|
"%s",
|
2021-03-30 03:18:59 +02:00
|
|
|
string_view{report, report_size}
|
|
|
|
};
|
2021-03-05 02:03:33 +01:00
|
|
|
}
|
|
|
|
|
2021-04-02 22:01:38 +02:00
|
|
|
ret = ctrl.tokens - in.size();
|
2021-04-22 21:15:31 +02:00
|
|
|
if ((false)) for(uint i(0); i < 3; ++i)
|
2021-04-02 22:01:38 +02:00
|
|
|
if(accc_thresh[i] && ctrl.accept_seq[i] >= accc_thresh[i])
|
2021-03-30 03:18:59 +02:00
|
|
|
{
|
|
|
|
ret -= (3 - accc_thresh[i]);
|
|
|
|
break;
|
|
|
|
}
|
2021-04-02 22:01:38 +02:00
|
|
|
else if(errc_thresh[i] && ctrl.error_seq[i] >= errc_thresh[i])
|
2021-03-30 03:18:59 +02:00
|
|
|
{
|
|
|
|
ret -= (3 - errc_thresh[i]);
|
|
|
|
break;
|
|
|
|
}
|
|
|
|
|
2021-04-02 22:01:38 +02:00
|
|
|
ctx::interruption_point();
|
2021-03-09 11:08:47 +01:00
|
|
|
return vector_view<u16>
|
2021-03-05 02:03:33 +01:00
|
|
|
{
|
2021-03-09 11:08:47 +01:00
|
|
|
out, ret
|
2021-03-05 02:03:33 +01:00
|
|
|
};
|
|
|
|
}
|
|
|
|
|
2021-03-10 09:18:23 +01:00
|
|
|
void
|
|
|
|
ircd::gpt::embed(float *const out,
|
2021-03-09 11:08:47 +01:00
|
|
|
const u16 token,
|
2021-03-10 09:18:23 +01:00
|
|
|
const u16 position,
|
|
|
|
const opts &opts)
|
2021-03-05 02:03:33 +01:00
|
|
|
{
|
2021-03-10 09:18:23 +01:00
|
|
|
assert(opts.model);
|
2021-03-09 11:08:47 +01:00
|
|
|
const auto &wpe
|
2021-03-05 02:03:33 +01:00
|
|
|
{
|
2021-03-30 03:14:55 +02:00
|
|
|
opts.model->word.pos[position]
|
2021-03-05 02:03:33 +01:00
|
|
|
};
|
|
|
|
|
2021-03-09 11:08:47 +01:00
|
|
|
const auto &wte
|
2021-03-05 02:03:33 +01:00
|
|
|
{
|
2021-03-30 03:14:55 +02:00
|
|
|
opts.model->word.token[token]
|
2021-03-05 02:03:33 +01:00
|
|
|
};
|
|
|
|
|
2021-03-09 11:08:47 +01:00
|
|
|
for(uint j(0); j < 768; ++j)
|
|
|
|
out[j] = wte[j] + wpe[j];
|
2021-03-05 02:03:33 +01:00
|
|
|
}
|
|
|
|
|
|
|
|
uint16_t
|
2021-03-09 11:08:47 +01:00
|
|
|
ircd::gpt::argmax(const float *const __restrict__ logit,
|
|
|
|
const opts &opts)
|
2021-03-05 02:03:33 +01:00
|
|
|
{
|
2021-03-09 11:08:47 +01:00
|
|
|
static const auto max
|
|
|
|
{
|
|
|
|
32U
|
|
|
|
};
|
|
|
|
|
|
|
|
const auto top
|
|
|
|
{
|
|
|
|
std::clamp(opts.top_k, 1U, max - 1)
|
|
|
|
};
|
|
|
|
|
|
|
|
u16 best[max] {0};
|
2021-03-05 02:03:33 +01:00
|
|
|
for(uint j(0); j < vocab::tokens; ++j)
|
2021-03-09 11:08:47 +01:00
|
|
|
{
|
|
|
|
best[top] = j;
|
|
|
|
std::sort(begin(best), begin(best) + top + 1, [&logit]
|
|
|
|
(const auto &a, const auto &b)
|
|
|
|
{
|
|
|
|
return logit[a] > logit[b];
|
|
|
|
});
|
|
|
|
}
|
2021-03-05 02:03:33 +01:00
|
|
|
|
2021-03-09 11:08:47 +01:00
|
|
|
const auto x
|
|
|
|
{
|
|
|
|
top > 1?
|
|
|
|
rand::integer(0, top - 1):
|
|
|
|
0
|
|
|
|
};
|
|
|
|
|
|
|
|
return best[x];
|
2021-03-05 02:03:33 +01:00
|
|
|
}
|
|
|
|
|
|
|
|
[[gnu::noinline]]
|
|
|
|
void
|
|
|
|
ircd::gpt::tail(float *const __restrict__ logit,
|
2021-03-09 11:08:47 +01:00
|
|
|
const float *const __restrict__ state,
|
2021-03-05 02:03:33 +01:00
|
|
|
const model::decoder &d)
|
|
|
|
{
|
2021-03-30 03:18:59 +02:00
|
|
|
constexpr float lnf_epsilon
|
|
|
|
{
|
|
|
|
0.00001
|
|
|
|
};
|
|
|
|
|
2021-03-05 02:03:33 +01:00
|
|
|
static float
|
2021-03-30 03:18:59 +02:00
|
|
|
buf alignas(64) [1][768];
|
|
|
|
for(uint i(0); i < 768; ++i)
|
|
|
|
buf[0][i] = state[i];
|
2021-03-05 02:03:33 +01:00
|
|
|
|
2021-03-30 03:18:59 +02:00
|
|
|
norm((f32x4 *)buf[0], (const f32x4 *)state, (const f32x4 *)d.f.bias, (const f32x4 *)d.f.weight, lnf_epsilon);
|
|
|
|
logits(logit, buf[0], d);
|
|
|
|
//logitsmax(logit, logit, vocab::tokens);
|
2021-03-05 02:03:33 +01:00
|
|
|
}
|
|
|
|
|
|
|
|
void
|
|
|
|
ircd::gpt::logits(float *const __restrict__ out,
|
2021-03-30 03:18:59 +02:00
|
|
|
const float *const __restrict__ in,
|
2021-03-05 02:03:33 +01:00
|
|
|
const model::decoder &d)
|
|
|
|
{
|
|
|
|
for(uint j(0); j < vocab::tokens; ++j)
|
|
|
|
out[j] = 0;
|
|
|
|
|
|
|
|
for(uint j(0); j < vocab::tokens; ++j)
|
|
|
|
for(uint k(0); k < 768; ++k)
|
2021-03-30 03:18:59 +02:00
|
|
|
out[j] += in[k] * d.word.token[j][k];
|
2021-03-05 02:03:33 +01:00
|
|
|
}
|
|
|
|
|
2021-03-30 03:18:59 +02:00
|
|
|
[[gnu::noinline]]
|
2021-03-05 02:03:33 +01:00
|
|
|
void
|
|
|
|
ircd::gpt::logitsmax(float *const out,
|
2021-03-30 03:18:59 +02:00
|
|
|
const float *const in,
|
|
|
|
const size_t num)
|
2021-03-05 02:03:33 +01:00
|
|
|
{
|
2021-03-30 03:18:59 +02:00
|
|
|
static f64x4
|
|
|
|
exps alignas(4096) [2][65536 / 4];
|
|
|
|
|
|
|
|
math::smax<f32x4, f64x4>
|
|
|
|
(
|
|
|
|
{(f32x4 *)out, num / 4},
|
|
|
|
{(const f32x4 *)in, num / 4},
|
|
|
|
exps[0],
|
|
|
|
exps[1]
|
|
|
|
);
|
2021-03-05 02:03:33 +01:00
|
|
|
}
|
|
|
|
|
|
|
|
[[gnu::noinline]]
|
|
|
|
void
|
2021-03-30 03:18:59 +02:00
|
|
|
ircd::gpt::coil(float *__restrict__ accum,
|
|
|
|
const size_t tokens,
|
|
|
|
const model::decoder &decoder)
|
2021-03-05 02:03:33 +01:00
|
|
|
{
|
|
|
|
static float
|
2021-03-30 03:18:59 +02:00
|
|
|
qkv alignas(4096) [3][1024][12][64],
|
|
|
|
state alignas(4096) [12][1024][1024],
|
|
|
|
attns alignas(4096) [12][1024][64];
|
2021-03-05 02:03:33 +01:00
|
|
|
|
|
|
|
for(uint i(0); i < 12; ++i)
|
|
|
|
{
|
|
|
|
const auto &layer
|
|
|
|
{
|
|
|
|
decoder.layer[i]
|
|
|
|
};
|
|
|
|
|
2021-03-30 03:18:59 +02:00
|
|
|
ctrl(qkv, accum, tokens, decoder, i);
|
2021-03-05 02:03:33 +01:00
|
|
|
pare(state, qkv, tokens);
|
|
|
|
mask(state, state, layer.attn.bias, tokens);
|
|
|
|
smax(state, state, tokens);
|
|
|
|
vals(attns, state, qkv, tokens);
|
|
|
|
|
2021-03-30 03:18:59 +02:00
|
|
|
static f32 a alignas(64) [1024][768];
|
|
|
|
memset(a, 0x0, 768 * tokens * sizeof(float));
|
|
|
|
for(uint j(0); j < tokens; j++)
|
2021-03-05 02:03:33 +01:00
|
|
|
{
|
2021-03-30 03:18:59 +02:00
|
|
|
for(uint k(0); k < 12; k++)
|
|
|
|
for(uint l(0); l < 64; l++)
|
|
|
|
a[j][k * 64 + l] = attns[k][j][l];
|
2021-03-05 02:03:33 +01:00
|
|
|
}
|
2021-03-30 03:18:59 +02:00
|
|
|
|
2021-04-11 04:28:23 +02:00
|
|
|
static const math::fmma_opts fmma_opts
|
|
|
|
{
|
|
|
|
768, 768, 2U
|
|
|
|
};
|
|
|
|
|
2021-03-30 03:18:59 +02:00
|
|
|
for(uint j(0); j < tokens; ++j)
|
2021-04-11 04:28:23 +02:00
|
|
|
fmma((f32x4 *)(accum + j * 768), (const f32x4 *)(a[j]), (const f32x4 *)layer.attn.proj_bias, (const f32x4 *)layer.attn.proj_weight, fmma_opts);
|
2021-03-30 03:18:59 +02:00
|
|
|
|
|
|
|
for(uint j(0); j < tokens; ++j)
|
|
|
|
ffnn(accum + j * 768, accum + j * 768, decoder, i);
|
2021-03-05 02:03:33 +01:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
void
|
|
|
|
ircd::gpt::ctrl(float (&__restrict__ out)[3][1024][12][64],
|
|
|
|
const float *const __restrict__ in,
|
|
|
|
const size_t num,
|
2021-03-30 03:18:59 +02:00
|
|
|
const model::decoder &decoder,
|
|
|
|
const uint laynum)
|
2021-03-05 02:03:33 +01:00
|
|
|
{
|
2021-03-30 03:18:59 +02:00
|
|
|
constexpr float ln1_epsilon
|
|
|
|
{
|
|
|
|
0.00001
|
|
|
|
};
|
|
|
|
|
|
|
|
const auto &layer
|
|
|
|
{
|
|
|
|
decoder.layer[laynum]
|
|
|
|
};
|
|
|
|
|
2021-03-05 02:03:33 +01:00
|
|
|
float
|
|
|
|
(&__restrict__ qry)[1024][12][64] { out[0] },
|
|
|
|
(&__restrict__ key)[1024][12][64] { out[1] },
|
|
|
|
(&__restrict__ val)[1024][12][64] { out[2] };
|
|
|
|
|
|
|
|
for(uint i(0); i < num; ++i)
|
|
|
|
{
|
|
|
|
static float
|
|
|
|
buf alignas(64) [768],
|
|
|
|
proj alignas(64) [2304];
|
|
|
|
|
2021-03-30 03:18:59 +02:00
|
|
|
norm((f32x4 *)buf, (const f32x4 *)(in + i * 768), (const f32x4 *)layer.ln1.bias, (const f32x4 *)layer.ln1.weight, ln1_epsilon);
|
2021-03-05 02:03:33 +01:00
|
|
|
|
2021-04-11 04:28:23 +02:00
|
|
|
static const math::fmma_opts fmma_opts
|
|
|
|
{
|
|
|
|
768, 2304, 2U,
|
|
|
|
};
|
|
|
|
|
2021-03-30 03:18:59 +02:00
|
|
|
memset(proj, 0x0, sizeof(proj));
|
2021-04-11 04:28:23 +02:00
|
|
|
fmma((f32x4 *)proj, (const f32x4 *)buf, (const f32x4 *)layer.attn.attn_bias, (const f32x4 *)layer.attn.attn_weight, fmma_opts);
|
2021-03-05 02:03:33 +01:00
|
|
|
|
|
|
|
#pragma clang loop unroll (disable)
|
|
|
|
for(uint j(0); j < 12; ++j)
|
|
|
|
for(uint k(0); k < 64; ++k)
|
|
|
|
qry[i][j][k] = proj[768 * 0 + j * 64 + k];
|
|
|
|
|
|
|
|
#pragma clang loop unroll (disable)
|
|
|
|
for(uint j(0); j < 12; ++j)
|
|
|
|
for(uint k(0); k < 64; ++k)
|
|
|
|
key[i][j][k] = proj[768 * 1 + j * 64 + k];
|
|
|
|
|
|
|
|
#pragma clang loop unroll (disable)
|
|
|
|
for(uint j(0); j < 12; ++j)
|
|
|
|
for(uint k(0); k < 64; ++k)
|
|
|
|
val[i][j][k] = proj[768 * 2 + j * 64 + k];
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
void
|
|
|
|
ircd::gpt::pare(float (&__restrict__ out)[12][1024][1024],
|
|
|
|
const float (&__restrict__ qkv)[3][1024][12][64],
|
|
|
|
const size_t num)
|
|
|
|
{
|
|
|
|
const float
|
|
|
|
(&__restrict__ qry)[1024][12][64] { qkv[0] },
|
|
|
|
(&__restrict__ key)[1024][12][64] { qkv[1] },
|
|
|
|
(&__restrict__ val)[1024][12][64] { qkv[2] };
|
|
|
|
|
|
|
|
#pragma clang loop unroll (disable)
|
|
|
|
for(uint j(0); j < 12; ++j)
|
|
|
|
for(uint k(0); k < num; ++k)
|
|
|
|
for(uint l(0); l < num; ++l)
|
|
|
|
out[j][k][l] = 0;
|
|
|
|
|
|
|
|
#pragma clang loop unroll (disable)
|
|
|
|
for(uint j(0); j < 12; ++j)
|
|
|
|
for(uint k(0); k < num; ++k)
|
|
|
|
for(uint l(0); l < num; ++l)
|
|
|
|
for(uint m(0); m < 64; ++m)
|
|
|
|
out[j][k][l] += qry[k][j][m] * key[l][j][m];
|
|
|
|
|
|
|
|
#pragma clang loop unroll (disable)
|
|
|
|
for(uint j(0); j < 12; ++j)
|
|
|
|
for(uint k(0); k < num; ++k)
|
|
|
|
for(uint l(0); l < num; ++l)
|
|
|
|
out[j][k][l] /= 8.0;
|
|
|
|
}
|
|
|
|
|
|
|
|
void
|
2021-03-30 03:18:59 +02:00
|
|
|
ircd::gpt::mask(float (&__restrict__ out)[12][1024][1024],
|
2021-03-05 02:03:33 +01:00
|
|
|
const float (&__restrict__ in)[12][1024][1024],
|
2021-03-30 03:18:59 +02:00
|
|
|
const bool (&__restrict__ bias)[1024][1024],
|
2021-03-05 02:03:33 +01:00
|
|
|
const size_t num)
|
|
|
|
{
|
2021-03-30 03:18:59 +02:00
|
|
|
static const float masked
|
|
|
|
{
|
|
|
|
-10000.0
|
|
|
|
};
|
2021-03-05 02:03:33 +01:00
|
|
|
|
|
|
|
#pragma clang loop unroll (disable)
|
|
|
|
for(uint j(0); j < 12; ++j)
|
|
|
|
for(uint k(0); k < num; ++k)
|
|
|
|
for(uint l(0); l < num; ++l)
|
2021-03-30 03:18:59 +02:00
|
|
|
out[j][k][l] = bias[k][l]? in[j][k][l]: masked;
|
2021-03-05 02:03:33 +01:00
|
|
|
}
|
|
|
|
|
|
|
|
void
|
|
|
|
ircd::gpt::smax(float (&__restrict__ out)[12][1024][1024],
|
|
|
|
const float (&__restrict__ in)[12][1024][1024],
|
|
|
|
const size_t num)
|
|
|
|
{
|
2021-03-30 03:18:59 +02:00
|
|
|
static f64
|
|
|
|
tmp alignas(4096) [2][1024];
|
2021-03-05 02:03:33 +01:00
|
|
|
|
|
|
|
#pragma clang loop unroll (disable)
|
|
|
|
for(uint j(0); j < 12; ++j)
|
|
|
|
for(uint k(0); k < num; ++k)
|
2021-03-30 03:18:59 +02:00
|
|
|
math::smax<f32, f64>
|
|
|
|
(
|
|
|
|
out[j][k], { in[j][k], num }, tmp[0], tmp[1]
|
|
|
|
);
|
|
|
|
}
|
2021-03-05 02:03:33 +01:00
|
|
|
|
2021-03-30 03:18:59 +02:00
|
|
|
void
|
|
|
|
ircd::gpt::vals(float (&__restrict__ out)[12][1024][64],
|
|
|
|
const float (&__restrict__ in)[12][1024][1024],
|
|
|
|
const float (&__restrict__ qkv)[3][1024][12][64],
|
|
|
|
const size_t num)
|
|
|
|
{
|
|
|
|
const float
|
|
|
|
(&__restrict__ val)[1024][12][64] { qkv[2] };
|
2021-03-05 02:03:33 +01:00
|
|
|
|
|
|
|
#pragma clang loop unroll (disable)
|
|
|
|
for(uint j(0); j < 12; ++j)
|
|
|
|
for(uint k(0); k < num; ++k)
|
2021-03-30 03:18:59 +02:00
|
|
|
for(uint l(0); l < 64; ++l)
|
|
|
|
out[j][k][l] = 0;
|
2021-03-05 02:03:33 +01:00
|
|
|
|
|
|
|
#pragma clang loop unroll (disable)
|
|
|
|
for(uint j(0); j < 12; ++j)
|
|
|
|
for(uint k(0); k < num; ++k)
|
|
|
|
for(uint l(0); l < num; ++l)
|
2021-03-30 03:18:59 +02:00
|
|
|
for(uint m(0); m < 64; ++m)
|
|
|
|
out[j][k][m] += in[j][k][l] * val[l][j][m];
|
2021-03-05 02:03:33 +01:00
|
|
|
}
|
|
|
|
|
|
|
|
void
|
2021-04-11 04:28:23 +02:00
|
|
|
ircd::gpt::ffnn(float *const out,
|
|
|
|
const float *const in,
|
|
|
|
const model::decoder &decoder,
|
|
|
|
const uint laynum)
|
2021-03-05 02:03:33 +01:00
|
|
|
{
|
2021-04-11 04:28:23 +02:00
|
|
|
static const math::fmma_opts fmma3_opts
|
2021-03-30 03:18:59 +02:00
|
|
|
{
|
2021-04-11 04:28:23 +02:00
|
|
|
768, 3072, 2U,
|
2021-03-30 03:18:59 +02:00
|
|
|
};
|
2021-03-05 02:03:33 +01:00
|
|
|
|
2021-04-11 04:28:23 +02:00
|
|
|
static const math::fmma_opts fmma4_opts
|
2021-03-05 02:03:33 +01:00
|
|
|
{
|
2021-04-11 04:28:23 +02:00
|
|
|
3072, 768, 2U,
|
2021-03-05 02:03:33 +01:00
|
|
|
};
|
|
|
|
|
2021-04-11 04:28:23 +02:00
|
|
|
constexpr float ln2_epsilon
|
2021-03-30 03:18:59 +02:00
|
|
|
{
|
2021-04-11 04:28:23 +02:00
|
|
|
0.00001
|
2021-03-30 03:18:59 +02:00
|
|
|
};
|
|
|
|
|
2021-04-11 04:28:23 +02:00
|
|
|
const auto &layer
|
2021-03-30 03:18:59 +02:00
|
|
|
{
|
2021-04-11 04:28:23 +02:00
|
|
|
decoder.layer[laynum]
|
2021-03-30 03:18:59 +02:00
|
|
|
};
|
|
|
|
|
2021-04-11 04:28:23 +02:00
|
|
|
static float
|
|
|
|
buf alignas(64) [768],
|
|
|
|
buf2 alignas(64) [3072];
|
2021-03-05 02:03:33 +01:00
|
|
|
|
2021-04-11 04:28:23 +02:00
|
|
|
memset(buf2, 0x0, sizeof(buf2));
|
|
|
|
norm((f32x4 *)buf, (const f32x4 *)in, (const f32x4 *)layer.ln2.bias, (const f32x4 *)layer.ln2.weight, ln2_epsilon);
|
|
|
|
fmma((f32x4 *)buf2, (const f32x4 *)buf, (const f32x4 *)layer.ffnn.fc_bias, (const f32x4 *)layer.ffnn.fc_weight, fmma3_opts);
|
|
|
|
gelu((f32x4 *)buf2, (const f32x4 *)buf2);
|
|
|
|
fmma((f32x4 *)out, (const f32x4 *)buf2, (const f32x4 *)layer.ffnn.proj_bias, (const f32x4 *)layer.ffnn.proj_weight, fmma4_opts);
|
2021-03-05 02:03:33 +01:00
|
|
|
}
|
|
|
|
|
|
|
|
void
|
2021-04-11 04:28:23 +02:00
|
|
|
ircd::gpt::norm(f32x4 *const __restrict__ out,
|
|
|
|
const f32x4 *const __restrict__ in,
|
|
|
|
const f32x4 *const __restrict__ bias,
|
|
|
|
const f32x4 *const __restrict__ weight,
|
|
|
|
const float epsilon)
|
2021-03-05 02:03:33 +01:00
|
|
|
{
|
2021-04-11 04:28:23 +02:00
|
|
|
static f64x4
|
|
|
|
tmp alignas(64) [768 / 4];
|
2021-03-30 03:18:59 +02:00
|
|
|
|
2021-04-11 04:28:23 +02:00
|
|
|
math::norm<f32x4, f64x4>
|
|
|
|
(
|
|
|
|
{out, 192}, {in, 192}, epsilon, tmp
|
|
|
|
);
|
2021-03-05 02:03:33 +01:00
|
|
|
|
2021-04-11 04:28:23 +02:00
|
|
|
for(uint j(0); j < 768 / 4; ++j)
|
|
|
|
out[j] = out[j] * weight[j] + bias[j];
|
2021-03-05 02:03:33 +01:00
|
|
|
}
|
|
|
|
|
2021-04-11 04:28:23 +02:00
|
|
|
template<class T>
|
2021-03-05 02:03:33 +01:00
|
|
|
void
|
2021-04-11 04:28:23 +02:00
|
|
|
ircd::gpt::fmma(T *const __restrict__ out,
|
|
|
|
const T *const __restrict__ in,
|
|
|
|
const T *const __restrict__ bias,
|
|
|
|
const T *const __restrict__ weight,
|
|
|
|
const math::fmma_opts &opts)
|
2021-03-05 02:03:33 +01:00
|
|
|
{
|
2021-04-11 04:28:23 +02:00
|
|
|
for(uint i(0); i < opts.rows / simd::lanes<T>(); ++i)
|
|
|
|
out[i] += bias[i];
|
2021-03-05 02:03:33 +01:00
|
|
|
|
2021-04-11 04:28:23 +02:00
|
|
|
math::fmma(out, in, weight, opts);
|
2021-03-05 02:03:33 +01:00
|
|
|
}
|
|
|
|
|
|
|
|
void
|
2021-03-30 03:18:59 +02:00
|
|
|
ircd::gpt::gelu(f32x4 *const out,
|
|
|
|
const f32x4 *const in)
|
2021-03-05 02:03:33 +01:00
|
|
|
{
|
2021-03-30 03:18:59 +02:00
|
|
|
for(uint j(0); j < 3072 / 4; ++j)
|
2021-03-05 02:03:33 +01:00
|
|
|
gelu(out[j], in[j]);
|
|
|
|
}
|
|
|
|
|
|
|
|
void
|
2021-03-30 03:18:59 +02:00
|
|
|
ircd::gpt::gelu(f32x4 &out,
|
|
|
|
const f32x4 &in)
|
2021-03-05 02:03:33 +01:00
|
|
|
{
|
2021-03-30 03:18:59 +02:00
|
|
|
out = 0.5 * in * (1.0 + tanh(in * f32(0.7978845608) * (1.0 + f32(0.044715) * in * in)));
|
2021-03-05 02:03:33 +01:00
|
|
|
}
|
2021-04-17 20:53:50 +02:00
|
|
|
|
|
|
|
|
|
|
|
//
|
|
|
|
// gpt::task
|
|
|
|
//
|
|
|
|
|
|
|
|
ircd::gpt::task::task(const gpt::opts *const opts,
|
|
|
|
struct ircd_gpt_task *const ctrl)
|
|
|
|
:opts
|
|
|
|
{
|
|
|
|
opts
|
|
|
|
}
|
|
|
|
,ctrl
|
|
|
|
{
|
|
|
|
ctrl
|
|
|
|
}
|
|
|
|
{
|
|
|
|
memset(this->ctrl, 0x0, sizeof(ircd_gpt_task));
|
|
|
|
|
|
|
|
this->ctrl->rand[0] = this->opts->seed;
|
|
|
|
this->ctrl->rand[1] = this->opts->seed;
|
|
|
|
this->ctrl->rand[2] = -1UL;
|
|
|
|
this->ctrl->rand[3] = -1UL;
|
|
|
|
}
|
|
|
|
|
|
|
|
ircd::gpt::task::~task()
|
|
|
|
noexcept
|
|
|
|
{
|
|
|
|
}
|
|
|
|
|
|
|
|
//
|
|
|
|
// hypercall
|
|
|
|
//
|
|
|
|
|
|
|
|
ircd::string_view
|
|
|
|
ircd::gpt::reflect(const enum ircd_gpt_hypercall code)
|
|
|
|
noexcept
|
|
|
|
{
|
|
|
|
switch(code)
|
|
|
|
{
|
|
|
|
case IRCD_GPT_ACCEPT: return "ACCEPT";
|
|
|
|
case IRCD_GPT_ECOMPLETE: return "ECOMPLETE";
|
|
|
|
case IRCD_GPT_ETOKENS: return "ETOKENS";
|
|
|
|
}
|
|
|
|
|
|
|
|
return "??????";
|
|
|
|
}
|