mirror of
https://github.com/matrix-construct/construct
synced 2024-09-26 18:38:52 +02:00
ircd::gpt::task: Refactor generator interface to member functions.
This commit is contained in:
parent
56d944f33e
commit
6d2da3b4f1
6 changed files with 102 additions and 148 deletions
|
@ -1,28 +0,0 @@
|
|||
// Matrix Construct
|
||||
//
|
||||
// Copyright (C) Matrix Construct Developers, Authors & Contributors
|
||||
// Copyright (C) 2016-2021 Jason Volk <jason@zemos.net>
|
||||
//
|
||||
// Permission to use, copy, modify, and/or distribute this software for any
|
||||
// purpose with or without fee is hereby granted, provided that the above
|
||||
// copyright notice and this permission notice is present in all copies. The
|
||||
// full license for this software is available in the LICENSE file.
|
||||
|
||||
#pragma once
|
||||
#define HAVE_IRCD_GPT_GENERATE_H
|
||||
|
||||
namespace ircd::gpt
|
||||
{
|
||||
void
|
||||
generate(task &);
|
||||
|
||||
vector_view<u16>
|
||||
generate(const vector_view<u16> &out,
|
||||
const vector_view<const u16> &in,
|
||||
task &);
|
||||
|
||||
string_view
|
||||
generate(const mutable_buffer &out,
|
||||
const string_view &in,
|
||||
task &);
|
||||
}
|
|
@ -36,7 +36,6 @@ namespace ircd::gpt
|
|||
#include "step.h"
|
||||
#include "epoch.h"
|
||||
#include "task.h"
|
||||
#include "generate.h"
|
||||
|
||||
namespace ircd::gpt
|
||||
{
|
||||
|
|
|
@ -56,7 +56,7 @@ struct ircd_gpt_opts
|
|||
|
||||
/// Limit number of output tokens. Default of -1; other halting conditions
|
||||
/// will be used.
|
||||
uint limit;
|
||||
int limit;
|
||||
|
||||
/// Bitbar toggling various debug modes.
|
||||
uint debug;
|
||||
|
|
|
@ -44,7 +44,17 @@ struct ircd::gpt::task
|
|||
|
||||
public:
|
||||
bool done() const noexcept;
|
||||
bool operator()();
|
||||
|
||||
bool
|
||||
operator()();
|
||||
|
||||
vector_view<u16>
|
||||
operator()(const vector_view<u16> &out,
|
||||
const vector_view<const u16> &in);
|
||||
|
||||
string_view
|
||||
operator()(const mutable_buffer &out,
|
||||
const string_view &in);
|
||||
|
||||
task(const gpt::opts * = nullptr,
|
||||
gpt::ctrl * = nullptr);
|
||||
|
|
199
ircd/gpt.cc
199
ircd/gpt.cc
|
@ -14,112 +14,6 @@ ircd::gpt::log
|
|||
"gpt"
|
||||
};
|
||||
|
||||
ircd::string_view
|
||||
ircd::gpt::generate(const mutable_buffer &out,
|
||||
const string_view &in,
|
||||
task &task)
|
||||
{
|
||||
u16 input_buf[1024];
|
||||
const auto input_tokens
|
||||
{
|
||||
gpt::vocab::tokenize(input_buf, in)
|
||||
};
|
||||
|
||||
u16 output_buf[1024];
|
||||
const auto output_tokens
|
||||
{
|
||||
generate(output_buf, input_tokens, task)
|
||||
};
|
||||
|
||||
const auto output
|
||||
{
|
||||
gpt::vocab::detokenize(out, output_tokens)
|
||||
};
|
||||
|
||||
return output;
|
||||
}
|
||||
|
||||
ircd::vector_view<ircd::u16>
|
||||
ircd::gpt::generate(const vector_view<u16> &out,
|
||||
const vector_view<const u16> &in,
|
||||
task &task)
|
||||
{
|
||||
assert(task.opts);
|
||||
const auto &opts
|
||||
{
|
||||
*task.opts
|
||||
};
|
||||
|
||||
assert(task.ctrl);
|
||||
auto &ctrl
|
||||
{
|
||||
*task.ctrl
|
||||
};
|
||||
|
||||
size_t in_i(0);
|
||||
while(in_i < in.size() && ctrl.count < opts.buffer_tokens)
|
||||
if(in[in_i] == 628)
|
||||
{
|
||||
ctrl.token[ctrl.count++] = 198;
|
||||
ctrl.token[ctrl.count++] = 198;
|
||||
in_i++;
|
||||
}
|
||||
else ctrl.token[ctrl.count++] = in[in_i++];
|
||||
|
||||
generate(task);
|
||||
|
||||
size_t out_i(0);
|
||||
for(; out_i < out.size() && in_i + out_i < ctrl.count; out_i++)
|
||||
out[out_i] = ctrl.token[in_i + out_i];
|
||||
|
||||
return vector_view<u16>
|
||||
{
|
||||
out, out_i
|
||||
};
|
||||
}
|
||||
|
||||
void
|
||||
ircd::gpt::generate(task &task)
|
||||
{
|
||||
assert(task.opts);
|
||||
const auto &opts
|
||||
{
|
||||
*task.opts
|
||||
};
|
||||
|
||||
assert(task.ctrl);
|
||||
auto &ctrl
|
||||
{
|
||||
*task.ctrl
|
||||
};
|
||||
|
||||
gpt::epoch epoch
|
||||
{
|
||||
task
|
||||
};
|
||||
|
||||
gpt::step step
|
||||
{
|
||||
epoch
|
||||
};
|
||||
|
||||
gpt::samp samp
|
||||
{
|
||||
step
|
||||
};
|
||||
|
||||
bool halt {false}; do
|
||||
{
|
||||
gpt::pipe::cycle cycle
|
||||
{
|
||||
samp
|
||||
};
|
||||
|
||||
halt = !samp.evaluate(cycle);
|
||||
}
|
||||
while(!halt);
|
||||
}
|
||||
|
||||
//
|
||||
// debug
|
||||
//
|
||||
|
@ -394,6 +288,80 @@ noexcept
|
|||
{
|
||||
}
|
||||
|
||||
ircd::string_view
|
||||
ircd::gpt::task::operator()(const mutable_buffer &out,
|
||||
const string_view &in)
|
||||
{
|
||||
u16 input_buf[1024];
|
||||
const auto input_tokens
|
||||
{
|
||||
gpt::vocab::tokenize(input_buf, in)
|
||||
};
|
||||
|
||||
u16 output_buf[1024];
|
||||
const auto output_tokens
|
||||
{
|
||||
operator()(output_buf, input_tokens)
|
||||
};
|
||||
|
||||
const auto output
|
||||
{
|
||||
gpt::vocab::detokenize(out, output_tokens)
|
||||
};
|
||||
|
||||
return output;
|
||||
}
|
||||
|
||||
ircd::vector_view<ircd::u16>
|
||||
ircd::gpt::task::operator()(const vector_view<u16> &out,
|
||||
const vector_view<const u16> &in)
|
||||
{
|
||||
assert(this->opts);
|
||||
const auto &opts{*this->opts};
|
||||
|
||||
assert(this->ctrl);
|
||||
auto &ctrl{*this->ctrl};
|
||||
|
||||
size_t in_i(0);
|
||||
for(; in_i < in.size() && ctrl.count < opts.buffer_tokens; in_i++)
|
||||
if(in[in_i] == 628)
|
||||
{
|
||||
ctrl.token[ctrl.count++] = 198;
|
||||
ctrl.token[ctrl.count++] = 198;
|
||||
}
|
||||
else ctrl.token[ctrl.count++] = in[in_i];
|
||||
|
||||
gpt::epoch epoch
|
||||
{
|
||||
*this,
|
||||
};
|
||||
|
||||
gpt::step step
|
||||
{
|
||||
epoch
|
||||
};
|
||||
|
||||
gpt::samp samp
|
||||
{
|
||||
step
|
||||
};
|
||||
|
||||
bool halt {false}; do
|
||||
{
|
||||
halt = samp();
|
||||
}
|
||||
while(!halt);
|
||||
|
||||
size_t out_i(0);
|
||||
for(; out_i < out.size() && in_i + out_i < ctrl.count; out_i++)
|
||||
out[out_i] = ctrl.token[in_i + out_i];
|
||||
|
||||
return vector_view<u16>
|
||||
{
|
||||
out, out_i
|
||||
};
|
||||
}
|
||||
|
||||
bool
|
||||
ircd::gpt::task::operator()()
|
||||
{
|
||||
|
@ -776,15 +744,18 @@ ircd::gpt::samp::samp(gpt::step &step)
|
|||
}
|
||||
,tokens
|
||||
{
|
||||
tokenize()
|
||||
ctrl.count?:
|
||||
tokenize()
|
||||
}
|
||||
,count
|
||||
{
|
||||
int(opts.limit) > 0?
|
||||
opts.limit > 0?
|
||||
tokens - opts.limit:
|
||||
int(opts.limit) < 0?
|
||||
std::abs(int(opts.limit)):
|
||||
tokens
|
||||
opts.limit < 0?
|
||||
std::abs(opts.limit):
|
||||
!ctrl.count?
|
||||
tokens:
|
||||
1
|
||||
}
|
||||
{
|
||||
desc.cached = 0;
|
||||
|
@ -840,7 +811,7 @@ ircd::gpt::samp::operator()()
|
|||
ctx::interruption_point();
|
||||
queue.emplace_back(*this);
|
||||
desc.cached = tokens;
|
||||
tokens += count < tokens? 0: 1;
|
||||
tokens += count >= tokens;
|
||||
++cycle;
|
||||
++count;
|
||||
--dispatch;
|
||||
|
@ -987,7 +958,7 @@ ircd::gpt::samp::evaluate(pipe::cycle &cycle)
|
|||
stepping = sampling && (frame.clk.samp + 1) >= batch_size,
|
||||
epoching = stepping && (frame.clk.step + 1) >= steps;
|
||||
|
||||
//ctrl[ctrl.count] = ctrl.select.logit.token;
|
||||
//ctrl.token[ctrl.count] = ctrl.select.logit.token;
|
||||
//ctrl.count++;
|
||||
|
||||
if(accepting)
|
||||
|
@ -1329,7 +1300,7 @@ noexcept
|
|||
}
|
||||
,limit
|
||||
{
|
||||
-1U
|
||||
-1
|
||||
}
|
||||
,debug
|
||||
{
|
||||
|
|
|
@ -1081,15 +1081,17 @@ ircd_gpt_accept(__local struct ircd_gpt_ctrl *const ctrl,
|
|||
__constant const struct ircd_gpt_opts *const opts)
|
||||
{
|
||||
const bool
|
||||
unlimited = opts->limit == -1U;
|
||||
unlimited = opts->limit < 0;
|
||||
|
||||
const uint
|
||||
batch_size = opts->batch_size,
|
||||
samps = opts->training_steps + opts->validation_steps + opts->testing_steps,
|
||||
steps = samps / batch_size;
|
||||
steps = samps / batch_size,
|
||||
limit_ = opts->limit,
|
||||
unproc = ctrl->tokens - ctrl->count;
|
||||
|
||||
const int
|
||||
limit = min(opts->limit, opts->context_tokens),
|
||||
limit = min(limit_?: unproc, opts->context_tokens),
|
||||
cycle_remain = limit - (ctrl->clk.cycle + 1), // cycle not yet incr
|
||||
token_remain = opts->context_tokens - ctrl->count, // but count already incr
|
||||
remain_ = min(cycle_remain, token_remain),
|
||||
|
|
Loading…
Reference in a new issue