0
0
Fork 0
mirror of https://github.com/matrix-construct/construct synced 2024-11-26 00:32:35 +01:00

ircd::gpt: Add basic interface; add options, context.

This commit is contained in:
Jason Volk 2021-03-09 02:08:47 -08:00
parent 4458235dfa
commit 33a1ffd4bf
2 changed files with 198 additions and 67 deletions

View file

@ -17,21 +17,56 @@ namespace ircd::gpt
{
IRCD_EXCEPTION(ircd::error, error)
u16
generate(const vector_view<const f32> &) noexcept;
vector_view<f32>
embed(const vector_view<f32> &,
const vector_view<const u16> &) noexcept;
struct opts;
struct context;
extern const opts default_opts;
extern log::log log;
vector_view<u16>
generate(const vector_view<u16> &out,
const vector_view<const u16> &in,
const opts & = default_opts);
string_view
generate(const mutable_buffer &out,
const string_view &in,
const opts & = default_opts);
}
#include "vocab.h"
#include "model.h"
namespace ircd::gpt
struct ircd::gpt::opts
{
using vocab::detokenize;
using vocab::tokenize;
}
/// Specifies the nominal halting condition based on the sequence of
/// tokens. Generation will complete when this sequence is witnessed. Set
/// tokens to -1 starting from the back to not match that token. Setting
/// all tokens to -1 will ignore this condition.
uint accept_code[3][3]
{
{ 13, 198, -1U, },
{ 198, 198, -1U, },
{ -1U, -1U, -1U, },
};
/// Specifies the exceptional halting condition based on the sequence of
/// tokens. By default, the three zeros represent three outputs of '!'
/// which is probably an error code; note that a true "!!!" is represented
/// by token number 10185. Set tokens to -1 starting from the back to
/// not match that token; generated output after errors is usually garbage.
uint error_code[3][3]
{
{ 0, 0, 0, },
{ -1U, 0, 0, },
{ -1U, 0, 0, },
};
/// Limit number of output tokens. Default of -1 is unlimited; the number
/// of tokens generated will be limited by other factors.
uint limit {-1U};
/// Flip random coins over the top k logits each round. Setting to 1
/// deterministically selects the top logit.
uint top_k {2};
};

View file

