mirror of
https://github.com/matrix-construct/construct
synced 2024-12-26 07:23:53 +01:00
ircd::gpt::pipe: Reuse logsm buffer for logexp intermediate values.
This commit is contained in:
parent
4ff9176086
commit
b7b1328352
3 changed files with 1 additions and 15 deletions
|
@ -24,7 +24,6 @@ struct ircd::gpt::pipe::desc
|
|||
master, // [root] single allocation for additional buffers:
|
||||
accum, // [-sub] accumulator (tokens * embed * float)
|
||||
logit, // [-sub] result logit vector (50257 * float)
|
||||
logexp, // [-sub] outputs distribution (50257 * float)
|
||||
logsm, // [-sub] outputs distribution (50257 * float)
|
||||
ctrl, // [root] control page
|
||||
opts; // [root] options page
|
||||
|
|
|
@ -620,7 +620,6 @@ ircd_gpt_lm_result_top(__global struct ircd_gpt_ctrl *const ctrl,
|
|||
__constant const struct ircd_gpt_opts *const opts,
|
||||
__local const ushort *const restrict idx,
|
||||
__global const float *const restrict logsm,
|
||||
__global const float *const restrict logexp,
|
||||
__global const float *const restrict logit,
|
||||
const uint i)
|
||||
{
|
||||
|
@ -640,7 +639,6 @@ ircd_gpt_lm_result_label(__global struct ircd_gpt_ctrl *const ctrl,
|
|||
__constant const struct ircd_gpt_opts *const opts,
|
||||
__local const ushort *const restrict idx,
|
||||
__global const float *const restrict logsm,
|
||||
__global const float *const restrict logexp,
|
||||
__global const float *const restrict logit,
|
||||
const uint i)
|
||||
{
|
||||
|
@ -732,7 +730,6 @@ __attribute__((flatten))
|
|||
ircd_gpt_lm_select(__global struct ircd_gpt_ctrl *const ctrl,
|
||||
__constant const struct ircd_gpt_opts *const opts,
|
||||
__global const float *const restrict logsm,
|
||||
__global const float *const restrict logexp,
|
||||
__global const float *const restrict logit)
|
||||
{
|
||||
const uint
|
||||
|
|
|
@ -506,20 +506,12 @@ ircd::gpt::pipe::desc::desc(pipe::code &code,
|
|||
accum.offset() + off_t(accum.size()),
|
||||
},
|
||||
}
|
||||
,logexp
|
||||
{
|
||||
master,
|
||||
{
|
||||
65536 * sizeof(float),
|
||||
logit.offset() + off_t(logit.size()),
|
||||
},
|
||||
}
|
||||
,logsm
|
||||
{
|
||||
master,
|
||||
{
|
||||
65536 * sizeof(float),
|
||||
logexp.offset() + off_t(logexp.size()),
|
||||
logit.offset() + off_t(logit.size()),
|
||||
},
|
||||
}
|
||||
,ctrl
|
||||
|
@ -569,7 +561,6 @@ ircd::gpt::pipe::desc::desc(pipe::code &code,
|
|||
ctrl,
|
||||
opts,
|
||||
logsm,
|
||||
logexp,
|
||||
logit,
|
||||
}
|
||||
,lm_select
|
||||
|
@ -579,7 +570,6 @@ ircd::gpt::pipe::desc::desc(pipe::code &code,
|
|||
ctrl,
|
||||
opts,
|
||||
logsm,
|
||||
logexp,
|
||||
logit,
|
||||
}
|
||||
,lm_norm_backprop
|
||||
|
|
Loading…
Reference in a new issue