0
0
Fork 0
mirror of https://github.com/matrix-construct/construct synced 2024-06-11 06:28:55 +02:00

ircd::gpt::vocab: Fixes for additional mismatching cases.

This commit is contained in:
Jason Volk 2021-04-11 12:16:02 -07:00
parent 9c062d9c3f
commit eeadc15319

View file

@ -397,6 +397,7 @@ ircd::gpt::vocab::pre_tokenize(u8x16 (&token)[16],
lane_cast<u32x16>(in_mask) != 0
);
// mask candidate start of token
const u32x16 is_head
(
(~is_trail | is_C0) & fat_mask
@ -408,10 +409,16 @@ ircd::gpt::vocab::pre_tokenize(u8x16 (&token)[16],
is_head & shl<32>(is_Z)
);
// mask if next char is also the same char
const u32x16 is_rep
(
is_head & (shl<32>(ch) == ch)
);
// zero or one preceding space becomes prefixed to the next token
const u32x16 tok_head
(0
| (is_head & ~leading_space)
| (is_head & ~leading_space & ~is_rep)
| shr<32>(leading_space)
);
@ -544,38 +551,50 @@ ircd::gpt::vocab::unk_tokenize(u16x16 &token,
const u8x16 str,
const u64 num)
{
const auto len
{
simd::strlen(str)
};
u64 tokens(0), consumed(0);
const auto len(simd::strlen(str));
while(consumed < len && num + tokens < 16)
{
u16 slen(0), tok;
for(uint i(1); i < len; ++i)
uint slen(0);
for(uint i(0); i < len - consumed; ++i)
{
u8x16 s(str);
for(uint j(0); j < consumed; ++j)
s = shr<8>(s);
for(uint j(i); j < 16; ++j)
for(uint j(i + 1); j < 16; ++j)
s[j] = 0;
u16 tok;
if((tok = find_token(s)) == u16(-1))
continue;
slen = simd::strlen(s);
token[num + tokens] = tok;
slen = simd::strlen(s);
}
//assert(slen > 0);
consumed += slen;
tokens += bool(slen);
// Last possible branch; token is bytewise identity.
if(!slen)
token[num + tokens] = str[consumed];
consumed += std::max(slen, 1U);
tokens += 1U;
}
assert(len >= consumed);
assert(num + tokens <= 16);
const auto overflow{len - consumed};
assert(overflow == 0 || num + tokens == 16);
assert(consumed > 0 || tokens == 0);
assert(tokens > 0 || len == 0);
return u64x2
{
// return number of tokens created only; the caller already counted
// the length of str as consumed input.
tokens, 0
};
}