0
0
Fork 0
mirror of https://github.com/matrix-construct/construct synced 2025-01-13 08:23:56 +01:00

ircd::gpt: Model structural tweaks; task structure; various.

This commit is contained in:
Jason Volk 2021-03-29 18:14:55 -07:00
parent cb45dcc840
commit 29fb7910b7
4 changed files with 47 additions and 22 deletions

View file

@ -17,8 +17,10 @@ namespace ircd::gpt::model
struct attn;
struct ffnn;
struct block;
struct embed;
struct decoder;
constexpr auto align {64};
extern const decoder *default_model;
}
@ -26,29 +28,35 @@ namespace ircd::gpt::model
struct ircd::gpt::model::attn
{
float
attn_bias alignas(64) [2304],
attn_weight alignas(64) [768][2304],
proj_bias alignas(64) [768],
proj_weight alignas(64) [768][768];
bool bias alignas(64) [1024][1024];
attn_bias alignas(align) [2304],
attn_weight alignas(align) [768][2304];
bool
bias alignas(align) [1024][1024];
float
proj_bias alignas(align) [768],
proj_weight alignas(align) [768][768];
};
/// Feed-forward neural network
struct ircd::gpt::model::ffnn
{
float
fc_bias alignas(64) [3072],
fc_weight alignas(64) [768][3072],
proj_bias alignas(64) [768],
proj_weight alignas(64) [3072][768];
fc_bias alignas(align) [3072],
fc_weight alignas(align) [768][3072];
float
proj_bias alignas(align) [768],
proj_weight alignas(align) [3072][768];
};
/// Layer normalization
struct ircd::gpt::model::norm
{
float
bias alignas(64) [768],
weight alignas(64) [768];
bias alignas(align) [768],
weight alignas(align) [768];
};
/// Transformer block
@ -56,15 +64,24 @@ struct ircd::gpt::model::block
{
norm ln1;
model::attn attn;
norm ln2;
model::ffnn ffnn;
};
struct ircd::gpt::model::decoder
/// Vocabulary embeddings
struct ircd::gpt::model::embed
{
float
wpe alignas(64) [1024][768],
wte alignas(64) [65536][768];
block layer[12];
norm f;
pos alignas(align) [1024][768],
token alignas(align) [65536][768];
};
struct ircd::gpt::model::decoder
{
block layer[12];
norm f;
embed word;
}
__attribute__((packed));

View file

@ -23,6 +23,12 @@ struct ircd::gpt::task
/// Current task status.
enum status status {'\0'};
/// State counters for the accept codes specified in the options.
uint8_t accept_seq[3] {0};
/// State counters for the error codes specified in the options.
uint8_t error_seq[3] {0};
/// Accumulates the number of executions by the user. Each call to the
/// interface is an execution.
uint64_t epoch {0};

View file

@ -183,12 +183,12 @@ ircd::gpt::embed(float *const out,
assert(opts.model);
const auto &wpe
{
opts.model->wpe[position]
opts.model->word.pos[position]
};
const auto &wte
{
opts.model->wte[token]
opts.model->word.token[token]
};
for(uint j(0); j < 768; ++j)

View file

@ -141,6 +141,8 @@ ircd::gpt::model::init_from_cache(const string_view &cache_path)
};
fs::map::opts map_opts;
map_opts.huge2mb = true;
map_opts.locked = false;
default_model_shm = fs::map
{
fd, map_opts, sizeof(decoder)
@ -283,9 +285,9 @@ ircd::gpt::model::init_wpe_weight(decoder &d,
{
size_t j(0);
for(const auto &elem : vec)
d.wpe[i][j++] = lex_cast<float>(elem);
d.word.pos[i][j++] = lex_cast<float>(elem);
always_assert(j == sizeof(d.wpe[i]) / sizeof(float));
always_assert(j == sizeof(d.word.pos[i]) / sizeof(float));
++i;
}
}
@ -301,9 +303,9 @@ ircd::gpt::model::init_wte_weight(decoder &d,
{
size_t j(0);
for(const auto &elem : vec)
d.wte[i][j++] = lex_cast<float>(elem);
d.word.token[i][j++] = lex_cast<float>(elem);
always_assert(j == sizeof(d.wte[i]) / sizeof(float));
always_assert(j == sizeof(d.word.token[i]) / sizeof(float));
++i;
}
}