mirror of
https://github.com/matrix-construct/construct
synced 2025-02-18 09:40:12 +01:00
ircd::gpt::vocab: Fixes for additional mismatching cases.
This commit is contained in:
parent
9c062d9c3f
commit
eeadc15319
1 changed files with 28 additions and 9 deletions
|
@ -397,6 +397,7 @@ ircd::gpt::vocab::pre_tokenize(u8x16 (&token)[16],
|
|||
lane_cast<u32x16>(in_mask) != 0
|
||||
);
|
||||
|
||||
// mask candidate start of token
|
||||
const u32x16 is_head
|
||||
(
|
||||
(~is_trail | is_C0) & fat_mask
|
||||
|
@ -408,10 +409,16 @@ ircd::gpt::vocab::pre_tokenize(u8x16 (&token)[16],
|
|||
is_head & shl<32>(is_Z)
|
||||
);
|
||||
|
||||
// mask if next char is also the same char
|
||||
const u32x16 is_rep
|
||||
(
|
||||
is_head & (shl<32>(ch) == ch)
|
||||
);
|
||||
|
||||
// zero or one preceding space becomes prefixed to the next token
|
||||
const u32x16 tok_head
|
||||
(0
|
||||
| (is_head & ~leading_space)
|
||||
| (is_head & ~leading_space & ~is_rep)
|
||||
| shr<32>(leading_space)
|
||||
);
|
||||
|
||||
|
@ -544,38 +551,50 @@ ircd::gpt::vocab::unk_tokenize(u16x16 &token,
|
|||
const u8x16 str,
|
||||
const u64 num)
|
||||
{
|
||||
const auto len
|
||||
{
|
||||
simd::strlen(str)
|
||||
};
|
||||
|
||||
u64 tokens(0), consumed(0);
|
||||
const auto len(simd::strlen(str));
|
||||
while(consumed < len && num + tokens < 16)
|
||||
{
|
||||
u16 slen(0), tok;
|
||||
for(uint i(1); i < len; ++i)
|
||||
uint slen(0);
|
||||
for(uint i(0); i < len - consumed; ++i)
|
||||
{
|
||||
u8x16 s(str);
|
||||
for(uint j(0); j < consumed; ++j)
|
||||
s = shr<8>(s);
|
||||
|
||||
for(uint j(i); j < 16; ++j)
|
||||
for(uint j(i + 1); j < 16; ++j)
|
||||
s[j] = 0;
|
||||
|
||||
u16 tok;
|
||||
if((tok = find_token(s)) == u16(-1))
|
||||
continue;
|
||||
|
||||
slen = simd::strlen(s);
|
||||
token[num + tokens] = tok;
|
||||
slen = simd::strlen(s);
|
||||
}
|
||||
|
||||
//assert(slen > 0);
|
||||
consumed += slen;
|
||||
tokens += bool(slen);
|
||||
// Last possible branch; token is bytewise identity.
|
||||
if(!slen)
|
||||
token[num + tokens] = str[consumed];
|
||||
|
||||
consumed += std::max(slen, 1U);
|
||||
tokens += 1U;
|
||||
}
|
||||
|
||||
assert(len >= consumed);
|
||||
assert(num + tokens <= 16);
|
||||
const auto overflow{len - consumed};
|
||||
assert(overflow == 0 || num + tokens == 16);
|
||||
assert(consumed > 0 || tokens == 0);
|
||||
assert(tokens > 0 || len == 0);
|
||||
return u64x2
|
||||
{
|
||||
// return number of tokens created only; the caller already counted
|
||||
// the length of str as consumed input.
|
||||
tokens, 0
|
||||
};
|
||||
}
|
||||
|
|
Loading…
Add table
Reference in a new issue