0
0
Fork 0
mirror of https://github.com/matrix-construct/construct synced 2024-06-10 22:18:54 +02:00

ircd::gpt: Optimizations for matrix multiply.

This commit is contained in:
Jason Volk 2021-10-03 17:47:45 -07:00
parent 1be7a8dea2
commit 8f90e7c0cd

View file

@ -51,21 +51,52 @@ ircd_gpt_ffnn_gelu(__local float4 *const out,
}
// Matrix * Vector Multiply/Accumulate
inline void
inline float4
__attribute__((flatten, always_inline))
ircd_gpt_sgemv(__local float4 *const restrict out,
__local const float4 *const restrict in,
__global const float4 *const restrict bias,
__global const float4 *const restrict weight,
const uint width,
const uint height,
const uint segs)
ircd_gpt_tmul_dot(__local const float4 *const restrict in,
__global const float4 *const restrict bias,
__global const float4 *const restrict weight,
const uint width,
const uint height,
const uint col,
const uint i,
const uint j)
{
const uint
li = get_local_id(0),
ln = get_local_size(0),
lanes = 4;
float4
acc = 0.0f;
for(uint k = 0; k < lanes; ++k)
{
const uint
row = j * lanes + k,
cell = row * width + col;
acc += in[j][k] * weight[cell];
}
return acc;
}
// Matrix * Vector Multiply/Accumulate
inline void
__attribute__((flatten, always_inline))
ircd_gpt_tmul(__local float4 *const restrict out,
__local const float4 *const restrict in,
__global const float4 *const restrict bias,
__global const float4 *const restrict weight,
const uint width,
const uint height,
const uint segs)
{
const uint
li = get_local_id(0),
ln = get_local_size(0);
__attribute__((opencl_unroll_hint))
for(uint i = 0; i < segs; ++i)
{
@ -75,24 +106,14 @@ ircd_gpt_sgemv(__local float4 *const restrict out,
out[col] = bias[col];
}
for(uint j = 0; j < height; ++j)
for(uint i = 0; i < segs; ++i)
{
const uint
col = i * ln + li;
for(uint i = 0; i < segs; ++i)
{
const uint
col = i * ln + li;
float4 acc = 0.0f;
for(uint k = 0; k < lanes; ++k)
{
const uint
row = j * lanes + k,
cell = row * width + col;
acc += in[j][k] * weight[cell];
}
out[col] += acc;
}
for(uint j = 0; j < height; ++j)
out[col] += ircd_gpt_tmul_dot(in, bias, weight, width, height, col, i, j);
}
}
inline void
@ -111,7 +132,7 @@ ircd_gpt_ffnn_fcon(__global const struct ircd_gpt_ctrl *const ctrl,
height = opts->ffnn_height,
tiles = opts->ffnn_mult;
ircd_gpt_sgemv(out->fcon, in->word, bias, weight, width, height, tiles);
ircd_gpt_tmul(out->fcon, in->word, bias, weight, width, height, tiles);
for(uint i = 0; i < tiles; ++i)
ircd_gpt_ffnn_gelu(out->fcon, out->fcon, i * ln + li);
@ -153,7 +174,7 @@ ircd_gpt_ffnn(__global const struct ircd_gpt_ctrl *const ctrl,
barrier(CLK_LOCAL_MEM_FENCE);
// Projection
ircd_gpt_sgemv(token->word, buf->fcon, proj_bias, proj_weight, height, width, 1);
ircd_gpt_tmul(token->word, buf->fcon, proj_bias, proj_weight, height, width, 1);
}
inline void
@ -278,7 +299,7 @@ ircd_gpt_attn_proj(__global const struct ircd_gpt_ctrl *const ctrl,
height = opts->attn_height;
// Projection
ircd_gpt_sgemv(out->word, xattn->word, bias, weight, width, height, 1);
ircd_gpt_tmul(out->word, xattn->word, bias, weight, width, height, 1);
}
__kernel void
@ -411,7 +432,7 @@ ircd_gpt_attn_fcon(__global const struct ircd_gpt_ctrl *const ctrl,
barrier(CLK_LOCAL_MEM_FENCE);
// Fully connected
ircd_gpt_sgemv(token.fcon, tmp, fcon_bias, fcon_weight, width, height, tiles);
ircd_gpt_tmul(token.fcon, tmp, fcon_bias, fcon_weight, width, height, tiles);
// Export queries, keys, and values.
for(uint i = 0; i < tiles; ++i)