mirror of
https://github.com/matrix-construct/construct
synced 2024-05-29 16:23:45 +02:00
ircd::gpt: Remove hostside backprop branch for now; simplify sample tokenizer.
This commit is contained in:
parent
d5dc477de5
commit
110d4e7b17
|
@ -29,7 +29,6 @@ struct ircd::gpt::step
|
||||||
pipe::prof profile;
|
pipe::prof profile;
|
||||||
|
|
||||||
void profile_accumulate(const pipe::prof &);
|
void profile_accumulate(const pipe::prof &);
|
||||||
bool backpropagate();
|
|
||||||
|
|
||||||
public:
|
public:
|
||||||
bool done() const noexcept;
|
bool done() const noexcept;
|
||||||
|
|
216
ircd/gpt.cc
216
ircd/gpt.cc
|
@ -399,11 +399,6 @@ const noexcept
|
||||||
// epoch
|
// epoch
|
||||||
//
|
//
|
||||||
|
|
||||||
namespace ircd::gpt
|
|
||||||
{
|
|
||||||
static thread_local u16 marker alignas(64) [1024];
|
|
||||||
}
|
|
||||||
|
|
||||||
//
|
//
|
||||||
// epoch::epoch
|
// epoch::epoch
|
||||||
//
|
//
|
||||||
|
@ -467,12 +462,6 @@ ircd::gpt::epoch::operator()()
|
||||||
while(!step())
|
while(!step())
|
||||||
ctx::interruption_point();
|
ctx::interruption_point();
|
||||||
|
|
||||||
if(!step.backpropagate())
|
|
||||||
throw error
|
|
||||||
{
|
|
||||||
"Failed to backprop."
|
|
||||||
};
|
|
||||||
|
|
||||||
return done();
|
return done();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -546,144 +535,6 @@ noexcept
|
||||||
log_debug_prof(opts, ctrl, this->profile);
|
log_debug_prof(opts, ctrl, this->profile);
|
||||||
}
|
}
|
||||||
|
|
||||||
bool
|
|
||||||
ircd::gpt::step::backpropagate()
|
|
||||||
{
|
|
||||||
const auto hit
|
|
||||||
{
|
|
||||||
ctrl.target.logit.token == ctrl.select.logit.token
|
|
||||||
};
|
|
||||||
|
|
||||||
const auto select_loss_mean
|
|
||||||
{
|
|
||||||
ctrl.select.loss.mean
|
|
||||||
};
|
|
||||||
|
|
||||||
const auto target_loss_mean
|
|
||||||
{
|
|
||||||
ctrl.target.loss.mean
|
|
||||||
};
|
|
||||||
|
|
||||||
const auto loss_mean
|
|
||||||
{
|
|
||||||
(target_loss_mean + select_loss_mean) / 2.0f
|
|
||||||
};
|
|
||||||
|
|
||||||
static float mean_best { 10000.0f }, target_mean_best { 10000.0f };
|
|
||||||
static ulong hit_best;
|
|
||||||
static bool tack, last_tack;
|
|
||||||
last_tack = tack;
|
|
||||||
|
|
||||||
const auto loss
|
|
||||||
{
|
|
||||||
loss_mean
|
|
||||||
};
|
|
||||||
|
|
||||||
const bool improve_global
|
|
||||||
{
|
|
||||||
target_loss_mean < target_mean_best
|
|
||||||
};
|
|
||||||
|
|
||||||
const bool improve
|
|
||||||
{
|
|
||||||
improve_global
|
|
||||||
};
|
|
||||||
|
|
||||||
if(improve)
|
|
||||||
mean_best = loss,
|
|
||||||
target_mean_best = target_loss_mean,
|
|
||||||
hit_best = ctrl.hit;
|
|
||||||
else
|
|
||||||
tack = !tack;
|
|
||||||
|
|
||||||
const auto grad
|
|
||||||
{
|
|
||||||
!tack? loss : -loss
|
|
||||||
};
|
|
||||||
|
|
||||||
const auto steps
|
|
||||||
{
|
|
||||||
(opts.training_steps + opts.validation_steps + opts.testing_steps) / opts.batch_size
|
|
||||||
};
|
|
||||||
|
|
||||||
const auto step
|
|
||||||
{
|
|
||||||
this->epoch.id * steps + this->id
|
|
||||||
};
|
|
||||||
|
|
||||||
log::logf
|
|
||||||
{
|
|
||||||
log, improve? log::level::INFO: log::level::ERROR,
|
|
||||||
"epoch:%u step:%u completed range[%u -> %zu] dsid:%u target:%-10.7f select:%-10.7f loss:%-10.7f [ %10.7f ] hit:%u miss:%u",
|
|
||||||
this->epoch.id,
|
|
||||||
step,
|
|
||||||
this->start,
|
|
||||||
this->start + opts.batch_size,
|
|
||||||
this->id * opts.batch_size + ctrl.clk.samp,
|
|
||||||
target_loss_mean,
|
|
||||||
select_loss_mean,
|
|
||||||
loss,
|
|
||||||
grad * opts.alpha,
|
|
||||||
ctrl.hit,
|
|
||||||
ctrl.miss,
|
|
||||||
};
|
|
||||||
|
|
||||||
if(!opts.alpha)
|
|
||||||
return true;
|
|
||||||
|
|
||||||
if(!improve)
|
|
||||||
return false;
|
|
||||||
|
|
||||||
cl::exec
|
|
||||||
{
|
|
||||||
desc.model->decode->master[0], std::memory_order_acq_rel
|
|
||||||
};
|
|
||||||
|
|
||||||
auto &model
|
|
||||||
{
|
|
||||||
*mutable_cast(desc.model->decode_const)
|
|
||||||
};
|
|
||||||
|
|
||||||
const mutable_buffer model_buffer
|
|
||||||
{
|
|
||||||
reinterpret_cast<char *>(&model),
|
|
||||||
sizeof(gpt::model::decoder) * 3
|
|
||||||
};
|
|
||||||
|
|
||||||
const mutable_buffer checkpoint_buffer
|
|
||||||
{
|
|
||||||
reinterpret_cast<char *>(&model) + sizeof(gpt::model::decoder) * 3,
|
|
||||||
sizeof(gpt::model::decoder) * 3
|
|
||||||
};
|
|
||||||
|
|
||||||
if(improve)
|
|
||||||
copy(checkpoint_buffer, model_buffer);
|
|
||||||
else
|
|
||||||
copy(model_buffer, checkpoint_buffer);
|
|
||||||
|
|
||||||
ircd::timer stopwatch;
|
|
||||||
backprop(opts, step, grad, model, epoch.moment);
|
|
||||||
allocator::sync(model_buffer);
|
|
||||||
|
|
||||||
char pbuf[1][32];
|
|
||||||
log::logf
|
|
||||||
{
|
|
||||||
log, improve? log::level::DEBUG: log::level::ERROR,
|
|
||||||
"backpropagation step:%u lr:%-8.6f mean:%-10.7f$L hits:%-5u Tbest:%-10.7f$L Mbest:%-10.7f$L Hbest:%-5lu grad:{ %10.7f$L } %s",
|
|
||||||
step,
|
|
||||||
opts.alpha,
|
|
||||||
loss_mean,
|
|
||||||
ctrl.hit,
|
|
||||||
target_mean_best,
|
|
||||||
mean_best,
|
|
||||||
hit_best,
|
|
||||||
grad,
|
|
||||||
pretty(pbuf[0], stopwatch.at<milliseconds>(), 1),
|
|
||||||
};
|
|
||||||
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool
|
bool
|
||||||
ircd::gpt::step::operator()()
|
ircd::gpt::step::operator()()
|
||||||
{
|
{
|
||||||
|
@ -739,7 +590,7 @@ ircd::gpt::samp::samp(gpt::step &step)
|
||||||
}
|
}
|
||||||
,id
|
,id
|
||||||
{
|
{
|
||||||
ctrl.clk.samp
|
ctrl.clk.step * opts.batch_size + ctrl.clk.samp
|
||||||
}
|
}
|
||||||
,accept
|
,accept
|
||||||
{
|
{
|
||||||
|
@ -852,7 +703,7 @@ ircd::gpt::samp::tokenize()
|
||||||
{
|
{
|
||||||
const auto idx
|
const auto idx
|
||||||
{
|
{
|
||||||
step.start + ctrl.clk.samp
|
id
|
||||||
};
|
};
|
||||||
|
|
||||||
const gpt::model::text text
|
const gpt::model::text text
|
||||||
|
@ -871,65 +722,20 @@ ircd::gpt::samp::tokenize()
|
||||||
json::unescape(str_buf, input)
|
json::unescape(str_buf, input)
|
||||||
};
|
};
|
||||||
|
|
||||||
assert(!empty(str));
|
const vector_view<u16> buf
|
||||||
static const auto delim
|
|
||||||
{
|
{
|
||||||
"\n\n"_sv
|
ctrl.token, opts.buffer_tokens
|
||||||
};
|
};
|
||||||
|
|
||||||
const int phrases
|
const auto in
|
||||||
(
|
|
||||||
ircd::token_count(str, delim)
|
|
||||||
);
|
|
||||||
|
|
||||||
uint count(0);
|
|
||||||
int p(phrases);
|
|
||||||
assert(p >= 0);
|
|
||||||
|
|
||||||
if(startswith(str, delim))
|
|
||||||
{
|
{
|
||||||
ctrl.token[count++] = 198;
|
gpt::vocab::tokenize(buf, str)
|
||||||
ctrl.token[count++] = 198;
|
};
|
||||||
}
|
|
||||||
|
|
||||||
ircd::tokens(str, delim, [this, &count, &p, &phrases]
|
const auto count
|
||||||
(const string_view &phrase) noexcept -> bool
|
|
||||||
{
|
{
|
||||||
assert(!empty(phrase));
|
size(in)
|
||||||
const vector_view<u16> buf
|
};
|
||||||
{
|
|
||||||
ctrl.token + count, opts.buffer_tokens - count
|
|
||||||
};
|
|
||||||
|
|
||||||
const auto in
|
|
||||||
{
|
|
||||||
gpt::vocab::tokenize(buf, phrase)
|
|
||||||
};
|
|
||||||
|
|
||||||
if(count + size(in) + 2 > opts.context_tokens)
|
|
||||||
return false;
|
|
||||||
|
|
||||||
count += size(in);
|
|
||||||
ctrl.token[count++] = 198;
|
|
||||||
ctrl.token[count++] = 198;
|
|
||||||
|
|
||||||
assert(p > 0);
|
|
||||||
marker[--p] = count;
|
|
||||||
return true;
|
|
||||||
});
|
|
||||||
|
|
||||||
for(assert(p >= 0); p < phrases; ++p)
|
|
||||||
if(marker[p] <= opts.context_tokens)
|
|
||||||
break;
|
|
||||||
|
|
||||||
assert(p <= phrases);
|
|
||||||
count = marker[p];
|
|
||||||
|
|
||||||
for(uint i(count); i < opts.buffer_tokens; ++i)
|
|
||||||
ctrl.token[i] = 198;
|
|
||||||
|
|
||||||
if(!endswith(str, delim))
|
|
||||||
count -= 2;
|
|
||||||
|
|
||||||
assert(count > 0);
|
assert(count > 0);
|
||||||
assert(count <= opts.context_tokens);
|
assert(count <= opts.context_tokens);
|
||||||
|
@ -1300,7 +1106,7 @@ ircd::gpt::debug_head(const mutable_buffer &out,
|
||||||
{
|
{
|
||||||
out, "%s[%4u]-%1u",
|
out, "%s[%4u]-%1u",
|
||||||
debug_head(head, opts, ctrl.clk),
|
debug_head(head, opts, ctrl.clk),
|
||||||
ctrl.count - 1,
|
ctrl.count,
|
||||||
ctrl.dispatch,
|
ctrl.dispatch,
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue