0
0
Fork 0
mirror of https://github.com/matrix-construct/construct synced 2024-11-04 21:08:57 +01:00

ircd::gpt::vocab: Fixes for additional mismatching cases.

This commit is contained in:
Jason Volk 2021-04-11 12:16:02 -07:00
parent 9c062d9c3f
commit eeadc15319

View file

@ -397,6 +397,7 @@ ircd::gpt::vocab::pre_tokenize(u8x16 (&token)[16],
lane_cast<u32x16>(in_mask) != 0 lane_cast<u32x16>(in_mask) != 0
); );
// mask candidate start of token
const u32x16 is_head const u32x16 is_head
( (
(~is_trail | is_C0) & fat_mask (~is_trail | is_C0) & fat_mask
@ -408,10 +409,16 @@ ircd::gpt::vocab::pre_tokenize(u8x16 (&token)[16],
is_head & shl<32>(is_Z) is_head & shl<32>(is_Z)
); );
// mask if next char is also the same char
const u32x16 is_rep
(
is_head & (shl<32>(ch) == ch)
);
// zero or one preceding space becomes prefixed to the next token // zero or one preceding space becomes prefixed to the next token
const u32x16 tok_head const u32x16 tok_head
(0 (0
| (is_head & ~leading_space) | (is_head & ~leading_space & ~is_rep)
| shr<32>(leading_space) | shr<32>(leading_space)
); );
@ -544,38 +551,50 @@ ircd::gpt::vocab::unk_tokenize(u16x16 &token,
const u8x16 str, const u8x16 str,
const u64 num) const u64 num)
{ {
const auto len
{
simd::strlen(str)
};
u64 tokens(0), consumed(0); u64 tokens(0), consumed(0);
const auto len(simd::strlen(str));
while(consumed < len && num + tokens < 16) while(consumed < len && num + tokens < 16)
{ {
u16 slen(0), tok; uint slen(0);
for(uint i(1); i < len; ++i) for(uint i(0); i < len - consumed; ++i)
{ {
u8x16 s(str); u8x16 s(str);
for(uint j(0); j < consumed; ++j) for(uint j(0); j < consumed; ++j)
s = shr<8>(s); s = shr<8>(s);
for(uint j(i); j < 16; ++j) for(uint j(i + 1); j < 16; ++j)
s[j] = 0; s[j] = 0;
u16 tok;
if((tok = find_token(s)) == u16(-1)) if((tok = find_token(s)) == u16(-1))
continue; continue;
slen = simd::strlen(s);
token[num + tokens] = tok; token[num + tokens] = tok;
slen = simd::strlen(s);
} }
//assert(slen > 0); // Last possible branch; token is bytewise identity.
consumed += slen; if(!slen)
tokens += bool(slen); token[num + tokens] = str[consumed];
consumed += std::max(slen, 1U);
tokens += 1U;
} }
assert(len >= consumed); assert(len >= consumed);
assert(num + tokens <= 16); assert(num + tokens <= 16);
const auto overflow{len - consumed}; const auto overflow{len - consumed};
assert(overflow == 0 || num + tokens == 16); assert(overflow == 0 || num + tokens == 16);
assert(consumed > 0 || tokens == 0);
assert(tokens > 0 || len == 0);
return u64x2 return u64x2
{ {
// return number of tokens created only; the caller already counted
// the length of str as consumed input.
tokens, 0 tokens, 0
}; };
} }