0
0
Fork 0
mirror of https://github.com/matrix-construct/construct synced 2024-11-16 23:10:54 +01:00

ircd::gpt: Add Basic Latin (lower) and C0 replacement LUT; various.

This commit is contained in:
Jason Volk 2021-03-05 15:33:05 -08:00
parent c014fa2bbe
commit 53c4260a21
2 changed files with 268 additions and 168 deletions

View file

@ -15,6 +15,8 @@
/// ///
namespace ircd::gpt::vocab namespace ircd::gpt::vocab
{ {
IRCD_EXCEPTION(gpt::error, error)
// Actual number of tokens and merges stored in following lists. // Actual number of tokens and merges stored in following lists.
extern size_t extern size_t
tokens, tokens,
@ -32,12 +34,8 @@ namespace ircd::gpt::vocab
merges_path; merges_path;
// Tokenize UTF-8 input string of any length into proper token values, // Tokenize UTF-8 input string of any length into proper token values,
vector_view<u16> vector_view<u16> tokenize(const vector_view<u16> &out, const string_view &in);
tokenize(const vector_view<u16> &out,
const string_view &in) noexcept;
// Decode token values to build output text string. // Decode token values to build output text string.
string_view string_view detokenize(const mutable_buffer &out, const vector_view<const u16> &in);
detokenize(const mutable_buffer &out,
const vector_view<const u16> &in) noexcept;
} }

View file

