ircd::gpt: Remove hostside backprop branch for now; simplify sample tokenizer.

This commit is contained in:
Jason Volk 2022-10-24 00:26:40 +00:00
parent d5dc477de5
commit 110d4e7b17
2 changed files with 11 additions and 206 deletions

View File

@ -29,7 +29,6 @@ struct ircd::gpt::step
pipe::prof profile;
void profile_accumulate(const pipe::prof &);
bool backpropagate();
public:
bool done() const noexcept;

View File

@ -399,11 +399,6 @@ const noexcept
// epoch
//
namespace ircd::gpt
{
static thread_local u16 marker alignas(64) [1024];
}
//
// epoch::epoch
//
@ -467,12 +462,6 @@ ircd::gpt::epoch::operator()()
while(!step())
ctx::interruption_point();
if(!step.backpropagate())
throw error
{
"Failed to backprop."
};
return done();
}
@ -546,144 +535,6 @@ noexcept
log_debug_prof(opts, ctrl, this->profile);
}
bool
ircd::gpt::step::backpropagate()
{
const auto hit
{
ctrl.target.logit.token == ctrl.select.logit.token
};
const auto select_loss_mean
{
ctrl.select.loss.mean
};
const auto target_loss_mean
{
ctrl.target.loss.mean
};
const auto loss_mean
{
(target_loss_mean + select_loss_mean) / 2.0f
};
static float mean_best { 10000.0f }, target_mean_best { 10000.0f };
static ulong hit_best;
static bool tack, last_tack;
last_tack = tack;
const auto loss
{
loss_mean
};
const bool improve_global
{
target_loss_mean < target_mean_best
};
const bool improve
{
improve_global
};
if(improve)
mean_best = loss,
target_mean_best = target_loss_mean,
hit_best = ctrl.hit;
else
tack = !tack;
const auto grad
{
!tack? loss : -loss
};
const auto steps
{
(opts.training_steps + opts.validation_steps + opts.testing_steps) / opts.batch_size
};
const auto step
{
this->epoch.id * steps + this->id
};
log::logf
{
log, improve? log::level::INFO: log::level::ERROR,
"epoch:%u step:%u completed range[%u -> %zu] dsid:%u target:%-10.7f select:%-10.7f loss:%-10.7f [ %10.7f ] hit:%u miss:%u",
this->epoch.id,
step,
this->start,
this->start + opts.batch_size,
this->id * opts.batch_size + ctrl.clk.samp,
target_loss_mean,
select_loss_mean,
loss,
grad * opts.alpha,
ctrl.hit,
ctrl.miss,
};
if(!opts.alpha)
return true;
if(!improve)
return false;
cl::exec
{
desc.model->decode->master[0], std::memory_order_acq_rel
};
auto &model
{
*mutable_cast(desc.model->decode_const)
};
const mutable_buffer model_buffer
{
reinterpret_cast<char *>(&model),
sizeof(gpt::model::decoder) * 3
};
const mutable_buffer checkpoint_buffer
{
reinterpret_cast<char *>(&model) + sizeof(gpt::model::decoder) * 3,
sizeof(gpt::model::decoder) * 3
};
if(improve)
copy(checkpoint_buffer, model_buffer);
else
copy(model_buffer, checkpoint_buffer);
ircd::timer stopwatch;
backprop(opts, step, grad, model, epoch.moment);
allocator::sync(model_buffer);
char pbuf[1][32];
log::logf
{
log, improve? log::level::DEBUG: log::level::ERROR,
"backpropagation step:%u lr:%-8.6f mean:%-10.7f$L hits:%-5u Tbest:%-10.7f$L Mbest:%-10.7f$L Hbest:%-5lu grad:{ %10.7f$L } %s",
step,
opts.alpha,
loss_mean,
ctrl.hit,
target_mean_best,
mean_best,
hit_best,
grad,
pretty(pbuf[0], stopwatch.at<milliseconds>(), 1),
};
return true;
}
bool
ircd::gpt::step::operator()()
{
@ -739,7 +590,7 @@ ircd::gpt::samp::samp(gpt::step &step)
}
,id
{
ctrl.clk.samp
ctrl.clk.step * opts.batch_size + ctrl.clk.samp
}
,accept
{
@ -852,7 +703,7 @@ ircd::gpt::samp::tokenize()
{
const auto idx
{
step.start + ctrl.clk.samp
id
};
const gpt::model::text text
@ -871,65 +722,20 @@ ircd::gpt::samp::tokenize()
json::unescape(str_buf, input)
};
assert(!empty(str));
static const auto delim
const vector_view<u16> buf
{
"\n\n"_sv
ctrl.token, opts.buffer_tokens
};
const int phrases
(
ircd::token_count(str, delim)
);
uint count(0);
int p(phrases);
assert(p >= 0);
if(startswith(str, delim))
const auto in
{
ctrl.token[count++] = 198;
ctrl.token[count++] = 198;
}
gpt::vocab::tokenize(buf, str)
};
ircd::tokens(str, delim, [this, &count, &p, &phrases]
(const string_view &phrase) noexcept -> bool
const auto count
{
assert(!empty(phrase));
const vector_view<u16> buf
{
ctrl.token + count, opts.buffer_tokens - count
};
const auto in
{
gpt::vocab::tokenize(buf, phrase)
};
if(count + size(in) + 2 > opts.context_tokens)
return false;
count += size(in);
ctrl.token[count++] = 198;
ctrl.token[count++] = 198;
assert(p > 0);
marker[--p] = count;
return true;
});
for(assert(p >= 0); p < phrases; ++p)
if(marker[p] <= opts.context_tokens)
break;
assert(p <= phrases);
count = marker[p];
for(uint i(count); i < opts.buffer_tokens; ++i)
ctrl.token[i] = 198;
if(!endswith(str, delim))
count -= 2;
size(in)
};
assert(count > 0);
assert(count <= opts.context_tokens);
@ -1300,7 +1106,7 @@ ircd::gpt::debug_head(const mutable_buffer &out,
{
out, "%s[%4u]-%1u",
debug_head(head, opts, ctrl.clk),
ctrl.count - 1,
ctrl.count,
ctrl.dispatch,
};
}