0
0
Fork 0
mirror of https://github.com/matrix-construct/construct synced 2024-12-27 07:54:05 +01:00

ircd::gpt::pipe: Reuse logsm buffer for logexp intermediate values.

This commit is contained in:
Jason Volk 2022-01-25 09:34:34 -08:00
parent 4ff9176086
commit b7b1328352
3 changed files with 1 additions and 15 deletions

View file

@ -24,7 +24,6 @@ struct ircd::gpt::pipe::desc
master, // [root] single allocation for additional buffers: master, // [root] single allocation for additional buffers:
accum, // [-sub] accumulator (tokens * embed * float) accum, // [-sub] accumulator (tokens * embed * float)
logit, // [-sub] result logit vector (50257 * float) logit, // [-sub] result logit vector (50257 * float)
logexp, // [-sub] outputs distribution (50257 * float)
logsm, // [-sub] outputs distribution (50257 * float) logsm, // [-sub] outputs distribution (50257 * float)
ctrl, // [root] control page ctrl, // [root] control page
opts; // [root] options page opts; // [root] options page

View file

@ -620,7 +620,6 @@ ircd_gpt_lm_result_top(__global struct ircd_gpt_ctrl *const ctrl,
__constant const struct ircd_gpt_opts *const opts, __constant const struct ircd_gpt_opts *const opts,
__local const ushort *const restrict idx, __local const ushort *const restrict idx,
__global const float *const restrict logsm, __global const float *const restrict logsm,
__global const float *const restrict logexp,
__global const float *const restrict logit, __global const float *const restrict logit,
const uint i) const uint i)
{ {
@ -640,7 +639,6 @@ ircd_gpt_lm_result_label(__global struct ircd_gpt_ctrl *const ctrl,
__constant const struct ircd_gpt_opts *const opts, __constant const struct ircd_gpt_opts *const opts,
__local const ushort *const restrict idx, __local const ushort *const restrict idx,
__global const float *const restrict logsm, __global const float *const restrict logsm,
__global const float *const restrict logexp,
__global const float *const restrict logit, __global const float *const restrict logit,
const uint i) const uint i)
{ {
@ -732,7 +730,6 @@ __attribute__((flatten))
ircd_gpt_lm_select(__global struct ircd_gpt_ctrl *const ctrl, ircd_gpt_lm_select(__global struct ircd_gpt_ctrl *const ctrl,
__constant const struct ircd_gpt_opts *const opts, __constant const struct ircd_gpt_opts *const opts,
__global const float *const restrict logsm, __global const float *const restrict logsm,
__global const float *const restrict logexp,
__global const float *const restrict logit) __global const float *const restrict logit)
{ {
const uint const uint

View file

@ -506,20 +506,12 @@ ircd::gpt::pipe::desc::desc(pipe::code &code,
accum.offset() + off_t(accum.size()), accum.offset() + off_t(accum.size()),
}, },
} }
,logexp
{
master,
{
65536 * sizeof(float),
logit.offset() + off_t(logit.size()),
},
}
,logsm ,logsm
{ {
master, master,
{ {
65536 * sizeof(float), 65536 * sizeof(float),
logexp.offset() + off_t(logexp.size()), logit.offset() + off_t(logit.size()),
}, },
} }
,ctrl ,ctrl
@ -569,7 +561,6 @@ ircd::gpt::pipe::desc::desc(pipe::code &code,
ctrl, ctrl,
opts, opts,
logsm, logsm,
logexp,
logit, logit,
} }
,lm_select ,lm_select
@ -579,7 +570,6 @@ ircd::gpt::pipe::desc::desc(pipe::code &code,
ctrl, ctrl,
opts, opts,
logsm, logsm,
logexp,
logit, logit,
} }
,lm_norm_backprop ,lm_norm_backprop