mirror of
https://github.com/matrix-construct/construct
synced 2024-11-16 15:00:51 +01:00
ircd::gpt: Add basic interface; add options, context.
This commit is contained in:
parent
4458235dfa
commit
33a1ffd4bf
2 changed files with 198 additions and 67 deletions
|
@ -17,21 +17,56 @@ namespace ircd::gpt
|
|||
{
|
||||
IRCD_EXCEPTION(ircd::error, error)
|
||||
|
||||
u16
|
||||
generate(const vector_view<const f32> &) noexcept;
|
||||
|
||||
vector_view<f32>
|
||||
embed(const vector_view<f32> &,
|
||||
const vector_view<const u16> &) noexcept;
|
||||
struct opts;
|
||||
struct context;
|
||||
|
||||
extern const opts default_opts;
|
||||
extern log::log log;
|
||||
|
||||
vector_view<u16>
|
||||
generate(const vector_view<u16> &out,
|
||||
const vector_view<const u16> &in,
|
||||
const opts & = default_opts);
|
||||
|
||||
string_view
|
||||
generate(const mutable_buffer &out,
|
||||
const string_view &in,
|
||||
const opts & = default_opts);
|
||||
}
|
||||
|
||||
#include "vocab.h"
|
||||
#include "model.h"
|
||||
|
||||
namespace ircd::gpt
|
||||
struct ircd::gpt::opts
|
||||
{
|
||||
using vocab::detokenize;
|
||||
using vocab::tokenize;
|
||||
}
|
||||
/// Specifies the nominal halting condition based on the sequence of
|
||||
/// tokens. Generation will complete when this sequence is witnessed. Set
|
||||
/// tokens to -1 starting from the back to not match that token. Setting
|
||||
/// all tokens to -1 will ignore this condition.
|
||||
uint accept_code[3][3]
|
||||
{
|
||||
{ 13, 198, -1U, },
|
||||
{ 198, 198, -1U, },
|
||||
{ -1U, -1U, -1U, },
|
||||
};
|
||||
|
||||
/// Specifies the exceptional halting condition based on the sequence of
|
||||
/// tokens. By default, the three zeros represent three outputs of '!'
|
||||
/// which is probably an error code; note that a true "!!!" is represented
|
||||
/// by token number 10185. Set tokens to -1 starting from the back to
|
||||
/// not match that token; generated output after errors is usually garbage.
|
||||
uint error_code[3][3]
|
||||
{
|
||||
{ 0, 0, 0, },
|
||||
{ -1U, 0, 0, },
|
||||
{ -1U, 0, 0, },
|
||||
};
|
||||
|
||||
/// Limit number of output tokens. Default of -1 is unlimited; the number
|
||||
/// of tokens generated will be limited by other factors.
|
||||
uint limit {-1U};
|
||||
|
||||
/// Flip random coins over the top k logits each round. Setting to 1
|
||||
/// deterministically selects the top logit.
|
||||
uint top_k {2};
|
||||
};
|
||||
|
|
210
ircd/gpt.cc
210
ircd/gpt.cc
|
@ -8,17 +8,11 @@
|
|||
// copyright notice and this permission notice is present in all copies. The
|
||||
// full license for this software is available in the LICENSE file.
|
||||
|
||||
decltype(ircd::gpt::log)
|
||||
ircd::gpt::log
|
||||
{
|
||||
"gpt"
|
||||
};
|
||||
|
||||
namespace ircd::gpt
|
||||
{
|
||||
static void gelu(float &, const float &);
|
||||
static void gelu(float (&)[3072], const float (&)[3072]);
|
||||
static void norm(float (&)[768], const float (&)[768], const float (&)[768], const float (&)[768], const float);
|
||||
static void norm(float (&)[768], const float *, const float (&)[768], const float (&)[768], const float);
|
||||
static void fmma(float (&)[768], const float (&)[3072], const float (&)[768], const float (&)[3072][768]);
|
||||
static void fmma(float (&)[3072], const float (&)[768], const float (&)[3072], const float (&)[768][3072]);
|
||||
static void fmma(float (&)[2304], const float (&)[768], const float (&)[2304], const float (&)[768][2304]);
|
||||
|
@ -32,8 +26,10 @@ namespace ircd::gpt
|
|||
static void transform(float *, const size_t, const model::decoder &);
|
||||
static void logitsmax(float *, const float *);
|
||||
static void logits(float *, const float (&)[768], const model::decoder &);
|
||||
static void tail(float *, const float (&)[768], const model::decoder &);
|
||||
static u16 argmax(const float *);
|
||||
static void tail(float *, const float *, const model::decoder &);
|
||||
static u16 argmax(const float *, const opts &);
|
||||
|
||||
static vector_view<f32> embed(const vector_view<f32> &out, const u16 token, const u16 position);
|
||||
|
||||
std::unique_ptr<model::decoder> device
|
||||
{
|
||||
|
@ -45,6 +41,15 @@ namespace ircd::gpt
|
|||
scratch alignas(64) [1024 * 768];
|
||||
}
|
||||
|
||||
decltype(ircd::gpt::default_opts)
|
||||
ircd::gpt::default_opts;
|
||||
|
||||
decltype(ircd::gpt::log)
|
||||
ircd::gpt::log
|
||||
{
|
||||
"gpt"
|
||||
};
|
||||
|
||||
namespace ircd::gpt::model
|
||||
{
|
||||
constexpr float embed_pdrop
|
||||
|
@ -78,81 +83,172 @@ namespace ircd::gpt::model
|
|||
};
|
||||
}
|
||||
|
||||
ircd::string_view
|
||||
ircd::gpt::generate(const mutable_buffer &out,
|
||||
const string_view &in,
|
||||
const opts &opts)
|
||||
{
|
||||
u16 buf[2][256];
|
||||
const auto input_tokens
|
||||
{
|
||||
vocab::tokenize(buf[0], in)
|
||||
};
|
||||
|
||||
const auto output_tokens
|
||||
{
|
||||
generate(buf[1], input_tokens, opts)
|
||||
};
|
||||
|
||||
const auto output
|
||||
{
|
||||
vocab::detokenize(out, output_tokens)
|
||||
};
|
||||
|
||||
return output;
|
||||
}
|
||||
|
||||
ircd::vector_view<ircd::u16>
|
||||
ircd::gpt::generate(const vector_view<u16> &out,
|
||||
const vector_view<const u16> &in,
|
||||
const opts &opts)
|
||||
{
|
||||
size_t ret(0);
|
||||
bool halt(false);
|
||||
uint errc[3] {0}, accc[3] {0};
|
||||
for(uint i(0); !halt && i < out.size() && ret < opts.limit; ++i)
|
||||
{
|
||||
const size_t tokens
|
||||
{
|
||||
in.size() + i
|
||||
};
|
||||
|
||||
const vector_view<f32> scratch
|
||||
{
|
||||
gpt::scratch, tokens * 768
|
||||
};
|
||||
|
||||
for(uint j(0); j < in.size(); ++j)
|
||||
{
|
||||
const vector_view<f32> dst
|
||||
{
|
||||
data(scratch) + j * 768, 768
|
||||
};
|
||||
|
||||
const auto embedding
|
||||
{
|
||||
embed(dst, in[j], j)
|
||||
};
|
||||
}
|
||||
|
||||
for(uint j(0); j < ret; ++j)
|
||||
{
|
||||
const vector_view<f32> dst
|
||||
{
|
||||
data(scratch) + (in.size() + j) * 768, 768
|
||||
};
|
||||
|
||||
const auto embedding
|
||||
{
|
||||
embed(dst, out[j], in.size() + j)
|
||||
};
|
||||
}
|
||||
|
||||
transform(data(scratch), tokens, *device);
|
||||
|
||||
const vector_view<f32> last_embed
|
||||
{
|
||||
data(scratch) + ((tokens - 1) * 768), 768
|
||||
};
|
||||
|
||||
tail(logit, data(last_embed), *device);
|
||||
out[i] = argmax(logit, opts);
|
||||
|
||||
for(uint j(0); j < 3; ++j)
|
||||
{
|
||||
errc[j] = out[i] == opts.error_code[j][errc[j]]? errc[j] + 1: 0;
|
||||
accc[j] = out[i] == opts.accept_code[j][accc[j]]? accc[j] + 1: 0;
|
||||
}
|
||||
|
||||
for(uint j(0); j < 3; ++j)
|
||||
{
|
||||
halt |= errc[j] >= 3 || (errc[j] && opts.error_code[j][errc[j] + 1] == -1U);
|
||||
halt |= accc[j] >= 3 || (accc[j] && opts.accept_code[j][accc[j] + 1] == -1U);
|
||||
}
|
||||
|
||||
++ret;
|
||||
}
|
||||
|
||||
return vector_view<u16>
|
||||
{
|
||||
out, ret
|
||||
};
|
||||
}
|
||||
|
||||
ircd::vector_view<ircd::f32>
|
||||
ircd::gpt::embed(const vector_view<f32> &out,
|
||||
const vector_view<const u16> &in)
|
||||
noexcept
|
||||
const u16 token,
|
||||
const u16 position)
|
||||
{
|
||||
assert(device);
|
||||
|
||||
uint i(0);
|
||||
for(; i < in.size(); ++i)
|
||||
const auto &wpe
|
||||
{
|
||||
const auto &wpe
|
||||
{
|
||||
device->wpe[i]
|
||||
};
|
||||
device->wpe[position]
|
||||
};
|
||||
|
||||
const auto &wte
|
||||
{
|
||||
device->wte[in[i]]
|
||||
};
|
||||
const auto &wte
|
||||
{
|
||||
device->wte[token]
|
||||
};
|
||||
|
||||
for(uint j(0); j < 768; ++j)
|
||||
out[i * 768 + j] = wte[j] + wpe[j];
|
||||
}
|
||||
for(uint j(0); j < 768; ++j)
|
||||
out[j] = wte[j] + wpe[j];
|
||||
|
||||
return vector_view<f32>
|
||||
{
|
||||
data(out), i * 768
|
||||
data(out), 768
|
||||
};
|
||||
}
|
||||
|
||||
uint16_t
|
||||
ircd::gpt::generate(const vector_view<const f32> &in)
|
||||
noexcept
|
||||
ircd::gpt::argmax(const float *const __restrict__ logit,
|
||||
const opts &opts)
|
||||
{
|
||||
always_assert(in.size() % 768 == 0);
|
||||
const auto toks
|
||||
static const auto max
|
||||
{
|
||||
in.size() / 768
|
||||
32U
|
||||
};
|
||||
|
||||
const vector_view<f32> scratch
|
||||
const auto top
|
||||
{
|
||||
gpt::scratch, in.size()
|
||||
std::clamp(opts.top_k, 1U, max - 1)
|
||||
};
|
||||
|
||||
for(uint i(0); i < in.size(); ++i)
|
||||
scratch[i] = in[i];
|
||||
|
||||
transform(data(scratch), toks, *device);
|
||||
|
||||
static float
|
||||
buf alignas(64) [768];
|
||||
|
||||
for(uint i(0); i < 768; ++i)
|
||||
buf[i] = scratch[(toks - 1) * 768 + i];
|
||||
|
||||
tail(logit, buf, *device);
|
||||
return argmax(logit);
|
||||
}
|
||||
|
||||
uint16_t
|
||||
ircd::gpt::argmax(const float *const __restrict__ logit)
|
||||
{
|
||||
u16 ret(0);
|
||||
u16 best[max] {0};
|
||||
for(uint j(0); j < vocab::tokens; ++j)
|
||||
if(logit[j] > logit[ret])
|
||||
ret = j;
|
||||
{
|
||||
best[top] = j;
|
||||
std::sort(begin(best), begin(best) + top + 1, [&logit]
|
||||
(const auto &a, const auto &b)
|
||||
{
|
||||
return logit[a] > logit[b];
|
||||
});
|
||||
}
|
||||
|
||||
return ret;
|
||||
const auto x
|
||||
{
|
||||
top > 1?
|
||||
rand::integer(0, top - 1):
|
||||
0
|
||||
};
|
||||
|
||||
return best[x];
|
||||
}
|
||||
|
||||
[[gnu::noinline]]
|
||||
void
|
||||
ircd::gpt::tail(float *const __restrict__ logit,
|
||||
const float (&__restrict__ state)[768],
|
||||
const float *const __restrict__ state,
|
||||
const model::decoder &d)
|
||||
{
|
||||
static float
|
||||
|
@ -396,7 +492,7 @@ ircd::gpt::mask(float (&__restrict__ out)[12][1024][1024],
|
|||
|
||||
void
|
||||
ircd::gpt::norm(float (&__restrict__ out)[768],
|
||||
const float (&__restrict__ in)[768],
|
||||
const float *const in,
|
||||
const float (&__restrict__ bias)[768],
|
||||
const float (&__restrict__ weight)[768],
|
||||
const float epsilon)
|
||||
|
@ -406,7 +502,7 @@ ircd::gpt::norm(float (&__restrict__ out)[768],
|
|||
|
||||
const float mean
|
||||
{
|
||||
math::mean<float>(in)
|
||||
math::mean<float>(vector_view<const float>{in, 768})
|
||||
};
|
||||
|
||||
for(uint j(0); j < 768; ++j)
|
||||
|
|
Loading…
Reference in a new issue