0
0
Fork 0
mirror of https://github.com/matrix-construct/construct synced 2024-12-28 00:14:07 +01:00

ircd::gpt::vocab: Minor reorg pre-tokenize related.

This commit is contained in:
Jason Volk 2021-04-20 12:23:02 -07:00
parent b6e2876af4
commit b2f788e255

View file

@ -17,7 +17,7 @@ namespace ircd::gpt::vocab
static uint bpe_postpare(u8x16 (&)[16], const u8x16 (&)[16][2], const uint);
static uint bpe_prepare(u8x16 (&)[16][2], const u8x16);
static uint bpe_tokenize(u8x16 (&)[16], const u8x16);
static u64x2 pre_tokenize_split(u8x16 (&)[16], u32x16, u32x16, u32x16);
static std::array<u32x16, 3> pre_tokenize_split(const u8x16, const u8x16);
static u64x2 pre_tokenize(u8x16 (&)[16], const u8x16, const u8x16);
static u64x2 unk_tokenize(u16x16 &, const u8x16, u64);
static u64x2 tokenize_block(u16x16 &, const u8x16, const u8x16) noexcept;
@ -317,10 +317,119 @@ noexcept
/// The return value in [0] indicates the number of tokens populated in the
/// array; the value in [1] indicates the bytes consumed from the input.
///
/// Split single vector of UTF-32 codepoints into vectors of UTF-8 strings for
/// each token determined by the input masks. Returns the number of tokens in
/// [0] and the number of codepoints consumed in [1].
ircd::u64x2
ircd::gpt::vocab::pre_tokenize(u8x16 (&token)[16],
const u8x16 in,
const u8x16 in_mask)
{
auto [ch, ch_mask, tok_mask]
{
pre_tokenize_split(in, in_mask)
};
// Replace single-byte codepoints from the LUT.
u32x16 rch;
for(uint i(0); i < 16; ++i)
rch[i] = ch[i] > 0xFF?
ch[i]: charset[ch[i]];
u64x2 ret {0, 0};
for(uint i(0); ret[0] >= i && ret[1] < 16; ++i)
{
static const u32x16 lane0_mask
{
-1U
};
// Create a mask from all non-leading characters of input tokens with
// a mask of just the leading character of the first token. To be sure
// extra characters are not included we rinse it with the ch_mask.
const u32x16 cover_mask
(
(lane0_mask | tok_mask) & ch_mask
);
// Get the number of codepoints of the first token from the cover.
const auto cp_num
{
std::min(simd::lzcnt(~cover_mask) / 32UL, 16UL)
};
// Input codepoint lengths
const u32x16 cp_len
(
utf8::length(ch & cover_mask)
);
// Output codepoint lengths
const u32x16 rcp_len
(
utf8::length(rch & cover_mask)
);
// Generate utf-8 codepoints
const u8x64 rch8
(
utf8::encode(rch & cover_mask)
);
u32x16 idx;
uint off(0); // result bytes of utf-8
for(uint j(0); j < cp_num; off += rcp_len[j++])
idx[j] = off;
uint len(0); // input bytes of utf-8
for(uint j(0); j < cp_num; ++j)
len += cp_len[j];
// When the first token is too large, we truncate that token here and
// return, effectively splitting the token into multiple. If the token
// after the first is too large (input potentially spans into the next
// block), we kick it to the next iteration entirely.
assert(ret[1] <= 16);
const auto skip
{
boolmask<u64>(ret[1] + off >= 16 && i > 0)
};
// We have to return the proper number of bytes for what was truncated
// from the input, but the truncation is determined after a transform
// which may have a different size; this has to be offset back now.
if(!skip && ret[1] + off > 16)
{
assert(off >= len);
len -= (off - len);
}
// Pack the utf-8 codepoints into the result token
token[i] = {0};
for(uint j(0); j < cp_num; ++j)
for(uint k(0); k < rcp_len[j] && idx[j] + k < 16; ++k)
token[i][idx[j] + k] = rch8[j * 4 + k];
// Shift the token off the input to consume the next.
for(uint j(0); j < cp_num; ++j)
{
ch = shr<32>(ch);
rch = shr<32>(rch);
ch_mask = shr<32>(ch_mask);
tok_mask = shr<32>(tok_mask);
}
ret[0] += !skip && len;
ret[1] += ~skip & len;
}
return ret;
}
std::array<ircd::u32x16, 3>
ircd::gpt::vocab::pre_tokenize_split(const u8x16 in,
const u8x16 in_mask)
{
const u8x16 is_ascii_ctrl
(
@ -453,118 +562,17 @@ ircd::gpt::vocab::pre_tokenize(u8x16 (&token)[16],
tok_trail
);
const auto ret
return
{
pre_tokenize_split(token, ch, ch_mask, tok_mask)
ch,
ch_mask,
tok_mask
};
return ret;
}
/// Split single vector of UTF-32 codepoints into vectors of UTF-8 strings for
/// each token determined by the input masks. Returns the number of tokens in
/// [0] and the number of codepoints consumed in [1].
ircd::u64x2
ircd::gpt::vocab::pre_tokenize_split(u8x16 (&token)[16],
u32x16 ch,
u32x16 ch_mask,
u32x16 tok_mask)
{
// Replace single-byte codepoints from the LUT.
u32x16 rch;
for(uint i(0); i < 16; ++i)
rch[i] = ch[i] > 0xFF?
ch[i]: charset[ch[i]];
u64x2 ret {0, 0};
for(uint i(0); ret[0] >= i && ret[1] < 16; ++i)
{
static const u32x16 lane0_mask
{
-1U
};
// Create a mask from all non-leading characters of input tokens with
// a mask of just the leading character of the first token. To be sure
// extra characters are not included we rinse it with the ch_mask.
const u32x16 cover_mask
(
(lane0_mask | tok_mask) & ch_mask
);
// Get the number of codepoints of the first token from the cover.
const auto cp_num
{
std::min(simd::lzcnt(~cover_mask) / 32UL, 16UL)
};
// Input codepoint lengths
const u32x16 cp_len
(
utf8::length(ch & cover_mask)
);
// Output codepoint lengths
const u32x16 rcp_len
(
utf8::length(rch & cover_mask)
);
// Generate utf-8 codepoints
const u8x64 rch8
(
utf8::encode(rch & cover_mask)
);
u32x16 idx;
uint off(0); // result bytes of utf-8
for(uint j(0); j < cp_num; off += rcp_len[j++])
idx[j] = off;
uint len(0); // input bytes of utf-8
for(uint j(0); j < cp_num; ++j)
len += cp_len[j];
// When the first token is too large, we truncate that token here and
// return, effectively splitting the token into multiple. If the token
// after the first is too large (input potentially spans into the next
// block), we kick it to the next iteration entirely.
assert(ret[1] <= 16);
const auto skip
{
boolmask<u64>(ret[1] + off >= 16 && i > 0)
};
// We have to return the proper number of bytes for what was truncated
// from the input, but the truncation is determined after a transform
// which may have a different size; this has to be offset back now.
if(!skip && ret[1] + off > 16)
{
assert(off >= len);
len -= (off - len);
}
// Pack the utf-8 codepoints into the result token
token[i] = {0};
for(uint j(0); j < cp_num; ++j)
for(uint k(0); k < rcp_len[j] && idx[j] + k < 16; ++k)
token[i][idx[j] + k] = rch8[j * 4 + k];
// Shift the token off the input to consume the next.
for(uint j(0); j < cp_num; ++j)
{
ch = shr<32>(ch);
rch = shr<32>(rch);
ch_mask = shr<32>(ch_mask);
tok_mask = shr<32>(tok_mask);
}
ret[0] += !skip && len;
ret[1] += ~skip & len;
}
return ret;
}
//
// post-tokenizer
//
[[gnu::noinline]]
ircd::u64x2