mirror of
https://github.com/matrix-construct/construct
synced 2024-12-28 00:14:07 +01:00
ircd::gpt::vocab: Minor reorg pre-tokenize related.
This commit is contained in:
parent
b6e2876af4
commit
b2f788e255
1 changed files with 117 additions and 109 deletions
|
@ -17,7 +17,7 @@ namespace ircd::gpt::vocab
|
|||
static uint bpe_postpare(u8x16 (&)[16], const u8x16 (&)[16][2], const uint);
|
||||
static uint bpe_prepare(u8x16 (&)[16][2], const u8x16);
|
||||
static uint bpe_tokenize(u8x16 (&)[16], const u8x16);
|
||||
static u64x2 pre_tokenize_split(u8x16 (&)[16], u32x16, u32x16, u32x16);
|
||||
static std::array<u32x16, 3> pre_tokenize_split(const u8x16, const u8x16);
|
||||
static u64x2 pre_tokenize(u8x16 (&)[16], const u8x16, const u8x16);
|
||||
static u64x2 unk_tokenize(u16x16 &, const u8x16, u64);
|
||||
static u64x2 tokenize_block(u16x16 &, const u8x16, const u8x16) noexcept;
|
||||
|
@ -317,10 +317,119 @@ noexcept
|
|||
/// The return value in [0] indicates the number of tokens populated in the
|
||||
/// array; the value in [1] indicates the bytes consumed from the input.
|
||||
///
|
||||
|
||||
/// Split single vector of UTF-32 codepoints into vectors of UTF-8 strings for
|
||||
/// each token determined by the input masks. Returns the number of tokens in
|
||||
/// [0] and the number of codepoints consumed in [1].
|
||||
ircd::u64x2
|
||||
ircd::gpt::vocab::pre_tokenize(u8x16 (&token)[16],
|
||||
const u8x16 in,
|
||||
const u8x16 in_mask)
|
||||
{
|
||||
auto [ch, ch_mask, tok_mask]
|
||||
{
|
||||
pre_tokenize_split(in, in_mask)
|
||||
};
|
||||
|
||||
// Replace single-byte codepoints from the LUT.
|
||||
u32x16 rch;
|
||||
for(uint i(0); i < 16; ++i)
|
||||
rch[i] = ch[i] > 0xFF?
|
||||
ch[i]: charset[ch[i]];
|
||||
|
||||
u64x2 ret {0, 0};
|
||||
for(uint i(0); ret[0] >= i && ret[1] < 16; ++i)
|
||||
{
|
||||
static const u32x16 lane0_mask
|
||||
{
|
||||
-1U
|
||||
};
|
||||
|
||||
// Create a mask from all non-leading characters of input tokens with
|
||||
// a mask of just the leading character of the first token. To be sure
|
||||
// extra characters are not included we rinse it with the ch_mask.
|
||||
const u32x16 cover_mask
|
||||
(
|
||||
(lane0_mask | tok_mask) & ch_mask
|
||||
);
|
||||
|
||||
// Get the number of codepoints of the first token from the cover.
|
||||
const auto cp_num
|
||||
{
|
||||
std::min(simd::lzcnt(~cover_mask) / 32UL, 16UL)
|
||||
};
|
||||
|
||||
// Input codepoint lengths
|
||||
const u32x16 cp_len
|
||||
(
|
||||
utf8::length(ch & cover_mask)
|
||||
);
|
||||
|
||||
// Output codepoint lengths
|
||||
const u32x16 rcp_len
|
||||
(
|
||||
utf8::length(rch & cover_mask)
|
||||
);
|
||||
|
||||
// Generate utf-8 codepoints
|
||||
const u8x64 rch8
|
||||
(
|
||||
utf8::encode(rch & cover_mask)
|
||||
);
|
||||
|
||||
u32x16 idx;
|
||||
uint off(0); // result bytes of utf-8
|
||||
for(uint j(0); j < cp_num; off += rcp_len[j++])
|
||||
idx[j] = off;
|
||||
|
||||
uint len(0); // input bytes of utf-8
|
||||
for(uint j(0); j < cp_num; ++j)
|
||||
len += cp_len[j];
|
||||
|
||||
// When the first token is too large, we truncate that token here and
|
||||
// return, effectively splitting the token into multiple. If the token
|
||||
// after the first is too large (input potentially spans into the next
|
||||
// block), we kick it to the next iteration entirely.
|
||||
assert(ret[1] <= 16);
|
||||
const auto skip
|
||||
{
|
||||
boolmask<u64>(ret[1] + off >= 16 && i > 0)
|
||||
};
|
||||
|
||||
// We have to return the proper number of bytes for what was truncated
|
||||
// from the input, but the truncation is determined after a transform
|
||||
// which may have a different size; this has to be offset back now.
|
||||
if(!skip && ret[1] + off > 16)
|
||||
{
|
||||
assert(off >= len);
|
||||
len -= (off - len);
|
||||
}
|
||||
|
||||
// Pack the utf-8 codepoints into the result token
|
||||
token[i] = {0};
|
||||
for(uint j(0); j < cp_num; ++j)
|
||||
for(uint k(0); k < rcp_len[j] && idx[j] + k < 16; ++k)
|
||||
token[i][idx[j] + k] = rch8[j * 4 + k];
|
||||
|
||||
// Shift the token off the input to consume the next.
|
||||
for(uint j(0); j < cp_num; ++j)
|
||||
{
|
||||
ch = shr<32>(ch);
|
||||
rch = shr<32>(rch);
|
||||
ch_mask = shr<32>(ch_mask);
|
||||
tok_mask = shr<32>(tok_mask);
|
||||
}
|
||||
|
||||
ret[0] += !skip && len;
|
||||
ret[1] += ~skip & len;
|
||||
}
|
||||
|
||||
return ret;
|
||||
}
|
||||
|
||||
std::array<ircd::u32x16, 3>
|
||||
ircd::gpt::vocab::pre_tokenize_split(const u8x16 in,
|
||||
const u8x16 in_mask)
|
||||
{
|
||||
const u8x16 is_ascii_ctrl
|
||||
(
|
||||
|
@ -453,118 +562,17 @@ ircd::gpt::vocab::pre_tokenize(u8x16 (&token)[16],
|
|||
tok_trail
|
||||
);
|
||||
|
||||
const auto ret
|
||||
return
|
||||
{
|
||||
pre_tokenize_split(token, ch, ch_mask, tok_mask)
|
||||
ch,
|
||||
ch_mask,
|
||||
tok_mask
|
||||
};
|
||||
|
||||
return ret;
|
||||
}
|
||||
|
||||
/// Split single vector of UTF-32 codepoints into vectors of UTF-8 strings for
|
||||
/// each token determined by the input masks. Returns the number of tokens in
|
||||
/// [0] and the number of codepoints consumed in [1].
|
||||
ircd::u64x2
|
||||
ircd::gpt::vocab::pre_tokenize_split(u8x16 (&token)[16],
|
||||
u32x16 ch,
|
||||
u32x16 ch_mask,
|
||||
u32x16 tok_mask)
|
||||
{
|
||||
// Replace single-byte codepoints from the LUT.
|
||||
u32x16 rch;
|
||||
for(uint i(0); i < 16; ++i)
|
||||
rch[i] = ch[i] > 0xFF?
|
||||
ch[i]: charset[ch[i]];
|
||||
|
||||
u64x2 ret {0, 0};
|
||||
for(uint i(0); ret[0] >= i && ret[1] < 16; ++i)
|
||||
{
|
||||
static const u32x16 lane0_mask
|
||||
{
|
||||
-1U
|
||||
};
|
||||
|
||||
// Create a mask from all non-leading characters of input tokens with
|
||||
// a mask of just the leading character of the first token. To be sure
|
||||
// extra characters are not included we rinse it with the ch_mask.
|
||||
const u32x16 cover_mask
|
||||
(
|
||||
(lane0_mask | tok_mask) & ch_mask
|
||||
);
|
||||
|
||||
// Get the number of codepoints of the first token from the cover.
|
||||
const auto cp_num
|
||||
{
|
||||
std::min(simd::lzcnt(~cover_mask) / 32UL, 16UL)
|
||||
};
|
||||
|
||||
// Input codepoint lengths
|
||||
const u32x16 cp_len
|
||||
(
|
||||
utf8::length(ch & cover_mask)
|
||||
);
|
||||
|
||||
// Output codepoint lengths
|
||||
const u32x16 rcp_len
|
||||
(
|
||||
utf8::length(rch & cover_mask)
|
||||
);
|
||||
|
||||
// Generate utf-8 codepoints
|
||||
const u8x64 rch8
|
||||
(
|
||||
utf8::encode(rch & cover_mask)
|
||||
);
|
||||
|
||||
u32x16 idx;
|
||||
uint off(0); // result bytes of utf-8
|
||||
for(uint j(0); j < cp_num; off += rcp_len[j++])
|
||||
idx[j] = off;
|
||||
|
||||
uint len(0); // input bytes of utf-8
|
||||
for(uint j(0); j < cp_num; ++j)
|
||||
len += cp_len[j];
|
||||
|
||||
// When the first token is too large, we truncate that token here and
|
||||
// return, effectively splitting the token into multiple. If the token
|
||||
// after the first is too large (input potentially spans into the next
|
||||
// block), we kick it to the next iteration entirely.
|
||||
assert(ret[1] <= 16);
|
||||
const auto skip
|
||||
{
|
||||
boolmask<u64>(ret[1] + off >= 16 && i > 0)
|
||||
};
|
||||
|
||||
// We have to return the proper number of bytes for what was truncated
|
||||
// from the input, but the truncation is determined after a transform
|
||||
// which may have a different size; this has to be offset back now.
|
||||
if(!skip && ret[1] + off > 16)
|
||||
{
|
||||
assert(off >= len);
|
||||
len -= (off - len);
|
||||
}
|
||||
|
||||
// Pack the utf-8 codepoints into the result token
|
||||
token[i] = {0};
|
||||
for(uint j(0); j < cp_num; ++j)
|
||||
for(uint k(0); k < rcp_len[j] && idx[j] + k < 16; ++k)
|
||||
token[i][idx[j] + k] = rch8[j * 4 + k];
|
||||
|
||||
// Shift the token off the input to consume the next.
|
||||
for(uint j(0); j < cp_num; ++j)
|
||||
{
|
||||
ch = shr<32>(ch);
|
||||
rch = shr<32>(rch);
|
||||
ch_mask = shr<32>(ch_mask);
|
||||
tok_mask = shr<32>(tok_mask);
|
||||
}
|
||||
|
||||
ret[0] += !skip && len;
|
||||
ret[1] += ~skip & len;
|
||||
}
|
||||
|
||||
return ret;
|
||||
}
|
||||
//
|
||||
// post-tokenizer
|
||||
//
|
||||
|
||||
[[gnu::noinline]]
|
||||
ircd::u64x2
|
||||
|
|
Loading…
Reference in a new issue