mirror of
https://github.com/matrix-construct/construct
synced 2025-03-14 05:20:17 +01:00
ircd::gpt: Optimizations for matrix multiply.
This commit is contained in:
parent
1be7a8dea2
commit
8f90e7c0cd
1 changed files with 50 additions and 29 deletions
|
@ -51,21 +51,52 @@ ircd_gpt_ffnn_gelu(__local float4 *const out,
|
|||
}
|
||||
|
||||
// Matrix * Vector Multiply/Accumulate
|
||||
inline void
|
||||
inline float4
|
||||
__attribute__((flatten, always_inline))
|
||||
ircd_gpt_sgemv(__local float4 *const restrict out,
|
||||
__local const float4 *const restrict in,
|
||||
__global const float4 *const restrict bias,
|
||||
__global const float4 *const restrict weight,
|
||||
const uint width,
|
||||
const uint height,
|
||||
const uint segs)
|
||||
ircd_gpt_tmul_dot(__local const float4 *const restrict in,
|
||||
__global const float4 *const restrict bias,
|
||||
__global const float4 *const restrict weight,
|
||||
const uint width,
|
||||
const uint height,
|
||||
const uint col,
|
||||
const uint i,
|
||||
const uint j)
|
||||
{
|
||||
const uint
|
||||
li = get_local_id(0),
|
||||
ln = get_local_size(0),
|
||||
lanes = 4;
|
||||
|
||||
float4
|
||||
acc = 0.0f;
|
||||
|
||||
for(uint k = 0; k < lanes; ++k)
|
||||
{
|
||||
const uint
|
||||
row = j * lanes + k,
|
||||
cell = row * width + col;
|
||||
|
||||
acc += in[j][k] * weight[cell];
|
||||
}
|
||||
|
||||
return acc;
|
||||
}
|
||||
|
||||
// Matrix * Vector Multiply/Accumulate
|
||||
inline void
|
||||
__attribute__((flatten, always_inline))
|
||||
ircd_gpt_tmul(__local float4 *const restrict out,
|
||||
__local const float4 *const restrict in,
|
||||
__global const float4 *const restrict bias,
|
||||
__global const float4 *const restrict weight,
|
||||
const uint width,
|
||||
const uint height,
|
||||
const uint segs)
|
||||
{
|
||||
const uint
|
||||
li = get_local_id(0),
|
||||
ln = get_local_size(0);
|
||||
|
||||
__attribute__((opencl_unroll_hint))
|
||||
for(uint i = 0; i < segs; ++i)
|
||||
{
|
||||
|
@ -75,24 +106,14 @@ ircd_gpt_sgemv(__local float4 *const restrict out,
|
|||
out[col] = bias[col];
|
||||
}
|
||||
|
||||
for(uint j = 0; j < height; ++j)
|
||||
for(uint i = 0; i < segs; ++i)
|
||||
{
|
||||
const uint
|
||||
col = i * ln + li;
|
||||
for(uint i = 0; i < segs; ++i)
|
||||
{
|
||||
const uint
|
||||
col = i * ln + li;
|
||||
|
||||
float4 acc = 0.0f;
|
||||
for(uint k = 0; k < lanes; ++k)
|
||||
{
|
||||
const uint
|
||||
row = j * lanes + k,
|
||||
cell = row * width + col;
|
||||
|
||||
acc += in[j][k] * weight[cell];
|
||||
}
|
||||
|
||||
out[col] += acc;
|
||||
}
|
||||
for(uint j = 0; j < height; ++j)
|
||||
out[col] += ircd_gpt_tmul_dot(in, bias, weight, width, height, col, i, j);
|
||||
}
|
||||
}
|
||||
|
||||
inline void
|
||||
|
@ -111,7 +132,7 @@ ircd_gpt_ffnn_fcon(__global const struct ircd_gpt_ctrl *const ctrl,
|
|||
height = opts->ffnn_height,
|
||||
tiles = opts->ffnn_mult;
|
||||
|
||||
ircd_gpt_sgemv(out->fcon, in->word, bias, weight, width, height, tiles);
|
||||
ircd_gpt_tmul(out->fcon, in->word, bias, weight, width, height, tiles);
|
||||
|
||||
for(uint i = 0; i < tiles; ++i)
|
||||
ircd_gpt_ffnn_gelu(out->fcon, out->fcon, i * ln + li);
|
||||
|
@ -153,7 +174,7 @@ ircd_gpt_ffnn(__global const struct ircd_gpt_ctrl *const ctrl,
|
|||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
|
||||
// Projection
|
||||
ircd_gpt_sgemv(token->word, buf->fcon, proj_bias, proj_weight, height, width, 1);
|
||||
ircd_gpt_tmul(token->word, buf->fcon, proj_bias, proj_weight, height, width, 1);
|
||||
}
|
||||
|
||||
inline void
|
||||
|
@ -278,7 +299,7 @@ ircd_gpt_attn_proj(__global const struct ircd_gpt_ctrl *const ctrl,
|
|||
height = opts->attn_height;
|
||||
|
||||
// Projection
|
||||
ircd_gpt_sgemv(out->word, xattn->word, bias, weight, width, height, 1);
|
||||
ircd_gpt_tmul(out->word, xattn->word, bias, weight, width, height, 1);
|
||||
}
|
||||
|
||||
__kernel void
|
||||
|
@ -411,7 +432,7 @@ ircd_gpt_attn_fcon(__global const struct ircd_gpt_ctrl *const ctrl,
|
|||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
|
||||
// Fully connected
|
||||
ircd_gpt_sgemv(token.fcon, tmp, fcon_bias, fcon_weight, width, height, tiles);
|
||||
ircd_gpt_tmul(token.fcon, tmp, fcon_bias, fcon_weight, width, height, tiles);
|
||||
|
||||
// Export queries, keys, and values.
|
||||
for(uint i = 0; i < tiles; ++i)
|
||||
|
|
Loading…
Add table
Reference in a new issue