// Matrix Construct Is All You Need Is All You Need Is AllĊĊĊĊĊĊĊĊ
//
// Copyright (C) Matrix Construct Developers, Authors & Contributors
// Copyright (C) 2016-2021 Jason Volk <jason@zemos.net>
//
// Permission to use, copy, modify, and/or distribute this software for any
// purpose with or without fee is hereby granted, provided that the above
// copyright notice and this permission notice is present in all copies. The
// full license for this software is available in the LICENSE file.

decltype(ircd::gpt::log)
ircd::gpt::log
{
	"gpt"
};

ircd::string_view
ircd::gpt::generate(const mutable_buffer &out,
                    const string_view &in,
                    task &task)
{
	u16 input_buf[1024];
	const auto input_tokens
	{
		gpt::vocab::tokenize(input_buf, in)
	};

	u16 output_buf[1024];
	const auto output_tokens
	{
		generate(output_buf, input_tokens, task)
	};

	const auto output
	{
		gpt::vocab::detokenize(out, output_tokens)
	};

	return output;
}

ircd::vector_view<ircd::u16>
ircd::gpt::generate(const vector_view<u16> &out,
                    const vector_view<const u16> &in,
                    task &task)
{
	assert(task.opts);
	const auto &opts
	{
		*task.opts
	};

	assert(task.ctrl);
	auto &ctrl
	{
		*task.ctrl
	};

	size_t in_i(0);
	while(in_i < in.size() && ctrl.count < opts.buffer_tokens)
		if(in[in_i] == 628)
		{
			ctrl.token[ctrl.count++] = 198;
			ctrl.token[ctrl.count++] = 198;
			in_i++;
		}
		else ctrl.token[ctrl.count++] = in[in_i++];

	generate(task);

	size_t out_i(0);
	for(; out_i < out.size() && in_i + out_i < ctrl.count; out_i++)
		out[out_i] = ctrl.token[in_i + out_i];

	return vector_view<u16>
	{
		out, out_i
	};
}

void
ircd::gpt::generate(task &task)
{
	assert(task.opts);
	const auto &opts
	{
		*task.opts
	};

	assert(task.ctrl);
	auto &ctrl
	{
		*task.ctrl
	};

	gpt::epoch epoch
	{
		task
	};

	gpt::step step
	{
		epoch
	};

	gpt::samp samp
	{
		step
	};

	bool halt {false}; do
	{
		gpt::pipe::cycle cycle
		{
			samp
		};

		halt = !samp.evaluate(cycle);
	}
	while(!halt);
}

//
// debug
//

void
ircd::gpt::log_debug_prof(const opts &opts,
                          const ctrl &ctrl,
                          const pipe::prof &prof)
{
	static char
	buf[2][512];

	const auto head
	{
		debug_head(buf[0], opts, ctrl)
	};

	for(uint i(0); i < prof.stages; ++i)
	{
		if(!std::get<1>(prof.info[i]))
			continue;

		log::logf
		{
			log, log::level::DEBUG,
			"%s %2u: %s",
			head,
			i,
			pipe::debug(buf[1], prof, i),
		};
	}
}

void
ircd::gpt::log_debug_topn(const opts &opts,
                          const ctrl &ctrl)
{
	static char
	buf[2][512];

	const auto head
	{
		debug_head(buf[0], opts, ctrl)
	};

	for(uint i(0); i < opts.top_n; ++i)
		log::logf
		{
			log, log::level::DEBUG,
			"%s %s",
			head,
			debug_top(buf[1], opts, ctrl, i),
		};
}

void
ircd::gpt::log_debug_labels(const opts &opts,
                            const ctrl &ctrl)
{
	static char
	buf[2][512];

	const auto head
	{
		debug_head(buf[0], opts, ctrl)
	};

	for(uint i(0); i < opts.labels; ++i)
		log::logf
		{
			log, log::level::DEBUG,
			"%s %s",
			head,
			debug_label(buf[1], opts, ctrl, i, 1),
		};
}