@ -10,24 +10,62 @@
namespace ircd::gpt::vocab namespace ircd::gpt::vocab
{ {
static u16 find_token(const u8x16) noexcept; static u16 find_token(const u8x16);
static uint find_tokens(u16x16 &, const uint, const u8x16 (&)[16], const uint) noexcept; static u16 find_merge(const u8x16, const u8x16);
static u16 find_merge(const u8x16, const u8x16) noexcept; static u16 bpe_score(u16 (&)[16], const u8x16 (&)[16][2], const uint);
static uint bpe_merge(u8x16 (&)[16][2], u16 (&)[16], const uint, const u16);
static u16 bpe_score(u16 (&)[16], const u8x16 (&)[16][2], const uint) noexcept; static uint bpe_postpare(u8x16 (&)[16], const u8x16 (&)[16][2], const uint);
static uint bpe_merge(u8x16 (&)[16][2], u16 (&)[16], const uint, const u16) noexcept; static uint bpe_prepare(u8x16 (&)[16][2], const u8x16);
static uint bpe_postpare(u8x16 (&)[16], const u8x16 (&)[16][2], const uint) noexcept; static uint bpe_tokenize(u8x16 (&)[16], const u8x16);
static uint bpe_prepare(u8x16 (&)[16][2], const u8x16) noexcept; static u64x2 pre_tokenize_split(u8x16 (&)[16], u32x16, u32x16, u32x16);
static uint bpe_tokenize(u8x16 (&)[16], const u8x16) noexcept; static u64x2 pre_tokenize(u8x16 (&)[16], const u8x16, const u8x16);
static u64x2 unk_tokenize(u16x16 &, const u8x16, u64);
static u64x2 pre_tokenize_split(u8x16 (&)[16], u32x16, u32x16, u32x16) noexcept;
static u64x2 pre_tokenize(u8x16 (&)[16], const u8x16, const u8x16) noexcept;
static u64x2 tokenize_block(u16x16 &, const u8x16, const u8x16) noexcept; static u64x2 tokenize_block(u16x16 &, const u8x16, const u8x16) noexcept;
static void init_tokens(), init_merges(); static void init_tokens(), init_merges();
[[gnu::visibility("internal")]]
extern const char32_t charset[256];
} }
/// Remapping of single byte characters (Control (C0) and Basic Latin (ASCII)).
decltype(ircd::gpt::vocab::charset)
ircd::gpt::vocab::charset
alignas(64)
{
U'Ā', U'ā', U'Ă', U'ă', U'Ą', U'ą', U'Ć', U'ć', // [0x07]
U'Ĉ', U'ĉ', U'Ċ', U'ċ', U'Č', U'č', U'Ď', U'ď', // [0x0F]
U'Đ', U'đ', U'Ē', U'ē', U'Ĕ', U'ĕ', U'Ė', U'ė', // [0x17]
U'Ę', U'ę', U'Ě', U'ě', U'Ĝ', U'ĝ', U'Ğ', U'ğ', // [0x1F]
U'Ġ', U'!', U'"', U'#', U'$', U'%', U'&', U'\'', // [0x27]
U'(', U')', U'*', U'+', U',', U'-', U'.', U'/', // [0x2F]
U'0', U'1', U'2', U'3', U'4', U'5', U'6', U'7', // [0x37]
U'8', U'9', U':', U';', U'<', U'=', U'>', U'?', // [0x3F]
U'@', U'A', U'B', U'C', U'D', U'E', U'F', U'G', // [0x47]
U'H', U'I', U'J', U'K', U'L', U'M', U'N', U'O', // [0x4F]
U'P', U'Q', U'R', U'S', U'T', U'U', U'V', U'W', // [0x57]
U'X', U'Y', U'Z', U'[', U'\\', U']', U'^', U'_', // [0x5F]
U'`', U'a', U'b', U'c', U'd', U'e', U'f', U'g', // [0x67]
U'h', U'i', U'j', U'k', U'l', U'm', U'n', U'o', // [0x6F]
U'p', U'q', U'r', U's', U't', U'u', U'v', U'w', // [0x77]
U'x', U'y', U'z', U'{', U'|', U'}', U'~', U'ġ', // [0x7F]
U'Ģ', U'ģ', U'Ĥ', U'ĥ', U'Ħ', U'ħ', U'Ĩ', U'ĩ', // [0x87]
U'Ī', U'ī', U'Ĭ', U'ĭ', U'Į', U'į', U'İ', U'ı', // [0x8F]
U'IJ', U'ij', U'Ĵ', U'ĵ', U'Ķ', U'ķ', U'ĸ', U'Ĺ', // [0x97]
U'ĺ', U'Ļ', U'ļ', U'Ľ', U'ľ', U'Ŀ', U'ŀ', U'Ł', // [0x9F]
U'ł', U'¡', U'¢', U'£', U'¤', U'¥', U'¦', U'§', // [0xA7]
U'¨', U'©', U'ª', U'«', U'¬', U'Ń', U'®', U'¯', // [0xAF]
U'°', U'±', U'²', U'³', U'´', U'µ', U'', U'·', // [0xB7]
U'¸', U'¹', U'º', U'»', U'¼', U'½', U'¾', U'¿', // [0xBF]
U'À', U'Á', U'Â', U'Ã', U'Ä', U'Å', U'Æ', U'Ç', // [0xC7]
U'È', U'É', U'Ê', U'Ë', U'Ì', U'Í', U'Î', U'Ï', // [0xCF]
U'Ð', U'Ñ', U'Ò', U'Ó', U'Ô', U'Õ', U'Ö', U'×', // [0xD7]
U'Ø', U'Ù', U'Ú', U'Û', U'Ü', U'Ý', U'Þ', U'ß', // [0xDF]
U'à', U'á', U'â', U'ã', U'ä', U'å', U'æ', U'ç', // [0xE7]
U'è', U'é', U'ê', U'ë', U'ì', U'í', U'î', U'ï', // [0xEF]
U'ð', U'ñ', U'ò', U'ó', U'ô', U'õ', U'ö', U'÷', // [0xF7]
U'ø', U'ù', U'ú', U'û', U'ü', U'ý', U'þ', U'ÿ', // [0xFF]
};
decltype(ircd::gpt::vocab::tokens) decltype(ircd::gpt::vocab::tokens)
ircd::gpt::vocab::tokens; ircd::gpt::vocab::tokens;
@ -124,15 +162,34 @@ ircd::gpt::vocab::init_merges()
ircd::string_view ircd::string_view
ircd::gpt::vocab::detokenize(const mutable_buffer &out, ircd::gpt::vocab::detokenize(const mutable_buffer &out,
const vector_view<const u16> &in) const vector_view<const u16> &in)
noexcept
{ {
mutable_buffer buf(out); size_t off(0);
for(const u16 &token : in) for(const u16 &id : in)
consume(buf, copy(buf, const_buffer(vocab::token[token], size(string_view(vocab::token[token]))))); {
const auto &token
{
vocab::token[id]
};
const string_view text
{
token, strnlen(token, 16)
};
string_view dest
{
data(out + off), copy(out + off, text)
};
dest = replace(out + off, dest, "Ġ"_sv, " "_sv);
dest = replace(out + off, dest, "Ċ"_sv, "\n"_sv);
off += size(dest);
}
assert(off <= size(out));
return string_view return string_view
{ {
data(out), data(buf) data(out), off
}; };
} }
@ -143,7 +200,6 @@ noexcept
ircd::vector_view<ircd::u16> ircd::vector_view<ircd::u16>
ircd::gpt::vocab::tokenize(const vector_view<u16> &out, ircd::gpt::vocab::tokenize(const vector_view<u16> &out,
const string_view &in) const string_view &in)
noexcept
{ {
using input_t = u8x16; using input_t = u8x16;
using block_t = u16x16; using block_t = u16x16;
@ -165,6 +221,7 @@ noexcept
}; };
assert(consumed[0] <= out.size()); assert(consumed[0] <= out.size());
assert(consumed[0] <= consumed[1]);
return vector_view<u16> return vector_view<u16>
( (
out.data(), consumed[0] out.data(), consumed[0]
@ -183,27 +240,39 @@ noexcept
pre_tokenize(pre_token, in, in_mask) pre_tokenize(pre_token, in, in_mask)
}; };
uint tokens(0); u64x2 ret
for(uint i(0); i < pre_tokens; ++i)
{ {
0, consumed
};
for(uint i(0); i < pre_tokens && ret[0] < 16; ++i)
{
// one token in hand is worth two in the bpe
if(likely((token[ret[0]] = find_token(pre_token[i])) != u16(-1)))
{
++ret[0];
continue;
}
u8x16 str[16]; u8x16 str[16];
const uint strs const uint strs
{ {
bpe_tokenize(str, pre_token[i]) bpe_tokenize(str, pre_token[i])
}; };
const uint addl_tokens for(uint j(0); j < strs && ret[0] < 16; ++j)
{ {
find_tokens(token, tokens, str, strs) if(likely((token[ret[0]] = find_token(str[j])) != u16(-1)))
}; {
++ret[0];
continue;
}
tokens += addl_tokens; ret += unk_tokenize(token, str[j], ret[0]);
}
} }
return u64x2 return ret;
{
tokens, consumed
};
} }
// //
@ -221,107 +290,111 @@ ircd::u64x2
ircd::gpt::vocab::pre_tokenize(u8x16 (&token)[16], ircd::gpt::vocab::pre_tokenize(u8x16 (&token)[16],
const u8x16 in, const u8x16 in,
const u8x16 in_mask) const u8x16 in_mask)
noexcept
{ {
const u8x16 is_ascii_ctrl const u8x16 is_ascii_ctrl
{ (
in < 0x20 in < 0x20
}; );
const u8x16 is_ascii_space const u8x16 is_ascii_space
{ (
in == ' ' in == ' '
}; );
const u8x16 is_ascii_number const u8x16 is_ascii_number
{ (
in >= '0' && in <= '9' in >= '0' && in <= '9'
}; );
const u8x16 is_ascii_letter const u8x16 is_ascii_letter
{ (
(in >= 'a' && in <= 'z') || (in >= 'A' && in <= 'Z') (in >= 'a' && in <= 'z') || (in >= 'A' && in <= 'Z')
}; );
const u8x16 ascii_identified const u8x16 ascii_identified
{ (
is_ascii_space | is_ascii_number | is_ascii_letter is_ascii_ctrl | is_ascii_space | is_ascii_number | is_ascii_letter
}; );
const u8x16 maybe_notascii const u8x16 maybe_notascii
{ (
~ascii_identified & in_mask ~ascii_identified & in_mask
}; );
const u32x16 ch const u32x16 ch
{ (
utf8::decode(in) utf8::decode(in)
}; );
const u32x16 uc_cat const u32x16 uc_cat
{ (
icu::category(ch & (lane_cast<u32x16>(maybe_notascii) != 0)) icu::category(ch & (lane_cast<u32x16>(maybe_notascii) != 0))
}; );
const u32x16 is_L const u32x16 is_L
{0 (0
| ((uc_cat & 0x0000003eU) != 0) | ((uc_cat & 0x0000003eU) != 0)
| (lane_cast<u32x16>(is_ascii_letter) != 0) | (lane_cast<u32x16>(is_ascii_letter) != 0)
}; );
const u32x16 is_N const u32x16 is_N
{0 (0
| ((uc_cat & 0x00000e00U) != 0) | ((uc_cat & 0x00000e00U) != 0)
| (lane_cast<u32x16>(is_ascii_number) != 0) | (lane_cast<u32x16>(is_ascii_number) != 0)
}; );
const u32x16 is_Z const u32x16 is_Z
{0 (0
| ((uc_cat & 0x00007000U) != 0) | ((uc_cat & 0x00007000U) != 0)
| (lane_cast<u32x16>(is_ascii_space) != 0) | (lane_cast<u32x16>(is_ascii_space) != 0)
}; );
const u32x16 is_C0
(0
| (lane_cast<u32x16>(is_ascii_ctrl) != 0)
);
const u32x16 is_trail const u32x16 is_trail
{0 (0
| (is_L & shl<32>(is_L)) | (is_L & shl<32>(is_L))
| (is_N & shl<32>(is_N)) | (is_N & shl<32>(is_N))
| (is_Z & shl<32>(is_Z)) | (is_Z & shl<32>(is_Z))
}; );
const u32x16 fat_mask const u32x16 fat_mask
{ (
lane_cast<u32x16>(in_mask) != 0 lane_cast<u32x16>(in_mask) != 0
}; );
const u32x16 is_head const u32x16 is_head
{ (
~is_trail & fat_mask (~is_trail | is_C0) & fat_mask
}; );
// mask if token is preceded by a space // mask if token is preceded by a space
const u32x16 leading_space const u32x16 leading_space
{ (
is_head & shl<32>(is_Z) is_head & shl<32>(is_Z)
}; );
// zero or one preceding space becomes prefixed to the next token // zero or one preceding space becomes prefixed to the next token
const u32x16 tok_head const u32x16 tok_head
{0 (0
| (is_head & ~leading_space) | (is_head & ~leading_space)
| shr<32>(leading_space) | shr<32>(leading_space)
}; );
const u32x16 tok_trail const u32x16 tok_trail
{ (
~tok_head ~tok_head
}; );
const u32x16 tok_mask const u32x16 tok_mask
{ (
tok_trail tok_trail
}; );
const u64x2 ret const auto ret
{ {
pre_tokenize_split(token, ch, fat_mask, tok_mask) pre_tokenize_split(token, ch, fat_mask, tok_mask)
}; };
@ -337,69 +410,139 @@ ircd::gpt::vocab::pre_tokenize_split(u8x16 (&token)[16],
u32x16 ch, u32x16 ch,
u32x16 ch_mask, u32x16 ch_mask,
u32x16 tok_mask) u32x16 tok_mask)
noexcept
{ {
const u32x16 lane0_mask // Replace single-byte codepoints from the LUT.
{ u32x16 rch;
-1U for(uint i(0); i < 16; ++i)
}; rch[i] = ch[i] > 0xFF?
ch[i]: charset[ch[i]];
u64x2 ret u64x2 ret {0, 0};
for(uint i(0); ret[0] >= i && ret[1] < 16; ++i)
{ {
0, 0 static const u32x16 lane0_mask
}; {
-1U
};
for(uint i(0); ret[0] == i && ret[1] < 16; ++i)
{
// Create a mask from all non-leading characters of input tokens with // Create a mask from all non-leading characters of input tokens with
// a mask of just the leading character of the first token. To be sure // a mask of just the leading character of the first token. To be sure
// extra characters are not included we rinse it with the ch_mask. // extra characters are not included we rinse it with the ch_mask.
const u32x16 cover_mask const u32x16 cover_mask
{ (
(lane0_mask | tok_mask) & ch_mask (lane0_mask | tok_mask) & ch_mask
);
// Get the number of codepoints of the first token from the cover.
const auto cp_num
{
std::min(simd::lzcnt(~cover_mask) / 32UL, 16UL)
}; };
// Get the length of the first token from the cover. // Input codepoint lengths
const u64 len const u32x16 cp_len
{ (
std::min(simd::lzcnt(~cover_mask) / 32, 16U) utf8::length(ch & cover_mask)
}; );
// Output codepoint lengths
const u32x16 rcp_len
(
utf8::length(rch & cover_mask)
);
// Generate utf-8 codepoints
const u8x64 rch8
(
utf8::encode(rch & cover_mask)
);
u32x16 idx;
uint off(0); // result bytes of utf-8
for(uint j(0); j < cp_num; off += rcp_len[j++])
idx[j] = off;
uint len(0); // input bytes of utf-8
for(uint j(0); j < cp_num; ++j)
len += cp_len[j];
// When the first token is too large, we truncate that token here and // When the first token is too large, we truncate that token here and
// return, effectively splitting the token into multiple. If the token // return, effectively splitting the token into multiple. If the token
// after the first is too large (input potentially spans into the next // after the first is too large (input potentially spans into the next
// block), we kick it to the next iteration entirely. // block), we kick it to the next iteration entirely.
const bool skip assert(ret[1] <= 16);
const auto skip
{ {
len >= 16 - ret[1] && ret[0] boolmask<u64>(ret[1] + off >= 16 && i > 0)
}; };
// Generate utf-8 codepoints // We have to return the proper number of bytes for what was truncated
const u32x16 ch8 // from the input, but the truncation is determined after a transform
// which may have a different size; this has to be offset back now.
if(!skip && ret[1] + off > 16)
{ {
utf8::encode(ch & cover_mask) assert(off >= len);
}; len -= (off - len);
}
// Pack the utf-8 codepoints into the result token // Pack the utf-8 codepoints into the result token
token[i] = {0}; token[i] = {0};
for(uint j(0); j < 16; ++j) for(uint j(0); j < cp_num; ++j)
token[i][j] = ch8[j]; for(uint k(0); k < rcp_len[j] && idx[j] + k < 16; ++k)
token[i][idx[j] + k] = rch8[j * 4 + k];
// Shift the token off the input to consume the next. // Shift the token off the input to consume the next.
for(uint j(0); j < len; ++j) for(uint j(0); j < cp_num; ++j)
{ {
ch = shr<32>(ch); ch = shr<32>(ch);
rch = shr<32>(rch);
ch_mask = shr<32>(ch_mask); ch_mask = shr<32>(ch_mask);
tok_mask = shr<32>(tok_mask); tok_mask = shr<32>(tok_mask);
} }
ret[0] += bool(len) && !skip; ret[0] += !skip && len;
ret[1] += len & boolmask<u64>(!skip); ret[1] += ~skip & len;
} }
return ret; return ret;
} }
[[gnu::noinline]]
ircd::u64x2
ircd::gpt::vocab::unk_tokenize(u16x16 &token,
const u8x16 str,
const u64 num)
{
u64 tokens(0), consumed(0);
const auto len(simd::strlen(str));
while(consumed < len && num + tokens < 16)
for(uint i(0); i < len; ++i)
{
u8x16 s(str);
for(uint j(0); j < consumed; ++j)
s = shr<8>(s);
for(uint j(len - i); j < 16; ++j)
s[j] = 0;
if((token[num + tokens] = find_token(s)) != u16(-1))
{
consumed += len - i;
++tokens;
break;
}
}
assert(len >= consumed);
assert(num + tokens <= 16);
const auto overflow{len - consumed};
assert(overflow == 0 || num + tokens == 16);
return u64x2
{
tokens, 0
};
}
// //
// byte-pair encoding // byte-pair encoding
// //
@ -408,7 +551,6 @@ noexcept
uint uint
ircd::gpt::vocab::bpe_tokenize(u8x16 (&str)[16], ircd::gpt::vocab::bpe_tokenize(u8x16 (&str)[16],
const u8x16 pre_token) const u8x16 pre_token)
noexcept
{ {
if(simd::strlen(pre_token) < 2) if(simd::strlen(pre_token) < 2)
{ {
@ -430,9 +572,6 @@ noexcept
bpe_score(score, pair, pairs) bpe_score(score, pair, pairs)
}; };
if(best_score >= u16(-1))
break;
const auto merges const auto merges
{ {
bpe_merge(pair, score, pairs, best_score) bpe_merge(pair, score, pairs, best_score)
@ -454,48 +593,42 @@ noexcept
uint uint
ircd::gpt::vocab::bpe_prepare(u8x16 (&out)[16][2], ircd::gpt::vocab::bpe_prepare(u8x16 (&out)[16][2],
const u8x16 in) const u8x16 in)
noexcept
{ {
uint di, si; const auto len
for(di = 0, si = 0; si < 16 && di < 16; si += 2, di += 2)
{ {
out[di][0] = {0}; simd::strlen(in)
out[di][1] = {0}; };
if(!in[si] || !in[si + 1])
break;
//TODO: XXX const u32x16 cplen
if(!si && in[si] == ' ') (
utf8::length(utf8::decode(in))
);
u32x16 idx;
for(uint i(0), off(0); i < 16; off += cplen[i++])
idx[i] = off;
uint ret(0);
for(uint phase(0); phase < 2; ++phase)
for(uint i(phase); i < 16; i += 2, ++ret)
{ {
out[di][0][0] = 0xc4; if(idx[i] >= 16 || !in[idx[i]])
out[di][0][1] = 0xa0; break;
out[di][1][0] = in[si + 1];
continue; out[i][0] = {0};
out[i][1] = {0};
for(uint k(0); k < 2; ++k)
for(uint j(0); j < cplen[i + k] && idx[i + k] + j < 16; ++j)
out[i][k][j] = in[idx[i + k] + j];
} }
out[di][0][0] = in[si]; return ret;
out[di][1][0] = in[si + 1];
}
for(di = 1, si = 1; si < 16 && di < 16; si += 2, di += 2)
{
out[di][0] = {0};
out[di][1] = {0};
if(!in[si] || !in[si + 1])
break;
out[di][0][0] = in[si];
out[di][1][0] = in[si + 1];
}
return di;
} }
uint uint
ircd::gpt::vocab::bpe_postpare(u8x16 (&out)[16], ircd::gpt::vocab::bpe_postpare(u8x16 (&out)[16],
const u8x16 (&in)[16][2], const u8x16 (&in)[16][2],
const uint num) const uint num)
noexcept
{ {
uint ret(0); uint ret(0);
for(uint j(0); j < num; ++j) for(uint j(0); j < num; ++j)
@ -514,7 +647,6 @@ ircd::gpt::vocab::bpe_merge(u8x16 (&pair)[16][2],
u16 (&score)[16], u16 (&score)[16],
const uint num, const uint num,
const u16 best_score) const u16 best_score)
noexcept
{ {
uint ret(0); uint ret(0);
@ -552,7 +684,6 @@ ircd::u16
ircd::gpt::vocab::bpe_score(u16 (&score)[16], ircd::gpt::vocab::bpe_score(u16 (&score)[16],
const u8x16 (&pair)[16][2], const u8x16 (&pair)[16][2],
const uint num) const uint num)
noexcept
{ {
uint best(-1U), is_min; uint best(-1U), is_min;
for(uint i(0); i < num; i++) for(uint i(0); i < num; i++)
@ -575,36 +706,8 @@ noexcept
// queries // queries
// //
uint
ircd::gpt::vocab::find_tokens(u16x16 &token,
const uint tokens,
const u8x16 (&str)[16],
const uint strs)
noexcept
{
uint ret(0);
for(; tokens + ret < 16 && ret < strs; ++ret)
{
const auto val
{
find_token(str[ret])
};
const bool found
{
val != u16(-1)
};
assert(found);
token[tokens + ret] = val;
}
return ret;
}
ircd::u16 ircd::u16
ircd::gpt::vocab::find_token(const u8x16 string) ircd::gpt::vocab::find_token(const u8x16 string)
noexcept
{ {
const auto *const __restrict__ token const auto *const __restrict__ token
{ {
@ -621,7 +724,6 @@ noexcept
ircd::u16 ircd::u16
ircd::gpt::vocab::find_merge(const u8x16 a, ircd::gpt::vocab::find_merge(const u8x16 a,
const u8x16 b) const u8x16 b)
noexcept
{ {
const auto &__restrict__ merge const auto &__restrict__ merge
{ {