0
0
Fork 0
mirror of https://github.com/matrix-construct/construct synced 2024-06-16 08:58:20 +02:00

ircd::gpt::task: Refactor generator interface to member functions.

This commit is contained in:
Jason Volk 2022-07-01 18:50:20 -07:00
parent 56d944f33e
commit 6d2da3b4f1
6 changed files with 102 additions and 148 deletions

View file

@ -1,28 +0,0 @@
// Matrix Construct
//
// Copyright (C) Matrix Construct Developers, Authors & Contributors
// Copyright (C) 2016-2021 Jason Volk <jason@zemos.net>
//
// Permission to use, copy, modify, and/or distribute this software for any
// purpose with or without fee is hereby granted, provided that the above
// copyright notice and this permission notice is present in all copies. The
// full license for this software is available in the LICENSE file.
#pragma once
#define HAVE_IRCD_GPT_GENERATE_H
namespace ircd::gpt
{
void
generate(task &);
vector_view<u16>
generate(const vector_view<u16> &out,
const vector_view<const u16> &in,
task &);
string_view
generate(const mutable_buffer &out,
const string_view &in,
task &);
}

View file

@ -36,7 +36,6 @@ namespace ircd::gpt
#include "step.h"
#include "epoch.h"
#include "task.h"
#include "generate.h"
namespace ircd::gpt
{

View file

@ -56,7 +56,7 @@ struct ircd_gpt_opts
/// Limit number of output tokens. Default of -1; other halting conditions
/// will be used.
uint limit;
int limit;
/// Bitbar toggling various debug modes.
uint debug;

View file

@ -44,7 +44,17 @@ struct ircd::gpt::task
public:
bool done() const noexcept;
bool operator()();
bool
operator()();
vector_view<u16>
operator()(const vector_view<u16> &out,
const vector_view<const u16> &in);
string_view
operator()(const mutable_buffer &out,
const string_view &in);
task(const gpt::opts * = nullptr,
gpt::ctrl * = nullptr);

View file

@ -14,112 +14,6 @@ ircd::gpt::log
"gpt"
};
ircd::string_view
ircd::gpt::generate(const mutable_buffer &out,
const string_view &in,
task &task)
{
u16 input_buf[1024];
const auto input_tokens
{
gpt::vocab::tokenize(input_buf, in)
};
u16 output_buf[1024];
const auto output_tokens
{
generate(output_buf, input_tokens, task)
};
const auto output
{
gpt::vocab::detokenize(out, output_tokens)
};
return output;
}
ircd::vector_view<ircd::u16>
ircd::gpt::generate(const vector_view<u16> &out,
const vector_view<const u16> &in,
task &task)
{
assert(task.opts);
const auto &opts
{
*task.opts
};
assert(task.ctrl);
auto &ctrl
{
*task.ctrl
};
size_t in_i(0);
while(in_i < in.size() && ctrl.count < opts.buffer_tokens)
if(in[in_i] == 628)
{
ctrl.token[ctrl.count++] = 198;
ctrl.token[ctrl.count++] = 198;
in_i++;
}
else ctrl.token[ctrl.count++] = in[in_i++];
generate(task);
size_t out_i(0);
for(; out_i < out.size() && in_i + out_i < ctrl.count; out_i++)
out[out_i] = ctrl.token[in_i + out_i];
return vector_view<u16>
{
out, out_i
};
}
void
ircd::gpt::generate(task &task)
{
assert(task.opts);
const auto &opts
{
*task.opts
};
assert(task.ctrl);
auto &ctrl
{
*task.ctrl
};
gpt::epoch epoch
{
task
};
gpt::step step
{
epoch
};
gpt::samp samp
{
step
};
bool halt {false}; do
{
gpt::pipe::cycle cycle
{
samp
};
halt = !samp.evaluate(cycle);
}
while(!halt);
}
//
// debug
//
@ -394,6 +288,80 @@ noexcept
{
}
ircd::string_view
ircd::gpt::task::operator()(const mutable_buffer &out,
const string_view &in)
{
u16 input_buf[1024];
const auto input_tokens
{
gpt::vocab::tokenize(input_buf, in)
};
u16 output_buf[1024];
const auto output_tokens
{
operator()(output_buf, input_tokens)
};
const auto output
{
gpt::vocab::detokenize(out, output_tokens)
};
return output;
}
ircd::vector_view<ircd::u16>
ircd::gpt::task::operator()(const vector_view<u16> &out,
const vector_view<const u16> &in)
{
assert(this->opts);
const auto &opts{*this->opts};
assert(this->ctrl);
auto &ctrl{*this->ctrl};
size_t in_i(0);
for(; in_i < in.size() && ctrl.count < opts.buffer_tokens; in_i++)
if(in[in_i] == 628)
{
ctrl.token[ctrl.count++] = 198;
ctrl.token[ctrl.count++] = 198;
}
else ctrl.token[ctrl.count++] = in[in_i];
gpt::epoch epoch
{
*this,
};
gpt::step step
{
epoch
};
gpt::samp samp
{
step
};
bool halt {false}; do
{
halt = samp();
}
while(!halt);
size_t out_i(0);
for(; out_i < out.size() && in_i + out_i < ctrl.count; out_i++)
out[out_i] = ctrl.token[in_i + out_i];
return vector_view<u16>
{
out, out_i
};
}
bool
ircd::gpt::task::operator()()
{
@ -776,15 +744,18 @@ ircd::gpt::samp::samp(gpt::step &step)
}
,tokens
{
tokenize()
ctrl.count?:
tokenize()
}
,count
{
int(opts.limit) > 0?
opts.limit > 0?
tokens - opts.limit:
int(opts.limit) < 0?
std::abs(int(opts.limit)):
tokens
opts.limit < 0?
std::abs(opts.limit):
!ctrl.count?
tokens:
1
}
{
desc.cached = 0;
@ -840,7 +811,7 @@ ircd::gpt::samp::operator()()
ctx::interruption_point();
queue.emplace_back(*this);
desc.cached = tokens;
tokens += count < tokens? 0: 1;
tokens += count >= tokens;
++cycle;
++count;
--dispatch;
@ -987,7 +958,7 @@ ircd::gpt::samp::evaluate(pipe::cycle &cycle)
stepping = sampling && (frame.clk.samp + 1) >= batch_size,
epoching = stepping && (frame.clk.step + 1) >= steps;
//ctrl[ctrl.count] = ctrl.select.logit.token;
//ctrl.token[ctrl.count] = ctrl.select.logit.token;
//ctrl.count++;
if(accepting)
@ -1329,7 +1300,7 @@ noexcept
}
,limit
{
-1U
-1
}
,debug
{

View file

@ -1081,15 +1081,17 @@ ircd_gpt_accept(__local struct ircd_gpt_ctrl *const ctrl,
__constant const struct ircd_gpt_opts *const opts)
{
const bool
unlimited = opts->limit == -1U;
unlimited = opts->limit < 0;
const uint
batch_size = opts->batch_size,
samps = opts->training_steps + opts->validation_steps + opts->testing_steps,
steps = samps / batch_size;
steps = samps / batch_size,
limit_ = opts->limit,
unproc = ctrl->tokens - ctrl->count;
const int
limit = min(opts->limit, opts->context_tokens),
limit = min(limit_?: unproc, opts->context_tokens),
cycle_remain = limit - (ctrl->clk.cycle + 1), // cycle not yet incr
token_remain = opts->context_tokens - ctrl->count, // but count already incr
remain_ = min(cycle_remain, token_remain),