void
ircd::gpt::log_debug_attns_top(const opts &opts,
                               const ctrl &ctrl)
{
	static char
	buf[8][512];

	const auto head
	{
		debug_head(buf[0], opts, ctrl)
	};

	std::map<uint, uint> tokm;
	for(uint i(0); i < opts.layers; ++i)
		for(uint j(0); j < opts.attn_rank; ++j)
			tokm[ctrl.attn[i][j]]++;

	std::vector<std::pair<uint, uint>> tok(begin(tokm), end(tokm));
	std::sort(begin(tok), end(tok), [&tokm]
	(const auto &a, const auto &b)
	{
		return b.second < a.second;
	});

	for(const auto &[idx, score] : tok)
	{
		const auto barsz
		{
			std::min(score, std::min(80U, uint(sizeof(buf[2]) - 1)))
		};

		memset(buf[2], '|', barsz);
		buf[2][barsz] = '\0';

		log::logf
		{
			log, log::level::DEBUG,
			"%s %s [%3u] %s %-3u",
			head,
			vocab::debug(buf[1], ctrl.token[idx], 1),
			idx,
			buf[2],
			score,
		};
	}
}

void
ircd::gpt::log_debug_attns(const opts &opts,
                           const ctrl &ctrl)
{
	static char
	buf[2][512];

	const auto head
	{
		debug_head(buf[0], opts, ctrl)
	};

	for(uint i(0); i < ctrl.count; ++i)
		log::logf
		{
			log, log::level::DEBUG,
			"%s %s",
			head,
			debug_attn(buf[1], opts, ctrl, i),
		};
}

void
ircd::gpt::log_debug_token(const opts &opts,
                           const ctrl &ctrl,
                           const uint i)
{
	static char
	buf[2][512];

	log::logf
	{
		log, log::level::DEBUG,
		"%s %s",
		debug_head(buf[0], opts, ctrl),
		debug_token_at(buf[1], opts, ctrl, i),
	};
}

void
ircd::gpt::log_debug(const opts &opts,
                     const ctrl &ctrl)
{
	static char
	buf[2][512];

	log::logf
	{
		log, log::level::DEBUG,
		"%s %s",
		debug_head(buf[0], opts, ctrl),
		debug(buf[1], opts, ctrl),
	};
}

///////////////////////////////////////////////////////////////////////////////
//
// gpt::task
//

void
ircd::gpt::reset(task &task)
noexcept
{
	clear(task);
	seed(task);
}

void
ircd::gpt::clear(task &task)
noexcept
{
	assert(task.ctrl);
	memset(task.ctrl, 0x0, sizeof(gpt::ctrl));
}

void
ircd::gpt::seed(task &task)
noexcept
{
	assert(task.opts);
	seed(task, task.opts->seed);
}

void
ircd::gpt::seed(task &task,
                const uint64_t &val)
noexcept
{
	assert(task.ctrl);
    task.ctrl->rand[0] = val;
    task.ctrl->rand[1] = val;
    task.ctrl->rand[2] = 65537;
    task.ctrl->rand[3] = -1UL;
}

//
// gpt::task::task
//

ircd::gpt::task::task(const gpt::opts *const opts,
                      gpt::ctrl *const ctrl)
try
:opts
{
	opts
}
,ctrl
{
	ctrl
}
,code
{
	std::make_unique<pipe::code>()
}
,model
{
	std::make_unique<pipe::model>
	(
		*const_cast<const gpt::model::decoder *>(gpt::model::default_model)
	)
}
,desc
{
	this->opts,
	this->ctrl,
	*this->model,
	*this->code,
}
{
	assert(aligned(opts, size_t(cl::data::gart_page_size)));
	assert(aligned(ctrl, size_t(cl::data::gart_page_size)));

	seed(*this, this->opts->seed);
}
catch(const std::exception &e)
{
	log::error
	{
		log, "Task ctor :%s", e.what()
	};

	throw;
}

ircd::gpt::task::~task()
noexcept
{
}

bool
ircd::gpt::task::operator()()
{
	gpt::epoch epoch
	{
		*this
	};

	while(!epoch())
		ctx::interruption_point();

	return done();
}

bool
ircd::gpt::task::done()
const noexcept
{
	return false;
}

///////////////////////////////////////////////////////////////////////////////
//
// epoch
//

namespace ircd::gpt
{
	static thread_local u16 marker alignas(64) [1024];
}

//
// epoch::epoch
//

