mirror of
https://github.com/matrix-construct/construct
synced 2024-11-04 21:08:57 +01:00
ircd::gpt::vocab: Fixes for additional mismatching cases.
This commit is contained in:
parent
9c062d9c3f
commit
eeadc15319
1 changed files with 28 additions and 9 deletions
|
@ -397,6 +397,7 @@ ircd::gpt::vocab::pre_tokenize(u8x16 (&token)[16],
|
||||||
lane_cast<u32x16>(in_mask) != 0
|
lane_cast<u32x16>(in_mask) != 0
|
||||||
);
|
);
|
||||||
|
|
||||||
|
// mask candidate start of token
|
||||||
const u32x16 is_head
|
const u32x16 is_head
|
||||||
(
|
(
|
||||||
(~is_trail | is_C0) & fat_mask
|
(~is_trail | is_C0) & fat_mask
|
||||||
|
@ -408,10 +409,16 @@ ircd::gpt::vocab::pre_tokenize(u8x16 (&token)[16],
|
||||||
is_head & shl<32>(is_Z)
|
is_head & shl<32>(is_Z)
|
||||||
);
|
);
|
||||||
|
|
||||||
|
// mask if next char is also the same char
|
||||||
|
const u32x16 is_rep
|
||||||
|
(
|
||||||
|
is_head & (shl<32>(ch) == ch)
|
||||||
|
);
|
||||||
|
|
||||||
// zero or one preceding space becomes prefixed to the next token
|
// zero or one preceding space becomes prefixed to the next token
|
||||||
const u32x16 tok_head
|
const u32x16 tok_head
|
||||||
(0
|
(0
|
||||||
| (is_head & ~leading_space)
|
| (is_head & ~leading_space & ~is_rep)
|
||||||
| shr<32>(leading_space)
|
| shr<32>(leading_space)
|
||||||
);
|
);
|
||||||
|
|
||||||
|
@ -544,38 +551,50 @@ ircd::gpt::vocab::unk_tokenize(u16x16 &token,
|
||||||
const u8x16 str,
|
const u8x16 str,
|
||||||
const u64 num)
|
const u64 num)
|
||||||
{
|
{
|
||||||
|
const auto len
|
||||||
|
{
|
||||||
|
simd::strlen(str)
|
||||||
|
};
|
||||||
|
|
||||||
u64 tokens(0), consumed(0);
|
u64 tokens(0), consumed(0);
|
||||||
const auto len(simd::strlen(str));
|
|
||||||
while(consumed < len && num + tokens < 16)
|
while(consumed < len && num + tokens < 16)
|
||||||
{
|
{
|
||||||
u16 slen(0), tok;
|
uint slen(0);
|
||||||
for(uint i(1); i < len; ++i)
|
for(uint i(0); i < len - consumed; ++i)
|
||||||
{
|
{
|
||||||
u8x16 s(str);
|
u8x16 s(str);
|
||||||
for(uint j(0); j < consumed; ++j)
|
for(uint j(0); j < consumed; ++j)
|
||||||
s = shr<8>(s);
|
s = shr<8>(s);
|
||||||
|
|
||||||
for(uint j(i); j < 16; ++j)
|
for(uint j(i + 1); j < 16; ++j)
|
||||||
s[j] = 0;
|
s[j] = 0;
|
||||||
|
|
||||||
|
u16 tok;
|
||||||
if((tok = find_token(s)) == u16(-1))
|
if((tok = find_token(s)) == u16(-1))
|
||||||
continue;
|
continue;
|
||||||
|
|
||||||
slen = simd::strlen(s);
|
|
||||||
token[num + tokens] = tok;
|
token[num + tokens] = tok;
|
||||||
|
slen = simd::strlen(s);
|
||||||
}
|
}
|
||||||
|
|
||||||
//assert(slen > 0);
|
// Last possible branch; token is bytewise identity.
|
||||||
consumed += slen;
|
if(!slen)
|
||||||
tokens += bool(slen);
|
token[num + tokens] = str[consumed];
|
||||||
|
|
||||||
|
consumed += std::max(slen, 1U);
|
||||||
|
tokens += 1U;
|
||||||
}
|
}
|
||||||
|
|
||||||
assert(len >= consumed);
|
assert(len >= consumed);
|
||||||
assert(num + tokens <= 16);
|
assert(num + tokens <= 16);
|
||||||
const auto overflow{len - consumed};
|
const auto overflow{len - consumed};
|
||||||
assert(overflow == 0 || num + tokens == 16);
|
assert(overflow == 0 || num + tokens == 16);
|
||||||
|
assert(consumed > 0 || tokens == 0);
|
||||||
|
assert(tokens > 0 || len == 0);
|
||||||
return u64x2
|
return u64x2
|
||||||
{
|
{
|
||||||
|
// return number of tokens created only; the caller already counted
|
||||||
|
// the length of str as consumed input.
|
||||||
tokens, 0
|
tokens, 0
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue