From 33a1ffd4bf54bcf530254fcb1ec248de2443ce05 Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Tue, 9 Mar 2021 02:08:47 -0800 Subject: [PATCH] ircd::gpt: Add basic interface; add options, context. --- include/ircd/gpt/gpt.h | 55 +++++++++-- ircd/gpt.cc | 210 ++++++++++++++++++++++++++++++----------- 2 files changed, 198 insertions(+), 67 deletions(-) diff --git a/include/ircd/gpt/gpt.h b/include/ircd/gpt/gpt.h index abd01b39d..b01d9687f 100644 --- a/include/ircd/gpt/gpt.h +++ b/include/ircd/gpt/gpt.h @@ -17,21 +17,56 @@ namespace ircd::gpt { IRCD_EXCEPTION(ircd::error, error) - u16 - generate(const vector_view &) noexcept; - - vector_view - embed(const vector_view &, - const vector_view &) noexcept; + struct opts; + struct context; + extern const opts default_opts; extern log::log log; + + vector_view + generate(const vector_view &out, + const vector_view &in, + const opts & = default_opts); + + string_view + generate(const mutable_buffer &out, + const string_view &in, + const opts & = default_opts); } #include "vocab.h" #include "model.h" -namespace ircd::gpt +struct ircd::gpt::opts { - using vocab::detokenize; - using vocab::tokenize; -} + /// Specifies the nominal halting condition based on the sequence of + /// tokens. Generation will complete when this sequence is witnessed. Set + /// tokens to -1 starting from the back to not match that token. Setting + /// all tokens to -1 will ignore this condition. + uint accept_code[3][3] + { + { 13, 198, -1U, }, + { 198, 198, -1U, }, + { -1U, -1U, -1U, }, + }; + + /// Specifies the exceptional halting condition based on the sequence of + /// tokens. By default, the three zeros represent three outputs of '!' + /// which is probably an error code; note that a true "!!!" is represented + /// by token number 10185. Set tokens to -1 starting from the back to + /// not match that token; generated output after errors is usually garbage. + uint error_code[3][3] + { + { 0, 0, 0, }, + { -1U, 0, 0, }, + { -1U, 0, 0, }, + }; + + /// Limit number of output tokens. Default of -1 is unlimited; the number + /// of tokens generated will be limited by other factors. + uint limit {-1U}; + + /// Flip random coins over the top k logits each round. Setting to 1 + /// deterministically selects the top logit. + uint top_k {2}; +}; diff --git a/ircd/gpt.cc b/ircd/gpt.cc index f9f2fc018..18a35dc5d 100644 --- a/ircd/gpt.cc +++ b/ircd/gpt.cc @@ -8,17 +8,11 @@ // copyright notice and this permission notice is present in all copies. The // full license for this software is available in the LICENSE file. -decltype(ircd::gpt::log) -ircd::gpt::log -{ - "gpt" -}; - namespace ircd::gpt { static void gelu(float &, const float &); static void gelu(float (&)[3072], const float (&)[3072]); - static void norm(float (&)[768], const float (&)[768], const float (&)[768], const float (&)[768], const float); + static void norm(float (&)[768], const float *, const float (&)[768], const float (&)[768], const float); static void fmma(float (&)[768], const float (&)[3072], const float (&)[768], const float (&)[3072][768]); static void fmma(float (&)[3072], const float (&)[768], const float (&)[3072], const float (&)[768][3072]); static void fmma(float (&)[2304], const float (&)[768], const float (&)[2304], const float (&)[768][2304]); @@ -32,8 +26,10 @@ namespace ircd::gpt static void transform(float *, const size_t, const model::decoder &); static void logitsmax(float *, const float *); static void logits(float *, const float (&)[768], const model::decoder &); - static void tail(float *, const float (&)[768], const model::decoder &); - static u16 argmax(const float *); + static void tail(float *, const float *, const model::decoder &); + static u16 argmax(const float *, const opts &); + + static vector_view embed(const vector_view &out, const u16 token, const u16 position); std::unique_ptr device { @@ -45,6 +41,15 @@ namespace ircd::gpt scratch alignas(64) [1024 * 768]; } +decltype(ircd::gpt::default_opts) +ircd::gpt::default_opts; + +decltype(ircd::gpt::log) +ircd::gpt::log +{ + "gpt" +}; + namespace ircd::gpt::model { constexpr float embed_pdrop @@ -78,81 +83,172 @@ namespace ircd::gpt::model }; } +ircd::string_view +ircd::gpt::generate(const mutable_buffer &out, + const string_view &in, + const opts &opts) +{ + u16 buf[2][256]; + const auto input_tokens + { + vocab::tokenize(buf[0], in) + }; + + const auto output_tokens + { + generate(buf[1], input_tokens, opts) + }; + + const auto output + { + vocab::detokenize(out, output_tokens) + }; + + return output; +} + +ircd::vector_view +ircd::gpt::generate(const vector_view &out, + const vector_view &in, + const opts &opts) +{ + size_t ret(0); + bool halt(false); + uint errc[3] {0}, accc[3] {0}; + for(uint i(0); !halt && i < out.size() && ret < opts.limit; ++i) + { + const size_t tokens + { + in.size() + i + }; + + const vector_view scratch + { + gpt::scratch, tokens * 768 + }; + + for(uint j(0); j < in.size(); ++j) + { + const vector_view dst + { + data(scratch) + j * 768, 768 + }; + + const auto embedding + { + embed(dst, in[j], j) + }; + } + + for(uint j(0); j < ret; ++j) + { + const vector_view dst + { + data(scratch) + (in.size() + j) * 768, 768 + }; + + const auto embedding + { + embed(dst, out[j], in.size() + j) + }; + } + + transform(data(scratch), tokens, *device); + + const vector_view last_embed + { + data(scratch) + ((tokens - 1) * 768), 768 + }; + + tail(logit, data(last_embed), *device); + out[i] = argmax(logit, opts); + + for(uint j(0); j < 3; ++j) + { + errc[j] = out[i] == opts.error_code[j][errc[j]]? errc[j] + 1: 0; + accc[j] = out[i] == opts.accept_code[j][accc[j]]? accc[j] + 1: 0; + } + + for(uint j(0); j < 3; ++j) + { + halt |= errc[j] >= 3 || (errc[j] && opts.error_code[j][errc[j] + 1] == -1U); + halt |= accc[j] >= 3 || (accc[j] && opts.accept_code[j][accc[j] + 1] == -1U); + } + + ++ret; + } + + return vector_view + { + out, ret + }; +} + ircd::vector_view ircd::gpt::embed(const vector_view &out, - const vector_view &in) -noexcept + const u16 token, + const u16 position) { assert(device); - uint i(0); - for(; i < in.size(); ++i) + const auto &wpe { - const auto &wpe - { - device->wpe[i] - }; + device->wpe[position] + }; - const auto &wte - { - device->wte[in[i]] - }; + const auto &wte + { + device->wte[token] + }; - for(uint j(0); j < 768; ++j) - out[i * 768 + j] = wte[j] + wpe[j]; - } + for(uint j(0); j < 768; ++j) + out[j] = wte[j] + wpe[j]; return vector_view { - data(out), i * 768 + data(out), 768 }; } uint16_t -ircd::gpt::generate(const vector_view &in) -noexcept +ircd::gpt::argmax(const float *const __restrict__ logit, + const opts &opts) { - always_assert(in.size() % 768 == 0); - const auto toks + static const auto max { - in.size() / 768 + 32U }; - const vector_view scratch + const auto top { - gpt::scratch, in.size() + std::clamp(opts.top_k, 1U, max - 1) }; - for(uint i(0); i < in.size(); ++i) - scratch[i] = in[i]; - - transform(data(scratch), toks, *device); - - static float - buf alignas(64) [768]; - - for(uint i(0); i < 768; ++i) - buf[i] = scratch[(toks - 1) * 768 + i]; - - tail(logit, buf, *device); - return argmax(logit); -} - -uint16_t -ircd::gpt::argmax(const float *const __restrict__ logit) -{ - u16 ret(0); + u16 best[max] {0}; for(uint j(0); j < vocab::tokens; ++j) - if(logit[j] > logit[ret]) - ret = j; + { + best[top] = j; + std::sort(begin(best), begin(best) + top + 1, [&logit] + (const auto &a, const auto &b) + { + return logit[a] > logit[b]; + }); + } - return ret; + const auto x + { + top > 1? + rand::integer(0, top - 1): + 0 + }; + + return best[x]; } [[gnu::noinline]] void ircd::gpt::tail(float *const __restrict__ logit, - const float (&__restrict__ state)[768], + const float *const __restrict__ state, const model::decoder &d) { static float @@ -396,7 +492,7 @@ ircd::gpt::mask(float (&__restrict__ out)[12][1024][1024], void ircd::gpt::norm(float (&__restrict__ out)[768], - const float (&__restrict__ in)[768], + const float *const in, const float (&__restrict__ bias)[768], const float (&__restrict__ weight)[768], const float epsilon) @@ -406,7 +502,7 @@ ircd::gpt::norm(float (&__restrict__ out)[768], const float mean { - math::mean(in) + math::mean(vector_view{in, 768}) }; for(uint j(0); j < 768; ++j)