ircd::gpt::epoch::epoch(gpt::task &task)
:task
{
	task
}
,desc
{
	task.desc
}
,opts
{
	*task.opts
}
,ctrl
{
	*task.ctrl
}
,id
{
	ctrl.clk.epoch
}
,start
{
	0
}
,stop
{
	std::min(start + uint(opts.batch_size), gpt::model::default_data.size())
}
,moment
{
	gpt::model::default_moment[0],
	gpt::model::default_moment[1],
}
{
	assert(task.opts);
	assert(task.ctrl);

	ctrl.clk.step = 0;
}

ircd::gpt::epoch::~epoch()
noexcept
{
	if(opts.debug & 0x80000000U)
		log_debug_prof(opts, ctrl, this->profile);
}

bool
ircd::gpt::epoch::operator()()
{
	gpt::step step
	{
		*this
	};

	while(!step())
		ctx::interruption_point();

	if(!step.backpropagate())
		throw error
		{
			"Failed to backprop."
		};

	return done();
}

bool
ircd::gpt::epoch::done()
const noexcept
{
	return ctrl.clk.epoch != id;
}

void
ircd::gpt::epoch::profile_accumulate(const pipe::prof &profile)
{
	for(size_t i(0); i < profile.ts.size(); ++i)
		for(size_t j(0); j < profile.phases; ++j)
			this->profile.ts[i][j] += profile.ts[i][j];
}

///////////////////////////////////////////////////////////////////////////////
//
// step::step
//

ircd::gpt::step::step(gpt::epoch &epoch)
:epoch
{
	epoch
}
,desc
{
	epoch.desc
}
,opts
{
	epoch.opts
}
,ctrl
{
	epoch.ctrl
}
,id
{
	ctrl.clk.step
}
,start
{
	ctrl.clk.step * opts.batch_size
}
{
	assert(opts.batch_size > 0);

	ctrl.clk.samp = 0;
	ctrl.hit = 0;
	ctrl.miss = 0;
	ctrl.target.ppl = {{0}};
	ctrl.target.loss = {{0}};
	ctrl.select.ppl = {{0}};
	ctrl.select.loss = {{0}};

	for(uint i(0); i < opts.labels; ++i)
	{
		ctrl.label[i].ppl = {{0}};
		ctrl.label[i].loss = {{0}};
	}
}

ircd::gpt::step::~step()
noexcept
{
	if(opts.debug & 0x40000000U)
		log_debug_prof(opts, ctrl, this->profile);
}

bool
ircd::gpt::step::backpropagate()
{
	const auto hit
	{
		ctrl.target.logit.token == ctrl.select.logit.token
	};

	const auto select_loss_mean
	{
		ctrl.select.loss.mean
	};

	const auto target_loss_mean
	{
		ctrl.target.loss.mean
	};

	const auto loss_mean
	{
		(target_loss_mean + select_loss_mean) / 2.0f
	};

	static float mean_best { 10000.0f }, target_mean_best { 10000.0f };
	static ulong hit_best;
	static bool tack, last_tack;
	last_tack = tack;

	const auto loss
	{
		loss_mean
	};

	const bool improve_global
	{
		target_loss_mean < target_mean_best
	};

	const bool improve
	{
		improve_global
	};

	if(improve)
		mean_best = loss,
		target_mean_best = target_loss_mean,
		hit_best = ctrl.hit;
	else
		tack = !tack;

	const auto grad
	{
		!tack? loss : -loss
	};

	const auto steps
	{
		(opts.training_steps + opts.validation_steps + opts.testing_steps) / opts.batch_size
	};

	const auto step
	{
		this->epoch.id * steps + this->id
	};

	log::logf
	{
		log, improve? log::level::INFO: log::level::ERROR,
		"epoch:%u step:%u completed range[%u -> %zu] dsid:%u target:%-10.7f select:%-10.7f loss:%-10.7f [ %10.7f ] hit:%u miss:%u",
		this->epoch.id,
		step,
		this->start,
		this->start + opts.batch_size,
		this->id * opts.batch_size + ctrl.clk.samp,
		target_loss_mean,
		select_loss_mean,
		loss,
		grad * opts.alpha,
		ctrl.hit,
		ctrl.miss,
	};

	if(!opts.alpha)
		return true;

	if(!improve)
		return false;

	cl::exec
	{
		desc.model->decode->master[0], std::memory_order_acq_rel
	};

	auto &model
	{
		*mutable_cast(desc.model->decode_const)
	};

	const mutable_buffer model_buffer
	{
		reinterpret_cast<char *>(&model),
		sizeof(gpt::model::decoder) * 3
	};

	const mutable_buffer checkpoint_buffer
	{
		reinterpret_cast<char *>(&model) + sizeof(gpt::model::decoder) * 3,
		sizeof(gpt::model::decoder) * 3
	};

	if(improve)
		copy(checkpoint_buffer, model_buffer);
	else
		copy(model_buffer, checkpoint_buffer);

	ircd::timer stopwatch;
	backprop(opts, step, grad, model, epoch.moment);
	allocator::sync(model_buffer);

	char pbuf[1][32];
	log::logf
	{
		log, improve? log::level::DEBUG: log::level::ERROR,
		"backpropagation step:%u lr:%-8.6f mean:%-10.7f$L hits:%-5u Tbest:%-10.7f$L Mbest:%-10.7f$L Hbest:%-5lu grad:{ %10.7f$L } %s",
		step,
		opts.alpha,
		loss_mean,
		ctrl.hit,
		target_mean_best,
		mean_best,
		hit_best,
		grad,
		pretty(pbuf[0], stopwatch.at<milliseconds>(), 1),
	};

	return true;
}

