ircd::gpt: Remove hostside backprop branch for now; simplify sample tokenizer.
This commit is contained in:
parent
d5dc477de5
commit
110d4e7b17
|
@ -29,7 +29,6 @@ struct ircd::gpt::step
|
|||
pipe::prof profile;
|
||||
|
||||
void profile_accumulate(const pipe::prof &);
|
||||
bool backpropagate();
|
||||
|
||||
public:
|
||||
bool done() const noexcept;
|
||||
|
|
216
ircd/gpt.cc
216
ircd/gpt.cc
|
@ -399,11 +399,6 @@ const noexcept
|
|||
// epoch
|
||||
//
|
||||
|
||||
namespace ircd::gpt
|
||||
{
|
||||
static thread_local u16 marker alignas(64) [1024];
|
||||
}
|
||||
|
||||
//
|
||||
// epoch::epoch
|
||||
//
|
||||
|
@ -467,12 +462,6 @@ ircd::gpt::epoch::operator()()
|
|||
while(!step())
|
||||
ctx::interruption_point();
|
||||
|
||||
if(!step.backpropagate())
|
||||
throw error
|
||||
{
|
||||
"Failed to backprop."
|
||||
};
|
||||
|
||||
return done();
|
||||
}
|
||||
|
||||
|
@ -546,144 +535,6 @@ noexcept
|
|||
log_debug_prof(opts, ctrl, this->profile);
|
||||
}
|
||||
|
||||
bool
|
||||
ircd::gpt::step::backpropagate()
|
||||
{
|
||||
const auto hit
|
||||
{
|
||||
ctrl.target.logit.token == ctrl.select.logit.token
|
||||
};
|
||||
|
||||
const auto select_loss_mean
|
||||
{
|
||||
ctrl.select.loss.mean
|
||||
};
|
||||
|
||||
const auto target_loss_mean
|
||||
{
|
||||
ctrl.target.loss.mean
|
||||
};
|
||||
|
||||
const auto loss_mean
|
||||
{
|
||||
(target_loss_mean + select_loss_mean) / 2.0f
|
||||
};
|
||||
|
||||
static float mean_best { 10000.0f }, target_mean_best { 10000.0f };
|
||||
static ulong hit_best;
|
||||
static bool tack, last_tack;
|
||||
last_tack = tack;
|
||||
|
||||
const auto loss
|
||||
{
|
||||
loss_mean
|
||||
};
|
||||
|
||||
const bool improve_global
|
||||
{
|
||||
target_loss_mean < target_mean_best
|
||||
};
|
||||
|
||||
const bool improve
|
||||
{
|
||||
improve_global
|
||||
};
|
||||
|
||||
if(improve)
|
||||
mean_best = loss,
|
||||
target_mean_best = target_loss_mean,
|
||||
hit_best = ctrl.hit;
|
||||
else
|
||||
tack = !tack;
|
||||
|
||||
const auto grad
|
||||
{
|
||||
!tack? loss : -loss
|
||||
};
|
||||
|
||||
const auto steps
|
||||
{
|
||||
(opts.training_steps + opts.validation_steps + opts.testing_steps) / opts.batch_size
|
||||
};
|
||||
|
||||
const auto step
|
||||
{
|
||||
this->epoch.id * steps + this->id
|
||||
};
|
||||
|
||||
log::logf
|
||||
{
|
||||
log, improve? log::level::INFO: log::level::ERROR,
|
||||
"epoch:%u step:%u completed range[%u -> %zu] dsid:%u target:%-10.7f select:%-10.7f loss:%-10.7f [ %10.7f ] hit:%u miss:%u",
|
||||
this->epoch.id,
|
||||
step,
|
||||
this->start,
|
||||
this->start + opts.batch_size,
|
||||
this->id * opts.batch_size + ctrl.clk.samp,
|
||||
target_loss_mean,
|
||||
select_loss_mean,
|
||||
loss,
|
||||
grad * opts.alpha,
|
||||
ctrl.hit,
|
||||
ctrl.miss,
|
||||
};
|
||||
|
||||
if(!opts.alpha)
|
||||
return true;
|
||||
|
||||
if(!improve)
|
||||
return false;
|
||||
|
||||
cl::exec
|
||||
{
|
||||
desc.model->decode->master[0], std::memory_order_acq_rel
|
||||
};
|
||||
|
||||
auto &model
|
||||
{
|
||||
*mutable_cast(desc.model->decode_const)
|
||||
};
|
||||
|
||||
const mutable_buffer model_buffer
|
||||
{
|
||||
reinterpret_cast<char *>(&model),
|
||||
sizeof(gpt::model::decoder) * 3
|
||||
};
|
||||
|
||||
const mutable_buffer checkpoint_buffer
|
||||
{
|
||||
reinterpret_cast<char *>(&model) + sizeof(gpt::model::decoder) * 3,
|
||||
sizeof(gpt::model::decoder) * 3
|
||||
};
|
||||
|
||||
if(improve)
|
||||
copy(checkpoint_buffer, model_buffer);
|
||||
else
|
||||
copy(model_buffer, checkpoint_buffer);
|
||||
|
||||
ircd::timer stopwatch;
|
||||
backprop(opts, step, grad, model, epoch.moment);
|
||||
allocator::sync(model_buffer);
|
||||
|
||||
char pbuf[1][32];
|
||||
log::logf
|
||||
{
|
||||
log, improve? log::level::DEBUG: log::level::ERROR,
|
||||
"backpropagation step:%u lr:%-8.6f mean:%-10.7f$L hits:%-5u Tbest:%-10.7f$L Mbest:%-10.7f$L Hbest:%-5lu grad:{ %10.7f$L } %s",
|
||||
step,
|
||||
opts.alpha,
|
||||
loss_mean,
|
||||
ctrl.hit,
|
||||
target_mean_best,
|
||||
mean_best,
|
||||
hit_best,
|
||||
grad,
|
||||
pretty(pbuf[0], stopwatch.at<milliseconds>(), 1),
|
||||
};
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool
|
||||
ircd::gpt::step::operator()()
|
||||
{
|
||||
|
@ -739,7 +590,7 @@ ircd::gpt::samp::samp(gpt::step &step)
|
|||
}
|
||||
,id
|
||||
{
|
||||
ctrl.clk.samp
|
||||
ctrl.clk.step * opts.batch_size + ctrl.clk.samp
|
||||
}
|
||||
,accept
|
||||
{
|
||||
|
@ -852,7 +703,7 @@ ircd::gpt::samp::tokenize()
|
|||
{
|
||||
const auto idx
|
||||
{
|
||||
step.start + ctrl.clk.samp
|
||||
id
|
||||
};
|
||||
|
||||
const gpt::model::text text
|
||||
|
@ -871,65 +722,20 @@ ircd::gpt::samp::tokenize()
|
|||
json::unescape(str_buf, input)
|
||||
};
|
||||
|
||||
assert(!empty(str));
|
||||
static const auto delim
|
||||
const vector_view<u16> buf
|
||||
{
|
||||
"\n\n"_sv
|
||||
ctrl.token, opts.buffer_tokens
|
||||
};
|
||||
|
||||
const int phrases
|
||||
(
|
||||
ircd::token_count(str, delim)
|
||||
);
|
||||
|
||||
uint count(0);
|
||||
int p(phrases);
|
||||
assert(p >= 0);
|
||||
|
||||
if(startswith(str, delim))
|
||||
const auto in
|
||||
{
|
||||
ctrl.token[count++] = 198;
|
||||
ctrl.token[count++] = 198;
|
||||
}
|
||||
gpt::vocab::tokenize(buf, str)
|
||||
};
|
||||
|
||||
ircd::tokens(str, delim, [this, &count, &p, &phrases]
|
||||
(const string_view &phrase) noexcept -> bool
|
||||
const auto count
|
||||
{
|
||||
assert(!empty(phrase));
|
||||
const vector_view<u16> buf
|
||||
{
|
||||
ctrl.token + count, opts.buffer_tokens - count
|
||||
};
|
||||
|
||||
const auto in
|
||||
{
|
||||
gpt::vocab::tokenize(buf, phrase)
|
||||
};
|
||||
|
||||
if(count + size(in) + 2 > opts.context_tokens)
|
||||
return false;
|
||||
|
||||
count += size(in);
|
||||
ctrl.token[count++] = 198;
|
||||
ctrl.token[count++] = 198;
|
||||
|
||||
assert(p > 0);
|
||||
marker[--p] = count;
|
||||
return true;
|
||||
});
|
||||
|
||||
for(assert(p >= 0); p < phrases; ++p)
|
||||
if(marker[p] <= opts.context_tokens)
|
||||
break;
|
||||
|
||||
assert(p <= phrases);
|
||||
count = marker[p];
|
||||
|
||||
for(uint i(count); i < opts.buffer_tokens; ++i)
|
||||
ctrl.token[i] = 198;
|
||||
|
||||
if(!endswith(str, delim))
|
||||
count -= 2;
|
||||
size(in)
|
||||
};
|
||||
|
||||
assert(count > 0);
|
||||
assert(count <= opts.context_tokens);
|
||||
|
@ -1300,7 +1106,7 @@ ircd::gpt::debug_head(const mutable_buffer &out,
|
|||
{
|
||||
out, "%s[%4u]-%1u",
|
||||
debug_head(head, opts, ctrl.clk),
|
||||
ctrl.count - 1,
|
||||
ctrl.count,
|
||||
ctrl.dispatch,
|
||||
};
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue