0
0
Fork 0
mirror of https://github.com/matrix-construct/construct synced 2025-01-14 00:34:18 +01:00

ircd::gpt::pipe: Reuse logsm buffer for logexp intermediate values.

This commit is contained in:
Jason Volk 2022-01-25 09:34:34 -08:00
parent 4ff9176086
commit b7b1328352
3 changed files with 1 additions and 15 deletions

View file

@ -24,7 +24,6 @@ struct ircd::gpt::pipe::desc
master, // [root] single allocation for additional buffers:
accum, // [-sub] accumulator (tokens * embed * float)
logit, // [-sub] result logit vector (50257 * float)
logexp, // [-sub] outputs distribution (50257 * float)
logsm, // [-sub] outputs distribution (50257 * float)
ctrl, // [root] control page
opts; // [root] options page

View file

@ -620,7 +620,6 @@ ircd_gpt_lm_result_top(__global struct ircd_gpt_ctrl *const ctrl,
__constant const struct ircd_gpt_opts *const opts,
__local const ushort *const restrict idx,
__global const float *const restrict logsm,
__global const float *const restrict logexp,
__global const float *const restrict logit,
const uint i)
{
@ -640,7 +639,6 @@ ircd_gpt_lm_result_label(__global struct ircd_gpt_ctrl *const ctrl,
__constant const struct ircd_gpt_opts *const opts,
__local const ushort *const restrict idx,
__global const float *const restrict logsm,
__global const float *const restrict logexp,
__global const float *const restrict logit,
const uint i)
{
@ -732,7 +730,6 @@ __attribute__((flatten))
ircd_gpt_lm_select(__global struct ircd_gpt_ctrl *const ctrl,
__constant const struct ircd_gpt_opts *const opts,
__global const float *const restrict logsm,
__global const float *const restrict logexp,
__global const float *const restrict logit)
{
const uint

View file

@ -506,20 +506,12 @@ ircd::gpt::pipe::desc::desc(pipe::code &code,
accum.offset() + off_t(accum.size()),
},
}
,logexp
{
master,
{
65536 * sizeof(float),
logit.offset() + off_t(logit.size()),
},
}
,logsm
{
master,
{
65536 * sizeof(float),
logexp.offset() + off_t(logexp.size()),
logit.offset() + off_t(logit.size()),
},
}
,ctrl
@ -569,7 +561,6 @@ ircd::gpt::pipe::desc::desc(pipe::code &code,
ctrl,
opts,
logsm,
logexp,
logit,
}
,lm_select
@ -579,7 +570,6 @@ ircd::gpt::pipe::desc::desc(pipe::code &code,
ctrl,
opts,
logsm,
logexp,
logit,
}
,lm_norm_backprop