bool
ircd::gpt::step::operator()()
{
	gpt::samp samp
	{
		*this
	};

	while(!samp())
		ctx::interruption_point();

	return done();
}

bool
ircd::gpt::step::done()
const noexcept
{
	return ctrl.clk.step != id;
}

void
ircd::gpt::step::profile_accumulate(const pipe::prof &profile)
{
	for(size_t i(0); i < profile.ts.size(); ++i)
		for(size_t j(0); j < profile.phases; ++j)
			this->profile.ts[i][j] += profile.ts[i][j];

	epoch.profile_accumulate(profile);
}

///////////////////////////////////////////////////////////////////////////////
//
// samp::samp
//

ircd::gpt::samp::samp(gpt::step &step)
:step
{
	step
}
,desc
{
	step.desc
}
,opts
{
	step.opts
}
,ctrl
{
	step.ctrl
}
,id
{
	ctrl.clk.samp
}
,accept
{
	-1
}
,dispatch
{
	1
}
,cycle
{
	0
}
,tokens
{
	tokenize()
}
,count
{
	int(opts.limit) > 0?
		tokens - opts.limit:
	int(opts.limit) < 0?
		std::abs(int(opts.limit)):
		tokens
}
{
	desc.cached = 0;

	ctrl.clk.cycle = cycle;
	ctrl.dispatch = dispatch;
	ctrl.accept = accept;
	ctrl.count = count;
	ctrl.tokens = tokens;
	ctrl.magic = 0xDEADBEEF;

	for(uint i(0); i < opts.labels; ++i)
	{
		ctrl.label[i].ppl = {{0}};
		ctrl.label[i].loss = {{0}};
	}

	assert(ctrl.count > 0);
	assert(ctrl.count < opts.context_tokens);
	assert(ctrl.count <= ctrl.tokens);

	if(opts.debug & 0x01)
		for(uint j(0); j < ctrl.count; ++j)
			log_debug_token(opts, ctrl, j);
}

ircd::gpt::samp::~samp()
noexcept
{
	if(run::level != run::level::RUN)
		return;

	cl::exec
	{
		desc.ctrl, std::memory_order_acq_rel
	};

	if(opts.debug & 0x04)
		log_debug(opts, ctrl);

	if(opts.debug & 0x40)
		log_debug_labels(opts, ctrl);

	if(opts.debug & 0x20000000U)
		log_debug_prof(opts, ctrl, this->profile);
}

bool
ircd::gpt::samp::operator()()
{
	if(dispatch > 0)
	{
		ctx::interruption_point();
		queue.emplace_back(*this);
		desc.cached = tokens;
		tokens += count < tokens? 0: 1;
		++cycle;
		++count;
		--dispatch;
		return false;
	}

	while(!queue.empty())
	{
		const unwind pop{[this]
		{
			queue.pop_front();
		}};

		if(evaluate(queue.front()))
			break;
	}

	return done();
}