@ -8,17 +8,11 @@
// copyright notice and this permission notice is present in all copies. The
// full license for this software is available in the LICENSE file.
decltype(ircd::gpt::log)
ircd::gpt::log
{
"gpt"
};
namespace ircd::gpt
{
static void gelu(float &, const float &);
static void gelu(float (&)[3072], const float (&)[3072]);
static void norm(float (&)[768], const float (&)[768], const float (&)[768], const float (&)[768], const float);
static void norm(float (&)[768], const float *, const float (&)[768], const float (&)[768], const float);
static void fmma(float (&)[768], const float (&)[3072], const float (&)[768], const float (&)[3072][768]);
static void fmma(float (&)[3072], const float (&)[768], const float (&)[3072], const float (&)[768][3072]);
static void fmma(float (&)[2304], const float (&)[768], const float (&)[2304], const float (&)[768][2304]);
@ -32,8 +26,10 @@ namespace ircd::gpt
static void transform(float *, const size_t, const model::decoder &);
static void logitsmax(float *, const float *);
static void logits(float *, const float (&)[768], const model::decoder &);
static void tail(float *, const float (&)[768], const model::decoder &);
static u16 argmax(const float *);
static void tail(float *, const float *, const model::decoder &);
static u16 argmax(const float *, const opts &);
static vector_view<f32> embed(const vector_view<f32> &out, const u16 token, const u16 position);
std::unique_ptr<model::decoder> device
{
@ -45,6 +41,15 @@ namespace ircd::gpt
scratch alignas(64) [1024 * 768];
}
decltype(ircd::gpt::default_opts)
ircd::gpt::default_opts;
decltype(ircd::gpt::log)
ircd::gpt::log
{
"gpt"
};
namespace ircd::gpt::model
{
constexpr float embed_pdrop
@ -78,81 +83,172 @@ namespace ircd::gpt::model
};
}
ircd::vector_view<ircd::f32>
ircd::gpt::embed(const vector_view<f32> &out,
const vector_view<const u16> &in)
noexcept
ircd::string_view
ircd::gpt::generate(const mutable_buffer &out,
const string_view &in,
const opts &opts)
{
assert(device);
uint i(0);
for(; i < in.size(); ++i)
u16 buf[2][256];
const auto input_tokens
{
const auto &wpe
{
device->wpe[i]
vocab::tokenize(buf[0], in)
};
const auto &wte
const auto output_tokens
{
device->wte[in[i]]
generate(buf[1], input_tokens, opts)
};
for(uint j(0); j < 768; ++j)
out[i * 768 + j] = wte[j] + wpe[j];
const auto output
{
vocab::detokenize(out, output_tokens)
};
return output;
}
return vector_view<f32>
ircd::vector_view<ircd::u16>
ircd::gpt::generate(const vector_view<u16> &out,
const vector_view<const u16> &in,
const opts &opts)
{
data(out), i * 768
};
}
uint16_t
ircd::gpt::generate(const vector_view<const f32> &in)
noexcept
size_t ret(0);
bool halt(false);
uint errc[3] {0}, accc[3] {0};
for(uint i(0); !halt && i < out.size() && ret < opts.limit; ++i)
{
always_assert(in.size() % 768 == 0);
const auto toks
const size_t tokens
{
in.size() / 768
in.size() + i
};
const vector_view<f32> scratch
{
gpt::scratch, in.size()
gpt::scratch, tokens * 768
};
for(uint i(0); i < in.size(); ++i)
scratch[i] = in[i];
for(uint j(0); j < in.size(); ++j)
{
const vector_view<f32> dst
{
data(scratch) + j * 768, 768
};
transform(data(scratch), toks, *device);
const auto embedding
{
embed(dst, in[j], j)
};
}
static float
buf alignas(64) [768];
for(uint j(0); j < ret; ++j)
{
const vector_view<f32> dst
{
data(scratch) + (in.size() + j) * 768, 768
};
for(uint i(0); i < 768; ++i)
buf[i] = scratch[(toks - 1) * 768 + i];
const auto embedding
{
embed(dst, out[j], in.size() + j)
};
}
tail(logit, buf, *device);
return argmax(logit);
transform(data(scratch), tokens, *device);
const vector_view<f32> last_embed
{
data(scratch) + ((tokens - 1) * 768), 768
};
tail(logit, data(last_embed), *device);
out[i] = argmax(logit, opts);
for(uint j(0); j < 3; ++j)
{
errc[j] = out[i] == opts.error_code[j][errc[j]]? errc[j] + 1: 0;
accc[j] = out[i] == opts.accept_code[j][accc[j]]? accc[j] + 1: 0;
}
for(uint j(0); j < 3; ++j)
{
halt |= errc[j] >= 3 || (errc[j] && opts.error_code[j][errc[j] + 1] == -1U);
halt |= accc[j] >= 3 || (accc[j] && opts.accept_code[j][accc[j] + 1] == -1U);
}
++ret;
}
return vector_view<u16>
{
out, ret
};
}
ircd::vector_view<ircd::f32>
ircd::gpt::embed(const vector_view<f32> &out,
const u16 token,
const u16 position)
{
assert(device);
const auto &wpe
{
device->wpe[position]
};
const auto &wte
{
device->wte[token]
};
for(uint j(0); j < 768; ++j)
out[j] = wte[j] + wpe[j];
return vector_view<f32>
{
data(out), 768
};
}
uint16_t
ircd::gpt::argmax(const float *const __restrict__ logit)
ircd::gpt::argmax(const float *const __restrict__ logit,
const opts &opts)
{
u16 ret(0);
for(uint j(0); j < vocab::tokens; ++j)
if(logit[j] > logit[ret])
ret = j;
static const auto max
{
32U
};
return ret;
const auto top
{
std::clamp(opts.top_k, 1U, max - 1)
};
u16 best[max] {0};
for(uint j(0); j < vocab::tokens; ++j)
{
best[top] = j;
std::sort(begin(best), begin(best) + top + 1, [&logit]
(const auto &a, const auto &b)
{
return logit[a] > logit[b];
});
}
const auto x
{
top > 1?
rand::integer(0, top - 1):
0
};
return best[x];
}
[[gnu::noinline]]
void
ircd::gpt::tail(float *const __restrict__ logit,
const float (&__restrict__ state)[768],
const float *const __restrict__ state,
const model::decoder &d)
{
static float
@ -396,7 +492,7 @@ ircd::gpt::mask(float (&__restrict__ out)[12][1024][1024],
void
ircd::gpt::norm(float (&__restrict__ out)[768],
const float (&__restrict__ in)[768],
const float *const in,
const float (&__restrict__ bias)[768],
const float (&__restrict__ weight)[768],
const float epsilon)
@ -406,7 +502,7 @@ ircd::gpt::norm(float (&__restrict__ out)[768],
const float mean
{
math::mean<float>(in)
math::mean<float>(vector_view<const float>{in, 768})
};
for(uint j(0); j < 768; ++j)