From 53c4260a2192beed58cacf8afd66968f88f679c1 Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Fri, 5 Mar 2021 15:33:05 -0800 Subject: [PATCH] ircd::gpt: Add Basic Latin (lower) and C0 replacement LUT; various. --- include/ircd/gpt/vocab.h | 10 +- ircd/gpt_vocab.cc | 426 ++++++++++++++++++++++++--------------- 2 files changed, 268 insertions(+), 168 deletions(-) diff --git a/include/ircd/gpt/vocab.h b/include/ircd/gpt/vocab.h index dee402521..4d6d700d3 100644 --- a/include/ircd/gpt/vocab.h +++ b/include/ircd/gpt/vocab.h @@ -15,6 +15,8 @@ /// namespace ircd::gpt::vocab { + IRCD_EXCEPTION(gpt::error, error) + // Actual number of tokens and merges stored in following lists. extern size_t tokens, @@ -32,12 +34,8 @@ namespace ircd::gpt::vocab merges_path; // Tokenize UTF-8 input string of any length into proper token values, - vector_view - tokenize(const vector_view &out, - const string_view &in) noexcept; + vector_view tokenize(const vector_view &out, const string_view &in); // Decode token values to build output text string. - string_view - detokenize(const mutable_buffer &out, - const vector_view &in) noexcept; + string_view detokenize(const mutable_buffer &out, const vector_view &in); } diff --git a/ircd/gpt_vocab.cc b/ircd/gpt_vocab.cc index d479b0c3b..e3955e0f0 100644 --- a/ircd/gpt_vocab.cc +++ b/ircd/gpt_vocab.cc @@ -10,24 +10,62 @@ namespace ircd::gpt::vocab { - static u16 find_token(const u8x16) noexcept; - static uint find_tokens(u16x16 &, const uint, const u8x16 (&)[16], const uint) noexcept; - static u16 find_merge(const u8x16, const u8x16) noexcept; - - static u16 bpe_score(u16 (&)[16], const u8x16 (&)[16][2], const uint) noexcept; - static uint bpe_merge(u8x16 (&)[16][2], u16 (&)[16], const uint, const u16) noexcept; - static uint bpe_postpare(u8x16 (&)[16], const u8x16 (&)[16][2], const uint) noexcept; - static uint bpe_prepare(u8x16 (&)[16][2], const u8x16) noexcept; - static uint bpe_tokenize(u8x16 (&)[16], const u8x16) noexcept; - - static u64x2 pre_tokenize_split(u8x16 (&)[16], u32x16, u32x16, u32x16) noexcept; - static u64x2 pre_tokenize(u8x16 (&)[16], const u8x16, const u8x16) noexcept; - + static u16 find_token(const u8x16); + static u16 find_merge(const u8x16, const u8x16); + static u16 bpe_score(u16 (&)[16], const u8x16 (&)[16][2], const uint); + static uint bpe_merge(u8x16 (&)[16][2], u16 (&)[16], const uint, const u16); + static uint bpe_postpare(u8x16 (&)[16], const u8x16 (&)[16][2], const uint); + static uint bpe_prepare(u8x16 (&)[16][2], const u8x16); + static uint bpe_tokenize(u8x16 (&)[16], const u8x16); + static u64x2 pre_tokenize_split(u8x16 (&)[16], u32x16, u32x16, u32x16); + static u64x2 pre_tokenize(u8x16 (&)[16], const u8x16, const u8x16); + static u64x2 unk_tokenize(u16x16 &, const u8x16, u64); static u64x2 tokenize_block(u16x16 &, const u8x16, const u8x16) noexcept; - static void init_tokens(), init_merges(); + + [[gnu::visibility("internal")]] + extern const char32_t charset[256]; } +/// Remapping of single byte characters (Control (C0) and Basic Latin (ASCII)). +decltype(ircd::gpt::vocab::charset) +ircd::gpt::vocab::charset +alignas(64) +{ + U'Ā', U'ā', U'Ă', U'ă', U'Ą', U'ą', U'Ć', U'ć', // [0x07] + U'Ĉ', U'ĉ', U'Ċ', U'ċ', U'Č', U'č', U'Ď', U'ď', // [0x0F] + U'Đ', U'đ', U'Ē', U'ē', U'Ĕ', U'ĕ', U'Ė', U'ė', // [0x17] + U'Ę', U'ę', U'Ě', U'ě', U'Ĝ', U'ĝ', U'Ğ', U'ğ', // [0x1F] + U'Ġ', U'!', U'"', U'#', U'$', U'%', U'&', U'\'', // [0x27] + U'(', U')', U'*', U'+', U',', U'-', U'.', U'/', // [0x2F] + U'0', U'1', U'2', U'3', U'4', U'5', U'6', U'7', // [0x37] + U'8', U'9', U':', U';', U'<', U'=', U'>', U'?', // [0x3F] + U'@', U'A', U'B', U'C', U'D', U'E', U'F', U'G', // [0x47] + U'H', U'I', U'J', U'K', U'L', U'M', U'N', U'O', // [0x4F] + U'P', U'Q', U'R', U'S', U'T', U'U', U'V', U'W', // [0x57] + U'X', U'Y', U'Z', U'[', U'\\', U']', U'^', U'_', // [0x5F] + U'`', U'a', U'b', U'c', U'd', U'e', U'f', U'g', // [0x67] + U'h', U'i', U'j', U'k', U'l', U'm', U'n', U'o', // [0x6F] + U'p', U'q', U'r', U's', U't', U'u', U'v', U'w', // [0x77] + U'x', U'y', U'z', U'{', U'|', U'}', U'~', U'ġ', // [0x7F] + U'Ģ', U'ģ', U'Ĥ', U'ĥ', U'Ħ', U'ħ', U'Ĩ', U'ĩ', // [0x87] + U'Ī', U'ī', U'Ĭ', U'ĭ', U'Į', U'į', U'İ', U'ı', // [0x8F] + U'IJ', U'ij', U'Ĵ', U'ĵ', U'Ķ', U'ķ', U'ĸ', U'Ĺ', // [0x97] + U'ĺ', U'Ļ', U'ļ', U'Ľ', U'ľ', U'Ŀ', U'ŀ', U'Ł', // [0x9F] + U'ł', U'¡', U'¢', U'£', U'¤', U'¥', U'¦', U'§', // [0xA7] + U'¨', U'©', U'ª', U'«', U'¬', U'Ń', U'®', U'¯', // [0xAF] + U'°', U'±', U'²', U'³', U'´', U'µ', U'¶', U'·', // [0xB7] + U'¸', U'¹', U'º', U'»', U'¼', U'½', U'¾', U'¿', // [0xBF] + U'À', U'Á', U'Â', U'Ã', U'Ä', U'Å', U'Æ', U'Ç', // [0xC7] + U'È', U'É', U'Ê', U'Ë', U'Ì', U'Í', U'Î', U'Ï', // [0xCF] + U'Ð', U'Ñ', U'Ò', U'Ó', U'Ô', U'Õ', U'Ö', U'×', // [0xD7] + U'Ø', U'Ù', U'Ú', U'Û', U'Ü', U'Ý', U'Þ', U'ß', // [0xDF] + U'à', U'á', U'â', U'ã', U'ä', U'å', U'æ', U'ç', // [0xE7] + U'è', U'é', U'ê', U'ë', U'ì', U'í', U'î', U'ï', // [0xEF] + U'ð', U'ñ', U'ò', U'ó', U'ô', U'õ', U'ö', U'÷', // [0xF7] + U'ø', U'ù', U'ú', U'û', U'ü', U'ý', U'þ', U'ÿ', // [0xFF] +}; + decltype(ircd::gpt::vocab::tokens) ircd::gpt::vocab::tokens; @@ -124,15 +162,34 @@ ircd::gpt::vocab::init_merges() ircd::string_view ircd::gpt::vocab::detokenize(const mutable_buffer &out, const vector_view &in) -noexcept { - mutable_buffer buf(out); - for(const u16 &token : in) - consume(buf, copy(buf, const_buffer(vocab::token[token], size(string_view(vocab::token[token]))))); + size_t off(0); + for(const u16 &id : in) + { + const auto &token + { + vocab::token[id] + }; + const string_view text + { + token, strnlen(token, 16) + }; + + string_view dest + { + data(out + off), copy(out + off, text) + }; + + dest = replace(out + off, dest, "Ġ"_sv, " "_sv); + dest = replace(out + off, dest, "Ċ"_sv, "\n"_sv); + off += size(dest); + } + + assert(off <= size(out)); return string_view { - data(out), data(buf) + data(out), off }; } @@ -143,7 +200,6 @@ noexcept ircd::vector_view ircd::gpt::vocab::tokenize(const vector_view &out, const string_view &in) -noexcept { using input_t = u8x16; using block_t = u16x16; @@ -165,6 +221,7 @@ noexcept }; assert(consumed[0] <= out.size()); + assert(consumed[0] <= consumed[1]); return vector_view ( out.data(), consumed[0] @@ -183,27 +240,39 @@ noexcept pre_tokenize(pre_token, in, in_mask) }; - uint tokens(0); - for(uint i(0); i < pre_tokens; ++i) + u64x2 ret { + 0, consumed + }; + + for(uint i(0); i < pre_tokens && ret[0] < 16; ++i) + { + // one token in hand is worth two in the bpe + if(likely((token[ret[0]] = find_token(pre_token[i])) != u16(-1))) + { + ++ret[0]; + continue; + } + u8x16 str[16]; const uint strs { bpe_tokenize(str, pre_token[i]) }; - const uint addl_tokens + for(uint j(0); j < strs && ret[0] < 16; ++j) { - find_tokens(token, tokens, str, strs) - }; + if(likely((token[ret[0]] = find_token(str[j])) != u16(-1))) + { + ++ret[0]; + continue; + } - tokens += addl_tokens; + ret += unk_tokenize(token, str[j], ret[0]); + } } - return u64x2 - { - tokens, consumed - }; + return ret; } // @@ -221,107 +290,111 @@ ircd::u64x2 ircd::gpt::vocab::pre_tokenize(u8x16 (&token)[16], const u8x16 in, const u8x16 in_mask) -noexcept { const u8x16 is_ascii_ctrl - { + ( in < 0x20 - }; + ); const u8x16 is_ascii_space - { + ( in == ' ' - }; + ); const u8x16 is_ascii_number - { + ( in >= '0' && in <= '9' - }; + ); const u8x16 is_ascii_letter - { + ( (in >= 'a' && in <= 'z') || (in >= 'A' && in <= 'Z') - }; + ); const u8x16 ascii_identified - { - is_ascii_space | is_ascii_number | is_ascii_letter - }; + ( + is_ascii_ctrl | is_ascii_space | is_ascii_number | is_ascii_letter + ); const u8x16 maybe_notascii - { + ( ~ascii_identified & in_mask - }; + ); const u32x16 ch - { + ( utf8::decode(in) - }; + ); const u32x16 uc_cat - { + ( icu::category(ch & (lane_cast(maybe_notascii) != 0)) - }; + ); const u32x16 is_L - {0 + (0 | ((uc_cat & 0x0000003eU) != 0) | (lane_cast(is_ascii_letter) != 0) - }; + ); const u32x16 is_N - {0 + (0 | ((uc_cat & 0x00000e00U) != 0) | (lane_cast(is_ascii_number) != 0) - }; + ); const u32x16 is_Z - {0 + (0 | ((uc_cat & 0x00007000U) != 0) | (lane_cast(is_ascii_space) != 0) - }; + ); + + const u32x16 is_C0 + (0 + | (lane_cast(is_ascii_ctrl) != 0) + ); const u32x16 is_trail - {0 + (0 | (is_L & shl<32>(is_L)) | (is_N & shl<32>(is_N)) | (is_Z & shl<32>(is_Z)) - }; + ); const u32x16 fat_mask - { + ( lane_cast(in_mask) != 0 - }; + ); const u32x16 is_head - { - ~is_trail & fat_mask - }; + ( + (~is_trail | is_C0) & fat_mask + ); // mask if token is preceded by a space const u32x16 leading_space - { + ( is_head & shl<32>(is_Z) - }; + ); // zero or one preceding space becomes prefixed to the next token const u32x16 tok_head - {0 + (0 | (is_head & ~leading_space) | shr<32>(leading_space) - }; + ); const u32x16 tok_trail - { + ( ~tok_head - }; + ); const u32x16 tok_mask - { + ( tok_trail - }; + ); - const u64x2 ret + const auto ret { pre_tokenize_split(token, ch, fat_mask, tok_mask) }; @@ -337,69 +410,139 @@ ircd::gpt::vocab::pre_tokenize_split(u8x16 (&token)[16], u32x16 ch, u32x16 ch_mask, u32x16 tok_mask) -noexcept { - const u32x16 lane0_mask - { - -1U - }; + // Replace single-byte codepoints from the LUT. + u32x16 rch; + for(uint i(0); i < 16; ++i) + rch[i] = ch[i] > 0xFF? + ch[i]: charset[ch[i]]; - u64x2 ret + u64x2 ret {0, 0}; + for(uint i(0); ret[0] >= i && ret[1] < 16; ++i) { - 0, 0 - }; + static const u32x16 lane0_mask + { + -1U + }; - for(uint i(0); ret[0] == i && ret[1] < 16; ++i) - { // Create a mask from all non-leading characters of input tokens with // a mask of just the leading character of the first token. To be sure // extra characters are not included we rinse it with the ch_mask. const u32x16 cover_mask - { + ( (lane0_mask | tok_mask) & ch_mask + ); + + // Get the number of codepoints of the first token from the cover. + const auto cp_num + { + std::min(simd::lzcnt(~cover_mask) / 32UL, 16UL) }; - // Get the length of the first token from the cover. - const u64 len - { - std::min(simd::lzcnt(~cover_mask) / 32, 16U) - }; + // Input codepoint lengths + const u32x16 cp_len + ( + utf8::length(ch & cover_mask) + ); + + // Output codepoint lengths + const u32x16 rcp_len + ( + utf8::length(rch & cover_mask) + ); + + // Generate utf-8 codepoints + const u8x64 rch8 + ( + utf8::encode(rch & cover_mask) + ); + + u32x16 idx; + uint off(0); // result bytes of utf-8 + for(uint j(0); j < cp_num; off += rcp_len[j++]) + idx[j] = off; + + uint len(0); // input bytes of utf-8 + for(uint j(0); j < cp_num; ++j) + len += cp_len[j]; // When the first token is too large, we truncate that token here and // return, effectively splitting the token into multiple. If the token // after the first is too large (input potentially spans into the next // block), we kick it to the next iteration entirely. - const bool skip + assert(ret[1] <= 16); + const auto skip { - len >= 16 - ret[1] && ret[0] + boolmask(ret[1] + off >= 16 && i > 0) }; - // Generate utf-8 codepoints - const u32x16 ch8 + // We have to return the proper number of bytes for what was truncated + // from the input, but the truncation is determined after a transform + // which may have a different size; this has to be offset back now. + if(!skip && ret[1] + off > 16) { - utf8::encode(ch & cover_mask) - }; + assert(off >= len); + len -= (off - len); + } // Pack the utf-8 codepoints into the result token token[i] = {0}; - for(uint j(0); j < 16; ++j) - token[i][j] = ch8[j]; + for(uint j(0); j < cp_num; ++j) + for(uint k(0); k < rcp_len[j] && idx[j] + k < 16; ++k) + token[i][idx[j] + k] = rch8[j * 4 + k]; // Shift the token off the input to consume the next. - for(uint j(0); j < len; ++j) + for(uint j(0); j < cp_num; ++j) { ch = shr<32>(ch); + rch = shr<32>(rch); ch_mask = shr<32>(ch_mask); tok_mask = shr<32>(tok_mask); } - ret[0] += bool(len) && !skip; - ret[1] += len & boolmask(!skip); + ret[0] += !skip && len; + ret[1] += ~skip & len; } return ret; } +[[gnu::noinline]] +ircd::u64x2 +ircd::gpt::vocab::unk_tokenize(u16x16 &token, + const u8x16 str, + const u64 num) +{ + u64 tokens(0), consumed(0); + const auto len(simd::strlen(str)); + while(consumed < len && num + tokens < 16) + for(uint i(0); i < len; ++i) + { + u8x16 s(str); + for(uint j(0); j < consumed; ++j) + s = shr<8>(s); + + for(uint j(len - i); j < 16; ++j) + s[j] = 0; + + if((token[num + tokens] = find_token(s)) != u16(-1)) + { + consumed += len - i; + ++tokens; + break; + } + } + + assert(len >= consumed); + assert(num + tokens <= 16); + const auto overflow{len - consumed}; + assert(overflow == 0 || num + tokens == 16); + return u64x2 + { + tokens, 0 + }; +} + // // byte-pair encoding // @@ -408,7 +551,6 @@ noexcept uint ircd::gpt::vocab::bpe_tokenize(u8x16 (&str)[16], const u8x16 pre_token) -noexcept { if(simd::strlen(pre_token) < 2) { @@ -430,9 +572,6 @@ noexcept bpe_score(score, pair, pairs) }; - if(best_score >= u16(-1)) - break; - const auto merges { bpe_merge(pair, score, pairs, best_score) @@ -454,48 +593,42 @@ noexcept uint ircd::gpt::vocab::bpe_prepare(u8x16 (&out)[16][2], const u8x16 in) -noexcept { - uint di, si; - for(di = 0, si = 0; si < 16 && di < 16; si += 2, di += 2) + const auto len { - out[di][0] = {0}; - out[di][1] = {0}; - if(!in[si] || !in[si + 1]) - break; + simd::strlen(in) + }; - //TODO: XXX - if(!si && in[si] == ' ') + const u32x16 cplen + ( + utf8::length(utf8::decode(in)) + ); + + u32x16 idx; + for(uint i(0), off(0); i < 16; off += cplen[i++]) + idx[i] = off; + + uint ret(0); + for(uint phase(0); phase < 2; ++phase) + for(uint i(phase); i < 16; i += 2, ++ret) { - out[di][0][0] = 0xc4; - out[di][0][1] = 0xa0; - out[di][1][0] = in[si + 1]; - continue; + if(idx[i] >= 16 || !in[idx[i]]) + break; + + out[i][0] = {0}; + out[i][1] = {0}; + for(uint k(0); k < 2; ++k) + for(uint j(0); j < cplen[i + k] && idx[i + k] + j < 16; ++j) + out[i][k][j] = in[idx[i + k] + j]; } - out[di][0][0] = in[si]; - out[di][1][0] = in[si + 1]; - } - - for(di = 1, si = 1; si < 16 && di < 16; si += 2, di += 2) - { - out[di][0] = {0}; - out[di][1] = {0}; - if(!in[si] || !in[si + 1]) - break; - - out[di][0][0] = in[si]; - out[di][1][0] = in[si + 1]; - } - - return di; + return ret; } uint ircd::gpt::vocab::bpe_postpare(u8x16 (&out)[16], const u8x16 (&in)[16][2], const uint num) -noexcept { uint ret(0); for(uint j(0); j < num; ++j) @@ -514,7 +647,6 @@ ircd::gpt::vocab::bpe_merge(u8x16 (&pair)[16][2], u16 (&score)[16], const uint num, const u16 best_score) -noexcept { uint ret(0); @@ -552,7 +684,6 @@ ircd::u16 ircd::gpt::vocab::bpe_score(u16 (&score)[16], const u8x16 (&pair)[16][2], const uint num) -noexcept { uint best(-1U), is_min; for(uint i(0); i < num; i++) @@ -575,36 +706,8 @@ noexcept // queries // -uint -ircd::gpt::vocab::find_tokens(u16x16 &token, - const uint tokens, - const u8x16 (&str)[16], - const uint strs) -noexcept -{ - uint ret(0); - for(; tokens + ret < 16 && ret < strs; ++ret) - { - const auto val - { - find_token(str[ret]) - }; - - const bool found - { - val != u16(-1) - }; - - assert(found); - token[tokens + ret] = val; - } - - return ret; -} - ircd::u16 ircd::gpt::vocab::find_token(const u8x16 string) -noexcept { const auto *const __restrict__ token { @@ -621,7 +724,6 @@ noexcept ircd::u16 ircd::gpt::vocab::find_merge(const u8x16 a, const u8x16 b) -noexcept { const auto &__restrict__ merge {