bool
ircd::gpt::samp::done()
const noexcept
{
	return accept >= 0;
}

uint
ircd::gpt::samp::tokenize()
{
	const auto idx
	{
		step.start + ctrl.clk.samp
	};

	const gpt::model::text text
	{
		gpt::model::default_data.at(idx)
	};

	const json::string input
	{
		json::get<"text"_>(text)
	};

	thread_local char str_buf[16_KiB];
	const string_view str
	{
		json::unescape(str_buf, input)
	};

	assert(!empty(str));
	static const auto delim
	{
		"\n\n"_sv
	};

	const int phrases
	(
		ircd::token_count(str, delim)
	);

	uint count(0);
	int p(phrases);
	assert(p >= 0);

	if(startswith(str, delim))
	{
		ctrl.token[count++] = 198;
		ctrl.token[count++] = 198;
	}

	ircd::tokens(str, delim, [this, &count, &p, &phrases]
	(const string_view &phrase) noexcept -> bool
	{
		assert(!empty(phrase));
		const vector_view<u16> buf
		{
			ctrl.token + count, opts.buffer_tokens - count
		};

		const auto in
		{
			gpt::vocab::tokenize(buf, phrase)
		};

		if(count + size(in) + 2 > opts.context_tokens)
			return false;

		count += size(in);
		ctrl.token[count++] = 198;
		ctrl.token[count++] = 198;

		assert(p > 0);
		marker[--p] = count;
		return true;
	});

	for(assert(p >= 0); p < phrases; ++p)
		if(marker[p] <= opts.context_tokens)
			break;

	assert(p <= phrases);
	count = marker[p];

	for(uint i(count); i < opts.buffer_tokens; ++i)
		ctrl.token[i] = 198;

	if(!endswith(str, delim))
		count -= 2;

	assert(count > 0);
	assert(count <= opts.context_tokens);
	return count;
}

bool
ircd::gpt::samp::evaluate(pipe::cycle &cycle)
{
	cl::exec
	{
		desc.frame[cycle.frame], std::memory_order_consume
	};

	const auto &frame
	{
		acquire(cycle)
	};

	if(!retire(cycle, frame))
		return false;

	memcpy(&ctrl, &frame, sizeof(gpt::ctrl));

	const uint
	batch_size = opts.batch_size,
	samps = opts.training_steps + opts.validation_steps + opts.testing_steps,
	steps = samps / batch_size;

	const bool
	accepting = accept >= 0,
	cycling = !accepting,
	sampling = accepting,
	stepping = sampling && (frame.clk.samp + 1) >= batch_size,
	epoching = stepping && (frame.clk.step + 1) >= steps;

	//ctrl[ctrl.count] = ctrl.select.logit.token;
	//ctrl.count++;

	if(accepting)
	{
		ctrl.clk.cycle += cycling;
		ctrl.clk.samp += sampling;
		ctrl.clk.step += stepping;
		ctrl.clk.epoch += epoching;
	}

	return true;
}

bool
ircd::gpt::samp::retire(pipe::cycle &cycle,
                        const gpt::ctrl &frame)
{
	assert(accept < 0);
	accept = frame.accept;
	dispatch = frame.dispatch;

	if(cl::profile_queue)
	{
		const pipe::prof profile
		{
			cycle
		};

		if(opts.debug & 0x10000000U)
			log_debug_prof(opts, frame, profile);

		profile_accumulate(profile);
	}

	if(opts.debug & 0x02)
		log_debug(opts, frame);

	if(opts.debug & 0x20)
		log_debug_labels(opts, frame);

	if(opts.debug & 0x10)
		log_debug_topn(opts, frame);

	if(opts.debug & 0x200)
		log_debug_attns_top(opts, frame);

	dispatch &= boolmask<uint>(ircd::run::level == run::level::RUN);
	dispatch &= boolmask<uint>(!ctx::interruption_requested());
	dispatch &= boolmask<uint>(accept < 0);
	const bool finished
	{
		dispatch == 0
	};

	return finished;
}

