mirror of
https://github.com/matrix-construct/construct
synced 2024-11-16 23:10:54 +01:00
ircd::gpt: Add basic interface; add options, context.
This commit is contained in:
parent
4458235dfa
commit
33a1ffd4bf
2 changed files with 198 additions and 67 deletions
|
@ -17,21 +17,56 @@ namespace ircd::gpt
|
||||||
{
|
{
|
||||||
IRCD_EXCEPTION(ircd::error, error)
|
IRCD_EXCEPTION(ircd::error, error)
|
||||||
|
|
||||||
u16
|
struct opts;
|
||||||
generate(const vector_view<const f32> &) noexcept;
|
struct context;
|
||||||
|
|
||||||
vector_view<f32>
|
|
||||||
embed(const vector_view<f32> &,
|
|
||||||
const vector_view<const u16> &) noexcept;
|
|
||||||
|
|
||||||
|
extern const opts default_opts;
|
||||||
extern log::log log;
|
extern log::log log;
|
||||||
|
|
||||||
|
vector_view<u16>
|
||||||
|
generate(const vector_view<u16> &out,
|
||||||
|
const vector_view<const u16> &in,
|
||||||
|
const opts & = default_opts);
|
||||||
|
|
||||||
|
string_view
|
||||||
|
generate(const mutable_buffer &out,
|
||||||
|
const string_view &in,
|
||||||
|
const opts & = default_opts);
|
||||||
}
|
}
|
||||||
|
|
||||||
#include "vocab.h"
|
#include "vocab.h"
|
||||||
#include "model.h"
|
#include "model.h"
|
||||||
|
|
||||||
namespace ircd::gpt
|
struct ircd::gpt::opts
|
||||||
{
|
{
|
||||||
using vocab::detokenize;
|
/// Specifies the nominal halting condition based on the sequence of
|
||||||
using vocab::tokenize;
|
/// tokens. Generation will complete when this sequence is witnessed. Set
|
||||||
}
|
/// tokens to -1 starting from the back to not match that token. Setting
|
||||||
|
/// all tokens to -1 will ignore this condition.
|
||||||
|
uint accept_code[3][3]
|
||||||
|
{
|
||||||
|
{ 13, 198, -1U, },
|
||||||
|
{ 198, 198, -1U, },
|
||||||
|
{ -1U, -1U, -1U, },
|
||||||
|
};
|
||||||
|
|
||||||
|
/// Specifies the exceptional halting condition based on the sequence of
|
||||||
|
/// tokens. By default, the three zeros represent three outputs of '!'
|
||||||
|
/// which is probably an error code; note that a true "!!!" is represented
|
||||||
|
/// by token number 10185. Set tokens to -1 starting from the back to
|
||||||
|
/// not match that token; generated output after errors is usually garbage.
|
||||||
|
uint error_code[3][3]
|
||||||
|
{
|
||||||
|
{ 0, 0, 0, },
|
||||||
|
{ -1U, 0, 0, },
|
||||||
|
{ -1U, 0, 0, },
|
||||||
|
};
|
||||||
|
|
||||||
|
/// Limit number of output tokens. Default of -1 is unlimited; the number
|
||||||
|
/// of tokens generated will be limited by other factors.
|
||||||
|
uint limit {-1U};
|
||||||
|
|
||||||
|
/// Flip random coins over the top k logits each round. Setting to 1
|
||||||
|
/// deterministically selects the top logit.
|
||||||
|
uint top_k {2};
|
||||||
|
};
|
||||||
|
|
210
ircd/gpt.cc
210
ircd/gpt.cc
|
@ -8,17 +8,11 @@
|
||||||
// copyright notice and this permission notice is present in all copies. The
|
// copyright notice and this permission notice is present in all copies. The
|
||||||
// full license for this software is available in the LICENSE file.
|
// full license for this software is available in the LICENSE file.
|
||||||
|
|
||||||
decltype(ircd::gpt::log)
|
|
||||||
ircd::gpt::log
|
|
||||||
{
|
|
||||||
"gpt"
|
|
||||||
};
|
|
||||||
|
|
||||||
namespace ircd::gpt
|
namespace ircd::gpt
|
||||||
{
|
{
|
||||||
static void gelu(float &, const float &);
|
static void gelu(float &, const float &);
|
||||||
static void gelu(float (&)[3072], const float (&)[3072]);
|
static void gelu(float (&)[3072], const float (&)[3072]);
|
||||||
static void norm(float (&)[768], const float (&)[768], const float (&)[768], const float (&)[768], const float);
|
static void norm(float (&)[768], const float *, const float (&)[768], const float (&)[768], const float);
|
||||||
static void fmma(float (&)[768], const float (&)[3072], const float (&)[768], const float (&)[3072][768]);
|
static void fmma(float (&)[768], const float (&)[3072], const float (&)[768], const float (&)[3072][768]);
|
||||||
static void fmma(float (&)[3072], const float (&)[768], const float (&)[3072], const float (&)[768][3072]);
|
static void fmma(float (&)[3072], const float (&)[768], const float (&)[3072], const float (&)[768][3072]);
|
||||||
static void fmma(float (&)[2304], const float (&)[768], const float (&)[2304], const float (&)[768][2304]);
|
static void fmma(float (&)[2304], const float (&)[768], const float (&)[2304], const float (&)[768][2304]);
|
||||||
|
@ -32,8 +26,10 @@ namespace ircd::gpt
|
||||||
static void transform(float *, const size_t, const model::decoder &);
|
static void transform(float *, const size_t, const model::decoder &);
|
||||||
static void logitsmax(float *, const float *);
|
static void logitsmax(float *, const float *);
|
||||||
static void logits(float *, const float (&)[768], const model::decoder &);
|
static void logits(float *, const float (&)[768], const model::decoder &);
|
||||||
static void tail(float *, const float (&)[768], const model::decoder &);
|
static void tail(float *, const float *, const model::decoder &);
|
||||||
static u16 argmax(const float *);
|
static u16 argmax(const float *, const opts &);
|
||||||
|
|
||||||
|
static vector_view<f32> embed(const vector_view<f32> &out, const u16 token, const u16 position);
|
||||||
|
|
||||||
std::unique_ptr<model::decoder> device
|
std::unique_ptr<model::decoder> device
|
||||||
{
|
{
|
||||||
|
@ -45,6 +41,15 @@ namespace ircd::gpt
|
||||||
scratch alignas(64) [1024 * 768];
|
scratch alignas(64) [1024 * 768];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
decltype(ircd::gpt::default_opts)
|
||||||
|
ircd::gpt::default_opts;
|
||||||
|
|
||||||
|
decltype(ircd::gpt::log)
|
||||||
|
ircd::gpt::log
|
||||||
|
{
|
||||||
|
"gpt"
|
||||||
|
};
|
||||||
|
|
||||||
namespace ircd::gpt::model
|
namespace ircd::gpt::model
|
||||||
{
|
{
|
||||||
constexpr float embed_pdrop
|
constexpr float embed_pdrop
|
||||||
|
@ -78,81 +83,172 @@ namespace ircd::gpt::model
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ircd::string_view
|
||||||
|
ircd::gpt::generate(const mutable_buffer &out,
|
||||||
|
const string_view &in,
|
||||||
|
const opts &opts)
|
||||||
|
{
|
||||||
|
u16 buf[2][256];
|
||||||
|
const auto input_tokens
|
||||||
|
{
|
||||||
|
vocab::tokenize(buf[0], in)
|
||||||
|
};
|
||||||
|
|
||||||
|
const auto output_tokens
|
||||||
|
{
|
||||||
|
generate(buf[1], input_tokens, opts)
|
||||||
|
};
|
||||||
|
|
||||||
|
const auto output
|
||||||
|
{
|
||||||
|
vocab::detokenize(out, output_tokens)
|
||||||
|
};
|
||||||
|
|
||||||
|
return output;
|
||||||
|
}
|
||||||
|
|
||||||
|
ircd::vector_view<ircd::u16>
|
||||||
|
ircd::gpt::generate(const vector_view<u16> &out,
|
||||||
|
const vector_view<const u16> &in,
|
||||||
|
const opts &opts)
|
||||||
|
{
|
||||||
|
size_t ret(0);
|
||||||
|
bool halt(false);
|
||||||
|
uint errc[3] {0}, accc[3] {0};
|
||||||
|
for(uint i(0); !halt && i < out.size() && ret < opts.limit; ++i)
|
||||||
|
{
|
||||||
|
const size_t tokens
|
||||||
|
{
|
||||||
|
in.size() + i
|
||||||
|
};
|
||||||
|
|
||||||
|
const vector_view<f32> scratch
|
||||||
|
{
|
||||||
|
gpt::scratch, tokens * 768
|
||||||
|
};
|
||||||
|
|
||||||
|
for(uint j(0); j < in.size(); ++j)
|
||||||
|
{
|
||||||
|
const vector_view<f32> dst
|
||||||
|
{
|
||||||
|
data(scratch) + j * 768, 768
|
||||||
|
};
|
||||||
|
|
||||||
|
const auto embedding
|
||||||
|
{
|
||||||
|
embed(dst, in[j], j)
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
for(uint j(0); j < ret; ++j)
|
||||||
|
{
|
||||||
|
const vector_view<f32> dst
|
||||||
|
{
|
||||||
|
data(scratch) + (in.size() + j) * 768, 768
|
||||||
|
};
|
||||||
|
|
||||||
|
const auto embedding
|
||||||
|
{
|
||||||
|
embed(dst, out[j], in.size() + j)
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
transform(data(scratch), tokens, *device);
|
||||||
|
|
||||||
|
const vector_view<f32> last_embed
|
||||||
|
{
|
||||||
|
data(scratch) + ((tokens - 1) * 768), 768
|
||||||
|
};
|
||||||
|
|
||||||
|
tail(logit, data(last_embed), *device);
|
||||||
|
out[i] = argmax(logit, opts);
|
||||||
|
|
||||||
|
for(uint j(0); j < 3; ++j)
|
||||||
|
{
|
||||||
|
errc[j] = out[i] == opts.error_code[j][errc[j]]? errc[j] + 1: 0;
|
||||||
|
accc[j] = out[i] == opts.accept_code[j][accc[j]]? accc[j] + 1: 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
for(uint j(0); j < 3; ++j)
|
||||||
|
{
|
||||||
|
halt |= errc[j] >= 3 || (errc[j] && opts.error_code[j][errc[j] + 1] == -1U);
|
||||||
|
halt |= accc[j] >= 3 || (accc[j] && opts.accept_code[j][accc[j] + 1] == -1U);
|
||||||
|
}
|
||||||
|
|
||||||
|
++ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
return vector_view<u16>
|
||||||
|
{
|
||||||
|
out, ret
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
ircd::vector_view<ircd::f32>
|
ircd::vector_view<ircd::f32>
|
||||||
ircd::gpt::embed(const vector_view<f32> &out,
|
ircd::gpt::embed(const vector_view<f32> &out,
|
||||||
const vector_view<const u16> &in)
|
const u16 token,
|
||||||
noexcept
|
const u16 position)
|
||||||
{
|
{
|
||||||
assert(device);
|
assert(device);
|
||||||
|
|
||||||
uint i(0);
|
const auto &wpe
|
||||||
for(; i < in.size(); ++i)
|
|
||||||
{
|
{
|
||||||
const auto &wpe
|
device->wpe[position]
|
||||||
{
|
};
|
||||||
device->wpe[i]
|
|
||||||
};
|
|
||||||
|
|
||||||
const auto &wte
|
const auto &wte
|
||||||
{
|
{
|
||||||
device->wte[in[i]]
|
device->wte[token]
|
||||||
};
|
};
|
||||||
|
|
||||||
for(uint j(0); j < 768; ++j)
|
for(uint j(0); j < 768; ++j)
|
||||||
out[i * 768 + j] = wte[j] + wpe[j];
|
out[j] = wte[j] + wpe[j];
|
||||||
}
|
|
||||||
|
|
||||||
return vector_view<f32>
|
return vector_view<f32>
|
||||||
{
|
{
|
||||||
data(out), i * 768
|
data(out), 768
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
uint16_t
|
uint16_t
|
||||||
ircd::gpt::generate(const vector_view<const f32> &in)
|
ircd::gpt::argmax(const float *const __restrict__ logit,
|
||||||
noexcept
|
const opts &opts)
|
||||||
{
|
{
|
||||||
always_assert(in.size() % 768 == 0);
|
static const auto max
|
||||||
const auto toks
|
|
||||||
{
|
{
|
||||||
in.size() / 768
|
32U
|
||||||
};
|
};
|
||||||
|
|
||||||
const vector_view<f32> scratch
|
const auto top
|
||||||
{
|
{
|
||||||
gpt::scratch, in.size()
|
std::clamp(opts.top_k, 1U, max - 1)
|
||||||
};
|
};
|
||||||
|
|
||||||
for(uint i(0); i < in.size(); ++i)
|
u16 best[max] {0};
|
||||||
scratch[i] = in[i];
|
|
||||||
|
|
||||||
transform(data(scratch), toks, *device);
|
|
||||||
|
|
||||||
static float
|
|
||||||
buf alignas(64) [768];
|
|
||||||
|
|
||||||
for(uint i(0); i < 768; ++i)
|
|
||||||
buf[i] = scratch[(toks - 1) * 768 + i];
|
|
||||||
|
|
||||||
tail(logit, buf, *device);
|
|
||||||
return argmax(logit);
|
|
||||||
}
|
|
||||||
|
|
||||||
uint16_t
|
|
||||||
ircd::gpt::argmax(const float *const __restrict__ logit)
|
|
||||||
{
|
|
||||||
u16 ret(0);
|
|
||||||
for(uint j(0); j < vocab::tokens; ++j)
|
for(uint j(0); j < vocab::tokens; ++j)
|
||||||
if(logit[j] > logit[ret])
|
{
|
||||||
ret = j;
|
best[top] = j;
|
||||||
|
std::sort(begin(best), begin(best) + top + 1, [&logit]
|
||||||
|
(const auto &a, const auto &b)
|
||||||
|
{
|
||||||
|
return logit[a] > logit[b];
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
return ret;
|
const auto x
|
||||||
|
{
|
||||||
|
top > 1?
|
||||||
|
rand::integer(0, top - 1):
|
||||||
|
0
|
||||||
|
};
|
||||||
|
|
||||||
|
return best[x];
|
||||||
}
|
}
|
||||||
|
|
||||||
[[gnu::noinline]]
|
[[gnu::noinline]]
|
||||||
void
|
void
|
||||||
ircd::gpt::tail(float *const __restrict__ logit,
|
ircd::gpt::tail(float *const __restrict__ logit,
|
||||||
const float (&__restrict__ state)[768],
|
const float *const __restrict__ state,
|
||||||
const model::decoder &d)
|
const model::decoder &d)
|
||||||
{
|
{
|
||||||
static float
|
static float
|
||||||
|
@ -396,7 +492,7 @@ ircd::gpt::mask(float (&__restrict__ out)[12][1024][1024],
|
||||||
|
|
||||||
void
|
void
|
||||||
ircd::gpt::norm(float (&__restrict__ out)[768],
|
ircd::gpt::norm(float (&__restrict__ out)[768],
|
||||||
const float (&__restrict__ in)[768],
|
const float *const in,
|
||||||
const float (&__restrict__ bias)[768],
|
const float (&__restrict__ bias)[768],
|
||||||
const float (&__restrict__ weight)[768],
|
const float (&__restrict__ weight)[768],
|
||||||
const float epsilon)
|
const float epsilon)
|
||||||
|
@ -406,7 +502,7 @@ ircd::gpt::norm(float (&__restrict__ out)[768],
|
||||||
|
|
||||||
const float mean
|
const float mean
|
||||||
{
|
{
|
||||||
math::mean<float>(in)
|
math::mean<float>(vector_view<const float>{in, 768})
|
||||||
};
|
};
|
||||||
|
|
||||||
for(uint j(0); j < 768; ++j)
|
for(uint j(0); j < 768; ++j)
|
||||||
|
|
Loading…
Reference in a new issue