void
ircd::gpt::samp::profile_accumulate(const pipe::prof &profile)
{
	for(size_t i(0); i < profile.ts.size(); ++i)
		for(size_t j(0); j < profile.phases; ++j)
			this->profile.ts[i][j] += profile.ts[i][j];

	step.profile_accumulate(profile);
}

///////////////////////////////////////////////////////////////////////////////
//
// ctrl
//

ircd::string_view
ircd::gpt::debug_top(const mutable_buffer &out,
                     const opts &opts,
                     const ctrl &ctrl,
                     const uint i)
{
	thread_local char buf[2][256];

	assert(opts.top_n > i);
	const auto &top
	{
		ctrl.top[i]
	};

	return fmt::sprintf
	{
		out, "%s T%02d %s",
		vocab::debug(buf[0], top.token, 1),
		i,
		debug(buf[1], opts, top),
	};
}

ircd::string_view
ircd::gpt::debug_label(const mutable_buffer &out,
                       const opts &opts,
                       const ctrl &ctrl,
                       const uint i,
                       const uint fmt)
{
	thread_local char buf[2][256];

	assert(opts.labels > i);
	const auto &label
	{
		ctrl.label[i]
	};

	return fmt::sprintf
	{
		out, "%s L%02d %s",
		vocab::debug(buf[0], label.logit.token, 1),
		i,
		debug(buf[1], opts, label, fmt),
	};
}

ircd::string_view
ircd::gpt::debug_attn(const mutable_buffer &out,
                      const opts &opts,
                      const ctrl &ctrl,
                      const uint ti)
{
	thread_local char buf[4][256];
	assert(ti < ctrl.count);

	memset(buf[1], 0x0, sizeof(buf[1]));
	for(uint i(0); i < opts.layers; ++i)
	{
		const auto f{[&](const auto &a) { return a == ti; }};
		if(std::none_of(ctrl.attn[i], ctrl.attn[i] + opts.attn_rank, f))
			continue;

		strlcat{buf[1], fmt::sprintf
		{
			buf[2], "  %1x[", uint(i)
		}};

		for(uint j(0); j < opts.attn_rank; ++j)
			if(ctrl.attn[i][j] == ti)
				strlcat{buf[1], fmt::sprintf
				{
					buf[2], "%1x", uint(j)
				}};

		strlcat{buf[1], "]"_sv};
	}

	return fmt::sprintf
	{
		out, "%s [%3u] <-%s",
		vocab::debug(buf[0], ctrl.token[ti], 1),
		ti,
		string_view{buf[1]},
	};
}

ircd::string_view
ircd::gpt::debug(const mutable_buffer &out,
                 const opts &opts,
                 const ctrl &ctrl)
{
	thread_local char
	buf[8][128],
	tmbuf[4][32];

	int top_idx {-1};
	for(uint i(0); i < opts.top_n; ++i)
		if(ctrl.top[i].token == ctrl.select.logit.token)
		{
			top_idx = i;
			break;
		}

	return fmt::sprintf
	{
		out, "%s %s %c T%02d %4u %6.2f%% %10.7f$L %c %s %s",
		vocab::debug(buf[0], ctrl.select.logit.token, 1),
		debug(buf[1], opts, ctrl.select),
		ctrl.target.logit.token == ctrl.top[0].token? '=' : ' ',
		top_idx,
		ctrl.hit,
		(ctrl.hit / float(ctrl.hit + ctrl.miss)) * 100.0f,
		ctrl.target.loss.mean - ctrl.select.loss.mean,
		ctrl.target.logit.token == ctrl.select.logit.token? '=' : ' ',
		debug(buf[2], opts, ctrl.target),
		vocab::debug(buf[3], ctrl.target.logit.token, 1),
	};
}

ircd::string_view
ircd::gpt::debug(const mutable_buffer &out,
                 const opts &opts,
                 const ctrl_label &label,
                 const uint fmt)
{
	thread_local char buf[64], bar[128];

	const auto diff
	{
		log2f(65536) - label.loss.mean
	};

	const auto pct
	{
		(diff / log2f(opts.logits)) * 100.0f
	};

	const auto barsz
	{
		std::min(uint(pct), std::min(66U, uint(sizeof(bar) - 1)))
	};

	memset(bar, '|', barsz);
	bar[barsz] = '\0';

	return fmt::sprintf
	{
		out,
		fmt == 1?
			"%s %10.7f$La %6.2f%% %s":
			"%s %10.7f$La",
		debug(buf, opts, label.logit, fmt),
		label.loss.mean,
		pct,
		string_view{bar},
	};
}

ircd::string_view
ircd::gpt::debug(const mutable_buffer &out,
                 const opts &opts,
                 const ctrl_logit &logit,
                 const uint fmt)
{
	return fmt::sprintf
	{
		out, "%6.2f%% %10.7f$L %5.1f$P",
		logit.samax * 100.0f,
		+0.0f - logf(logit.samax),
		(1.0f - logit.samax) * log2f(opts.logits),
	};
}

ircd::string_view
ircd::gpt::debug_head(const mutable_buffer &out,
                      const opts &opts,
                      const ctrl &ctrl)
{
	thread_local char head[64];

	return fmt::sprintf
	{
		out, "%s[%4u]-%1u",
		debug_head(head, opts, ctrl.clk),
		ctrl.count,
		ctrl.dispatch,
	};
}

ircd::string_view
ircd::gpt::debug_head(const mutable_buffer &out,
                      const opts &opts,
                      const ctrl_clk &clk)
{
	return fmt::sprintf
	{
		out, "%02u:%06u|%04u|%04u|%04u",
		clk.epoch,
		clk.step * opts.batch_size + clk.samp,
		clk.step,
		clk.samp,
		clk.cycle,
	};
}

ircd::string_view
ircd::gpt::debug_token(const mutable_buffer &out,
                       const opts &opts,
                       const ctrl &ctrl,
                       const uint fmt)
{
	assert(ctrl.count > 0);
	const auto pos
	{
		ctrl.count - 1
	};

	return debug_token_at(out, opts, ctrl, pos, fmt);
}

ircd::string_view
ircd::gpt::debug_token_at(const mutable_buffer &out,
                          const opts &opts,
                          const ctrl &ctrl,
                          const uint i,
                          const uint fmt)
{
	const auto &token
	{
		ctrl.token[i]
	};

	return vocab::debug(out, token, fmt);
}

///////////////////////////////////////////////////////////////////////////////
//
// opts
//

ircd_gpt_opts::ircd_gpt_opts()
noexcept
:seed
{
	1234567890UL
}
,top_k
{
	16
}
,top_p
{
	0.90f
}
,top_n
{
	0
}
,labels
{
	0
}
,frames
{
	8
}
,limit
{
	-1U
}
,debug
{
	0x00
}
,accept
{
	{ 198, 198, ushort(-1),    },
	{ 0, 0, 0, ushort(-1),     },
	{ ushort(-1),              },
	{ ushort(-1),              },
}
,batch_size
{
	32
}
,training_steps
{
	250000
}
,validation_steps
{
	5000
}
,testing_steps
{
	5000
}
,alpha
{
	0.00002
}
,beta
{
	0.9f,
	0.999f,
}
,epsilon
{
	0.00001
}
,lambda
{
	0.5
}
,logits
{
	50256
}
,buffer_tokens
{
	1024 - 16 // XXX
}
,context_tokens
{
	512 // 1024
}
,layers
{
	12
}
,lanes
{
	4
}
,embed_elems
{
	768
}
,embed_width
{
	embed_elems / lanes
}
,attn_rank
{
	12
}
,attn_mult
{
	3
}
,attn_elems
{
	embed_elems * attn_mult
}
,attn_fcon_width
{
	attn_elems / lanes
}
,attn_fcon_height
{
	embed_elems / lanes
}
,attn_proj_width
{
	embed_elems / lanes
}
,attn_proj_height
{
	embed_elems / lanes
}
,attn_self_elems
{
	(uint(powl(context_tokens, 2)) / 2) * attn_rank
}
,ffnn_mult
{
	4
}
,ffnn_elems
{
	embed_elems * ffnn_mult
}
,ffnn_fcon_width
{
	ffnn_elems / lanes
}
,ffnn_fcon_height
{
	embed_elems / lanes
}
,ffnn_proj_width
{
	embed_elems / lanes
}
,ffnn_proj_height
{
	ffnn_elems / lanes
}
{
}