2021-03-30 03:22:42 +02:00
|
|
|
// Matrix Construct
|
|
|
|
//
|
|
|
|
// Copyright (C) Matrix Construct Developers, Authors & Contributors
|
|
|
|
// Copyright (C) 2016-2021 Jason Volk <jason@zemos.net>
|
|
|
|
//
|
|
|
|
// Permission to use, copy, modify, and/or distribute this software for any
|
|
|
|
// purpose with or without fee is hereby granted, provided that the above
|
|
|
|
// copyright notice and this permission notice is present in all copies. The
|
|
|
|
// full license for this software is available in the LICENSE file.
|
|
|
|
|
2021-08-28 03:02:33 +02:00
|
|
|
#include <ircd/simt/simt.h>
|
|
|
|
#include <ircd/gpt/token.h>
|
2021-09-18 06:08:10 +02:00
|
|
|
#include <ircd/gpt/opts.h>
|
|
|
|
#include <ircd/gpt/ctrl.h>
|
2021-03-30 03:22:42 +02:00
|
|
|
|
|
|
|
inline void
|
2021-04-27 05:43:21 +02:00
|
|
|
__attribute__((always_inline))
|
2021-04-22 21:17:29 +02:00
|
|
|
ircd_gpt_norm_fmad(__local float4 *const out,
|
|
|
|
__local const float4 *const in,
|
2021-04-02 22:01:38 +02:00
|
|
|
__global const float4 *const restrict bias,
|
|
|
|
__global const float4 *const restrict weight,
|
|
|
|
const uint i)
|
2021-03-30 03:22:42 +02:00
|
|
|
{
|
2021-04-02 22:01:38 +02:00
|
|
|
out[i] = in[i] * weight[i] + bias[i];
|
2021-03-30 03:22:42 +02:00
|
|
|
}
|
|
|
|
|
2021-04-11 04:28:23 +02:00
|
|
|
/// Gaussian Error Linear Unit
|
2021-03-30 03:22:42 +02:00
|
|
|
inline void
|
2021-09-01 21:15:48 +02:00
|
|
|
__attribute__((always_inline))
|
2021-04-11 04:28:23 +02:00
|
|
|
ircd_gpt_ffnn_gelu(__local float4 *const out,
|
|
|
|
__local const float4 *const in_,
|
|
|
|
const uint i)
|
2021-03-30 03:22:42 +02:00
|
|
|
{
|
2021-04-11 04:28:23 +02:00
|
|
|
const float4
|
2021-03-30 03:22:42 +02:00
|
|
|
in = in_[i];
|
|
|
|
|
2021-04-11 04:28:23 +02:00
|
|
|
float4 a;
|
2021-03-30 03:22:42 +02:00
|
|
|
a = 0.044715f;
|
|
|
|
a *= in;
|
|
|
|
a *= in;
|
|
|
|
a += 1.0f;
|
|
|
|
a *= 0.7978845608f;
|
|
|
|
a *= in;
|
|
|
|
|
|
|
|
a = tanh(a);
|
|
|
|
a += 1.0f;
|
|
|
|
a *= in;
|
|
|
|
a *= 0.5f;
|
|
|
|
|
|
|
|
out[i] = a;
|
|
|
|
}
|
|
|
|
|
2021-04-27 05:43:21 +02:00
|
|
|
// Matrix * Vector Multiply/Accumulate
|
2021-10-04 02:47:45 +02:00
|
|
|
inline float4
|
2021-09-01 21:15:48 +02:00
|
|
|
__attribute__((flatten, always_inline))
|
2021-10-04 02:47:45 +02:00
|
|
|
ircd_gpt_tmul_dot(__local const float4 *const restrict in,
|
|
|
|
__global const float4 *const restrict bias,
|
|
|
|
__global const float4 *const restrict weight,
|
|
|
|
const uint width,
|
|
|
|
const uint height,
|
|
|
|
const uint col,
|
|
|
|
const uint i,
|
|
|
|
const uint j)
|
2021-04-27 05:43:21 +02:00
|
|
|
{
|
|
|
|
const uint
|
|
|
|
li = get_local_id(0),
|
|
|
|
ln = get_local_size(0),
|
|
|
|
lanes = 4;
|
|
|
|
|
2021-10-04 02:47:45 +02:00
|
|
|
float4
|
|
|
|
acc = 0.0f;
|
|
|
|
|
|
|
|
for(uint k = 0; k < lanes; ++k)
|
|
|
|
{
|
|
|
|
const uint
|
|
|
|
row = j * lanes + k,
|
|
|
|
cell = row * width + col;
|
|
|
|
|
|
|
|
acc += in[j][k] * weight[cell];
|
|
|
|
}
|
|
|
|
|
|
|
|
return acc;
|
|
|
|
}
|
|
|
|
|
|
|
|
// Matrix * Vector Multiply/Accumulate
|
|
|
|
inline void
|
|
|
|
__attribute__((flatten, always_inline))
|
|
|
|
ircd_gpt_tmul(__local float4 *const restrict out,
|
|
|
|
__local const float4 *const restrict in,
|
|
|
|
__global const float4 *const restrict bias,
|
|
|
|
__global const float4 *const restrict weight,
|
|
|
|
const uint width,
|
|
|
|
const uint height,
|
|
|
|
const uint segs)
|
|
|
|
{
|
|
|
|
const uint
|
|
|
|
li = get_local_id(0),
|
|
|
|
ln = get_local_size(0);
|
|
|
|
|
2021-04-27 05:43:21 +02:00
|
|
|
__attribute__((opencl_unroll_hint))
|
|
|
|
for(uint i = 0; i < segs; ++i)
|
|
|
|
{
|
|
|
|
const uint
|
|
|
|
col = i * ln + li;
|
|
|
|
|
|
|
|
out[col] = bias[col];
|
|
|
|
}
|
|
|
|
|
2021-10-04 02:47:45 +02:00
|
|
|
for(uint i = 0; i < segs; ++i)
|
|
|
|
{
|
|
|
|
const uint
|
|
|
|
col = i * ln + li;
|
2021-04-27 05:43:21 +02:00
|
|
|
|
2021-10-04 02:47:45 +02:00
|
|
|
for(uint j = 0; j < height; ++j)
|
|
|
|
out[col] += ircd_gpt_tmul_dot(in, bias, weight, width, height, col, i, j);
|
|
|
|
}
|
2021-04-27 05:43:21 +02:00
|
|
|
}
|
|
|
|
|
2021-04-11 04:28:23 +02:00
|
|
|
inline void
|
2021-09-01 21:15:48 +02:00
|
|
|
__attribute__((flatten, always_inline))
|
2021-09-02 19:40:11 +02:00
|
|
|
ircd_gpt_ffnn_fcon(__global const struct ircd_gpt_ctrl *const ctrl,
|
2021-04-11 04:28:23 +02:00
|
|
|
__constant const struct ircd_gpt_opts *const opts,
|
|
|
|
__local union ircd_gpt_ffnn_aperaturev *const restrict out,
|
|
|
|
__local const union ircd_gpt_tokenv *const in,
|
|
|
|
__global const float4 *const restrict bias,
|
|
|
|
__global const float4 *const restrict weight)
|
|
|
|
{
|
|
|
|
const uint
|
|
|
|
li = get_local_id(0),
|
|
|
|
ln = get_local_size(0),
|
|
|
|
width = opts->ffnn_width,
|
2021-04-17 21:01:12 +02:00
|
|
|
height = opts->ffnn_height,
|
|
|
|
tiles = opts->ffnn_mult;
|
2021-03-30 03:22:42 +02:00
|
|
|
|
2021-10-04 02:47:45 +02:00
|
|
|
ircd_gpt_tmul(out->fcon, in->word, bias, weight, width, height, tiles);
|
2021-04-11 04:28:23 +02:00
|
|
|
|
2021-04-17 21:01:12 +02:00
|
|
|
for(uint i = 0; i < tiles; ++i)
|
2021-04-11 04:28:23 +02:00
|
|
|
ircd_gpt_ffnn_gelu(out->fcon, out->fcon, i * ln + li);
|
|
|
|
}
|
|
|
|
|
|
|
|
inline void
|
2021-09-01 21:15:48 +02:00
|
|
|
__attribute__((flatten, always_inline))
|
2021-09-02 19:40:11 +02:00
|
|
|
ircd_gpt_ffnn(__global const struct ircd_gpt_ctrl *const ctrl,
|
2021-04-02 22:01:38 +02:00
|
|
|
__constant const struct ircd_gpt_opts *const opts,
|
2021-04-11 04:28:23 +02:00
|
|
|
__local union ircd_gpt_tokenv *const restrict token,
|
|
|
|
__local union ircd_gpt_ffnn_aperaturev *const restrict buf,
|
2021-04-27 05:43:21 +02:00
|
|
|
__local union ircd_gpt_ffnn_aperaturev *const restrict tmp0,
|
|
|
|
__local union ircd_gpt_tokenv *const restrict tmp1,
|
2021-04-02 22:01:38 +02:00
|
|
|
__global const float4 *const restrict norm_bias,
|
|
|
|
__global const float4 *const restrict norm_weight,
|
|
|
|
__global const float4 *const restrict fcon_bias,
|
|
|
|
__global const float4 *const restrict fcon_weight,
|
|
|
|
__global const float4 *const restrict proj_bias,
|
|
|
|
__global const float4 *const restrict proj_weight)
|
2021-03-30 03:22:42 +02:00
|
|
|
{
|
|
|
|
const uint
|
|
|
|
li = get_local_id(0),
|
|
|
|
ln = get_local_size(0),
|
2021-09-17 08:02:11 +02:00
|
|
|
wi = get_global_offset(0) / ln + get_group_id(0),
|
2021-04-11 04:28:23 +02:00
|
|
|
width = opts->ffnn_width,
|
|
|
|
height = opts->ffnn_height;
|
2021-03-30 03:22:42 +02:00
|
|
|
|
|
|
|
// Layer re-normalization
|
2021-04-17 21:01:12 +02:00
|
|
|
ircd_simt_math_norm_f4lldr(token->word, token->word, buf->word);
|
2021-04-22 21:17:29 +02:00
|
|
|
ircd_gpt_norm_fmad(token->word, token->word, norm_bias, norm_weight, li);
|
2021-03-30 03:22:42 +02:00
|
|
|
|
2021-04-11 04:28:23 +02:00
|
|
|
// ln's writes are still pending but fcon reads results across threads.
|
|
|
|
barrier(CLK_LOCAL_MEM_FENCE);
|
2021-03-30 03:22:42 +02:00
|
|
|
|
2021-04-11 04:28:23 +02:00
|
|
|
// Fully connected
|
2021-04-22 21:17:29 +02:00
|
|
|
ircd_gpt_ffnn_fcon(ctrl, opts, buf, token, fcon_bias, fcon_weight);
|
2021-04-02 22:01:38 +02:00
|
|
|
|
2021-04-11 04:28:23 +02:00
|
|
|
// fcon's writes are still pending but proj reads results across threads.
|
2021-04-02 22:01:38 +02:00
|
|
|
barrier(CLK_LOCAL_MEM_FENCE);
|
2021-03-30 03:22:42 +02:00
|
|
|
|
|
|
|
// Projection
|
2021-10-04 02:47:45 +02:00
|
|
|
ircd_gpt_tmul(token->word, buf->fcon, proj_bias, proj_weight, height, width, 1);
|
2021-03-30 03:22:42 +02:00
|
|
|
}
|
|
|
|
|
2021-05-02 23:51:49 +02:00
|
|
|
inline void
|
2021-09-01 21:15:48 +02:00
|
|
|
__attribute__((flatten, always_inline))
|
2021-09-02 19:40:11 +02:00
|
|
|
ircd_gpt_attn_self_samax(__global const struct ircd_gpt_ctrl *const ctrl,
|
2021-05-02 23:51:49 +02:00
|
|
|
__constant const struct ircd_gpt_opts *const opts,
|
2021-09-17 08:03:44 +02:00
|
|
|
__local float self[][12],
|
|
|
|
const uint wn)
|
2021-05-02 23:51:49 +02:00
|
|
|
{
|
|
|
|
const uint
|
2021-09-17 08:03:44 +02:00
|
|
|
gn = get_global_size(0),
|
2021-05-02 23:51:49 +02:00
|
|
|
li = get_local_id(0),
|
2021-09-17 08:03:44 +02:00
|
|
|
ln = get_local_size(0);
|
2021-05-02 23:51:49 +02:00
|
|
|
|
|
|
|
struct ircd_math_samax samax =
|
|
|
|
{
|
|
|
|
.mu = -10000.0f,
|
|
|
|
.sum = 0.0f,
|
|
|
|
};
|
|
|
|
|
|
|
|
for(uint i = 0; i < wn; ++i)
|
|
|
|
samax.mu = max(samax.mu, self[i][li]);
|
|
|
|
|
|
|
|
for(uint i = 0; i < wn; ++i)
|
|
|
|
self[i][li] = exp(self[i][li] - samax.mu);
|
|
|
|
|
|
|
|
__attribute__((opencl_unroll_hint))
|
|
|
|
for(uint i = 0; i < wn; ++i)
|
|
|
|
samax.sum += self[i][li];
|
|
|
|
|
|
|
|
samax.lambda = 1.0f / samax.sum;
|
|
|
|
|
|
|
|
__attribute__((opencl_unroll_hint))
|
|
|
|
for(uint i = 0; i < wn; ++i)
|
|
|
|
self[i][li] *= samax.lambda;
|
|
|
|
}
|
|
|
|
|
2021-04-11 04:28:23 +02:00
|
|
|
inline void
|
2021-09-01 21:15:48 +02:00
|
|
|
__attribute__((flatten, always_inline))
|
2021-09-02 19:40:11 +02:00
|
|
|
ircd_gpt_attn_self(__global const struct ircd_gpt_ctrl *const ctrl,
|
2021-04-02 22:01:38 +02:00
|
|
|
__constant const struct ircd_gpt_opts *const opts,
|
|
|
|
__local union ircd_gpt_tokenv *const restrict out,
|
2021-04-17 21:01:12 +02:00
|
|
|
__local float self[][12],
|
2021-09-15 11:26:10 +02:00
|
|
|
__global const struct ircd_gpt_attn_qkvv *const restrict token)
|
2021-03-30 03:22:42 +02:00
|
|
|
{
|
|
|
|
const uint
|
|
|
|
gi = get_global_id(0),
|
|
|
|
gn = get_global_size(0),
|
|
|
|
li = get_local_id(0),
|
|
|
|
ln = get_local_size(0),
|
2021-09-17 08:03:44 +02:00
|
|
|
wi = get_global_offset(0) / ln + get_group_id(0),
|
|
|
|
wn = ctrl->tokens.count,
|
|
|
|
ti = li % opts->attn_rank,
|
|
|
|
ki = li / opts->attn_rank,
|
|
|
|
kn = ln / opts->attn_rank;
|
2021-03-30 03:22:42 +02:00
|
|
|
|
2021-04-17 21:01:12 +02:00
|
|
|
// Low-rank mask
|
2021-09-17 08:03:44 +02:00
|
|
|
if(li < opts->attn_rank)
|
2021-04-17 21:01:12 +02:00
|
|
|
{
|
2021-09-17 08:03:44 +02:00
|
|
|
// For each token
|
2021-04-17 21:01:12 +02:00
|
|
|
for(uint i = 0; i < wn; ++i)
|
|
|
|
{
|
2021-09-15 11:26:10 +02:00
|
|
|
// Left-attention mask
|
|
|
|
if(wi < i)
|
2021-04-17 21:01:12 +02:00
|
|
|
{
|
|
|
|
self[i][li] = -10000.0f;
|
|
|
|
continue;
|
|
|
|
}
|
2021-03-30 03:22:42 +02:00
|
|
|
|
2021-04-17 21:01:12 +02:00
|
|
|
float4 acc = 0.0f;
|
2021-09-17 08:03:44 +02:00
|
|
|
__attribute__((opencl_unroll_hint))
|
|
|
|
for(uint k = 0; k < kn; ++k)
|
2021-03-30 03:22:42 +02:00
|
|
|
{
|
|
|
|
float4
|
2021-04-17 21:01:12 +02:00
|
|
|
qry = token[wi].qry.attn[li][k],
|
|
|
|
key = token[i].key.attn[li][k];
|
|
|
|
|
|
|
|
acc += qry * key;
|
2021-03-30 03:22:42 +02:00
|
|
|
}
|
|
|
|
|
2021-04-17 21:01:12 +02:00
|
|
|
const float
|
|
|
|
sum = ircd_simt_reduce_add_f4(acc),
|
|
|
|
res = sum / 8.0f;
|
2021-03-30 03:22:42 +02:00
|
|
|
|
2021-04-17 21:01:12 +02:00
|
|
|
self[i][li] = res;
|
|
|
|
}
|
2021-03-30 03:22:42 +02:00
|
|
|
|
2021-04-17 21:01:12 +02:00
|
|
|
// Three-piece softmax
|
2021-09-17 08:03:44 +02:00
|
|
|
ircd_gpt_attn_self_samax(ctrl, opts, self, wn);
|
2021-04-17 21:01:12 +02:00
|
|
|
}
|
2021-03-30 03:22:42 +02:00
|
|
|
|
2021-04-17 21:01:12 +02:00
|
|
|
// Propagate to full width for value dot prod.
|
|
|
|
barrier(CLK_LOCAL_MEM_FENCE);
|
2021-03-30 03:22:42 +02:00
|
|
|
|
2021-04-17 21:01:12 +02:00
|
|
|
float4 acc = 0.0f;
|
2021-04-27 05:43:21 +02:00
|
|
|
__attribute__((opencl_unroll_hint))
|
2021-09-17 08:03:44 +02:00
|
|
|
for(uint i = 0; i < wi; ++i)
|
2021-04-17 21:01:12 +02:00
|
|
|
{
|
|
|
|
const float4
|
|
|
|
attn = self[i][ti],
|
|
|
|
val = token[i].val.attn[ti][ki];
|
|
|
|
|
|
|
|
acc += attn * val;
|
|
|
|
}
|
|
|
|
|
|
|
|
out->attn[ti][ki] = acc;
|
2021-03-30 03:22:42 +02:00
|
|
|
}
|
|
|
|
|
2021-04-11 04:28:23 +02:00
|
|
|
inline void
|
2021-09-01 21:15:48 +02:00
|
|
|
__attribute__((flatten, always_inline))
|
2021-09-02 19:40:11 +02:00
|
|
|
ircd_gpt_attn_proj(__global const struct ircd_gpt_ctrl *const ctrl,
|
2021-04-02 22:01:38 +02:00
|
|
|
__constant const struct ircd_gpt_opts *const opts,
|
2021-04-11 04:28:23 +02:00
|
|
|
__local union ircd_gpt_tokenv *const out,
|
|
|
|
__local const union ircd_gpt_tokenv *const xattn,
|
|
|
|
__global const float4 *const restrict bias,
|
|
|
|
__global const float4 *const restrict weight)
|
2021-03-30 03:22:42 +02:00
|
|
|
{
|
|
|
|
const uint
|
2021-04-02 22:01:38 +02:00
|
|
|
ln = get_local_size(0),
|
2021-09-17 08:02:11 +02:00
|
|
|
wi = get_global_offset(0) / ln + get_group_id(0),
|
2021-04-17 21:01:12 +02:00
|
|
|
width = opts->attn_height, // same
|
|
|
|
height = opts->attn_height;
|
2021-04-02 22:01:38 +02:00
|
|
|
|
2021-04-11 04:28:23 +02:00
|
|
|
// Projection
|
2021-10-04 02:47:45 +02:00
|
|
|
ircd_gpt_tmul(out->word, xattn->word, bias, weight, width, height, 1);
|
2021-03-30 03:22:42 +02:00
|
|
|
}
|
|
|
|
|
|
|
|
__kernel void
|
2021-09-01 21:15:48 +02:00
|
|
|
__attribute__((flatten))
|
2021-09-02 19:40:11 +02:00
|
|
|
ircd_gpt_coil(__global const struct ircd_gpt_ctrl *const ctrl,
|
2021-04-02 22:01:38 +02:00
|
|
|
__constant const struct ircd_gpt_opts *const opts,
|
|
|
|
__global union ircd_gpt_tokenv *const restrict accum,
|
2021-04-11 04:28:23 +02:00
|
|
|
__global const struct ircd_gpt_attn_qkvv *const restrict state,
|
2021-04-17 21:01:12 +02:00
|
|
|
__global const float4 *const restrict attn_proj_bias,
|
2021-04-02 22:01:38 +02:00
|
|
|
__global const float4 *const restrict attn_proj_weight,
|
|
|
|
__global const float4 *const restrict ffnn_norm_bias,
|
|
|
|
__global const float4 *const restrict ffnn_norm_weight,
|
|
|
|
__global const float4 *const restrict ffnn_fcon_bias,
|
|
|
|
__global const float4 *const restrict ffnn_fcon_weight,
|
|
|
|
__global const float4 *const restrict ffnn_proj_bias,
|
|
|
|
__global const float4 *const restrict ffnn_proj_weight)
|
2021-03-30 03:22:42 +02:00
|
|
|
{
|
2021-04-11 04:28:23 +02:00
|
|
|
const uint
|
|
|
|
li = get_local_id(0),
|
2021-04-27 05:43:21 +02:00
|
|
|
ln = get_local_size(0),
|
2021-09-17 08:02:11 +02:00
|
|
|
wi = get_global_offset(0) / ln + get_group_id(0);
|
2021-04-11 04:28:23 +02:00
|
|
|
|
|
|
|
__local union ircd_gpt_tokenv
|
2021-04-17 21:01:12 +02:00
|
|
|
buf1, buf0;
|
|
|
|
|
2021-04-22 21:17:29 +02:00
|
|
|
__local union
|
|
|
|
{
|
|
|
|
union ircd_gpt_ffnn_aperaturev
|
2021-04-27 05:43:21 +02:00
|
|
|
ffnn_fcon[2];
|
2021-04-22 21:17:29 +02:00
|
|
|
|
|
|
|
float
|
|
|
|
attn_self[512][12];
|
|
|
|
}
|
|
|
|
buf;
|
2021-04-11 04:28:23 +02:00
|
|
|
|
|
|
|
// Self-attention backend; this computes the self-attention result now
|
|
|
|
// that keys and values are globally visible across tokens.
|
2021-04-02 22:01:38 +02:00
|
|
|
ircd_gpt_attn_self
|
|
|
|
(
|
|
|
|
ctrl,
|
|
|
|
opts,
|
2021-04-11 04:28:23 +02:00
|
|
|
&buf1,
|
2021-04-22 21:17:29 +02:00
|
|
|
buf.attn_self,
|
2021-09-15 11:26:10 +02:00
|
|
|
state
|
2021-04-02 22:01:38 +02:00
|
|
|
);
|
2021-03-30 03:22:42 +02:00
|
|
|
|
2021-04-11 04:28:23 +02:00
|
|
|
barrier(CLK_LOCAL_MEM_FENCE);
|
|
|
|
|
|
|
|
// Project result of self-attention.
|
2021-04-02 22:01:38 +02:00
|
|
|
ircd_gpt_attn_proj
|
|
|
|
(
|
|
|
|
ctrl,
|
|
|
|
opts,
|
2021-04-11 04:28:23 +02:00
|
|
|
&buf0,
|
|
|
|
&buf1,
|
2021-04-02 22:01:38 +02:00
|
|
|
attn_proj_bias,
|
|
|
|
attn_proj_weight
|
|
|
|
);
|
2021-03-30 03:22:42 +02:00
|
|
|
|
2021-04-11 04:28:23 +02:00
|
|
|
// Frontend accumulation
|
|
|
|
{
|
|
|
|
const float4
|
|
|
|
attn = buf0.word[li],
|
|
|
|
resid = accum[wi].word[li];
|
|
|
|
|
|
|
|
buf0.word[li] += resid;
|
|
|
|
accum[wi].word[li] += attn;
|
|
|
|
}
|
|
|
|
|
|
|
|
// Backend mlp; layer-norm acquires any pending writes, no fence required.
|
2021-04-02 22:01:38 +02:00
|
|
|
ircd_gpt_ffnn
|
|
|
|
(
|
|
|
|
ctrl,
|
|
|
|
opts,
|
2021-04-11 04:28:23 +02:00
|
|
|
&buf0,
|
2021-04-27 05:43:21 +02:00
|
|
|
buf.ffnn_fcon + 0,
|
|
|
|
buf.ffnn_fcon + 1,
|
|
|
|
&buf1,
|
2021-04-02 22:01:38 +02:00
|
|
|
ffnn_norm_bias,
|
|
|
|
ffnn_norm_weight,
|
|
|
|
ffnn_fcon_bias,
|
|
|
|
ffnn_fcon_weight,
|
|
|
|
ffnn_proj_bias,
|
|
|
|
ffnn_proj_weight
|
|
|
|
);
|
2021-04-11 04:28:23 +02:00
|
|
|
|
|
|
|
// Backend accumulation
|
|
|
|
{
|
|
|
|
const float4
|
|
|
|
ffnn = buf0.word[li],
|
|
|
|
resid = accum[wi].word[li],
|
|
|
|
result = ffnn + resid;
|
|
|
|
|
|
|
|
accum[wi].word[li] = result;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
__kernel void
|
2021-09-01 21:15:48 +02:00
|
|
|
__attribute__((flatten))
|
2021-09-02 19:40:11 +02:00
|
|
|
ircd_gpt_attn_fcon(__global const struct ircd_gpt_ctrl *const ctrl,
|
2021-04-11 04:28:23 +02:00
|
|
|
__constant const struct ircd_gpt_opts *const opts,
|
|
|
|
__global union ircd_gpt_attn_aperaturev *const restrict state,
|
|
|
|
__global const union ircd_gpt_tokenv *const restrict accum,
|
|
|
|
__global const float4 *const restrict norm_bias,
|
|
|
|
__global const float4 *const restrict norm_weight,
|
|
|
|
__global const float4 *const restrict fcon_bias,
|
|
|
|
__global const float4 *const restrict fcon_weight)
|
|
|
|
{
|
|
|
|
const uint
|
|
|
|
li = get_local_id(0),
|
|
|
|
ln = get_local_size(0),
|
2021-09-17 08:02:11 +02:00
|
|
|
wi = get_global_offset(0) / ln + get_group_id(0),
|
2021-04-11 04:28:23 +02:00
|
|
|
width = opts->attn_width,
|
2021-04-17 21:01:12 +02:00
|
|
|
height = opts->attn_height,
|
|
|
|
tiles = opts->attn_mult;
|
2021-04-11 04:28:23 +02:00
|
|
|
|
|
|
|
__local union ircd_gpt_attn_aperaturev
|
|
|
|
token;
|
|
|
|
|
|
|
|
__local float4
|
|
|
|
tmp[768/4];
|
|
|
|
|
|
|
|
token.word[li] = accum[wi].word[li];
|
|
|
|
|
|
|
|
// Layer re-normalization
|
2021-04-17 21:01:12 +02:00
|
|
|
ircd_simt_math_norm_f4lldr(token.word, token.word, tmp);
|
2021-04-11 04:28:23 +02:00
|
|
|
ircd_gpt_norm_fmad(tmp, token.word, norm_bias, norm_weight, li);
|
|
|
|
|
|
|
|
// Ln's writes are still pending; fcon requires results across threads.
|
|
|
|
barrier(CLK_LOCAL_MEM_FENCE);
|
|
|
|
|
|
|
|
// Fully connected
|
2021-10-04 02:47:45 +02:00
|
|
|
ircd_gpt_tmul(token.fcon, tmp, fcon_bias, fcon_weight, width, height, tiles);
|
2021-04-11 04:28:23 +02:00
|
|
|
|
|
|
|
// Export queries, keys, and values.
|
2021-04-17 21:01:12 +02:00
|
|
|
for(uint i = 0; i < tiles; ++i)
|
2021-04-11 04:28:23 +02:00
|
|
|
state[wi].proj[i][li] = token.proj[i][li];
|
2021-03-30 03:22:42 +02:00
|
|
|
}
|
|
|
|
|
2021-04-02 22:01:38 +02:00
|
|
|
//
|
|
|
|
// frontend
|
|
|
|
//
|
2021-03-30 03:22:42 +02:00
|
|
|
|
2021-04-02 22:01:38 +02:00
|
|
|
inline void
|
2021-09-01 21:15:48 +02:00
|
|
|
__attribute__((always_inline))
|
2021-09-02 19:40:11 +02:00
|
|
|
_ircd_gpt_lm_embed(__global const struct ircd_gpt_ctrl *const ctrl,
|
2021-04-02 22:01:38 +02:00
|
|
|
__constant const struct ircd_gpt_opts *const opts,
|
2021-09-17 08:03:44 +02:00
|
|
|
__global union ircd_gpt_tokenv *const restrict accum,
|
2021-04-02 22:01:38 +02:00
|
|
|
__global const union ircd_gpt_tokenv *const restrict pos,
|
|
|
|
__global const union ircd_gpt_tokenv *const restrict vocab,
|
|
|
|
const uint out_idx,
|
|
|
|
const uint tok_idx,
|
|
|
|
const uint word_idx)
|
|
|
|
{
|
2021-03-30 03:22:42 +02:00
|
|
|
const ushort
|
2021-05-03 05:40:00 +02:00
|
|
|
ring_idx = (ctrl->tokens.head + tok_idx) % opts->buffer_tokens,
|
2021-04-11 04:28:23 +02:00
|
|
|
token = ctrl->token[ring_idx];
|
2021-03-30 03:22:42 +02:00
|
|
|
|
2021-04-02 22:01:38 +02:00
|
|
|
const float4
|
|
|
|
wte = vocab[token].word[word_idx],
|
|
|
|
wpe = pos[tok_idx].word[word_idx];
|
2021-03-30 03:22:42 +02:00
|
|
|
|
2021-09-17 08:03:44 +02:00
|
|
|
accum[out_idx].word[word_idx] = wte + wpe;
|
2021-04-02 22:01:38 +02:00
|
|
|
}
|
|
|
|
|
|
|
|
__kernel void
|
2021-09-01 21:15:48 +02:00
|
|
|
__attribute__((flatten))
|
2021-09-02 19:40:11 +02:00
|
|
|
ircd_gpt_lm_embed(__global const struct ircd_gpt_ctrl *const ctrl,
|
2021-04-02 22:01:38 +02:00
|
|
|
__constant const struct ircd_gpt_opts *const opts,
|
|
|
|
__global union ircd_gpt_tokenv *const restrict accum,
|
|
|
|
__global const union ircd_gpt_tokenv *const restrict pos,
|
|
|
|
__global const union ircd_gpt_tokenv *const restrict vocab)
|
|
|
|
{
|
|
|
|
const uint
|
2021-04-22 21:17:29 +02:00
|
|
|
li = get_local_id(0),
|
2021-09-17 08:02:11 +02:00
|
|
|
ln = get_local_size(0),
|
|
|
|
wi = get_global_offset(0) / ln + get_group_id(0);
|
2021-04-02 22:01:38 +02:00
|
|
|
|
2021-09-17 08:03:44 +02:00
|
|
|
_ircd_gpt_lm_embed(ctrl, opts, accum, pos, vocab, wi, wi, li);
|
2021-03-30 03:22:42 +02:00
|
|
|
}
|
|
|
|
|
|
|
|
__kernel void
|
2021-09-01 21:15:48 +02:00
|
|
|
__attribute__((flatten))
|
2021-09-02 19:40:11 +02:00
|
|
|
ircd_gpt_lm_norm(__global const struct ircd_gpt_ctrl *const ctrl,
|
2021-04-02 22:01:38 +02:00
|
|
|
__constant const struct ircd_gpt_opts *const opts,
|
|
|
|
__global union ircd_gpt_tokenv *const restrict accum,
|
|
|
|
__global const float4 *const restrict norm_bias,
|
|
|
|
__global const float4 *const restrict norm_weight)
|
2021-03-30 03:22:42 +02:00
|
|
|
{
|
|
|
|
const uint
|
|
|
|
li = get_local_id(0),
|
|
|
|
ln = get_local_size(0),
|
|
|
|
wi = get_global_offset(0) / ln + get_group_id(0);
|
|
|
|
|
2021-04-02 22:01:38 +02:00
|
|
|
__local union ircd_gpt_tokenv
|
2021-03-30 03:22:42 +02:00
|
|
|
token, tmp;
|
|
|
|
|
|
|
|
token.word[li] = accum[wi].word[li];
|
|
|
|
|
|
|
|
// Final re-normalization
|
2021-04-17 21:01:12 +02:00
|
|
|
ircd_simt_math_norm_f4lldr(token.word, token.word, tmp.word);
|
2021-04-22 21:17:29 +02:00
|
|
|
ircd_gpt_norm_fmad(token.word, token.word, norm_bias, norm_weight, li);
|
2021-03-30 03:22:42 +02:00
|
|
|
|
2021-04-22 21:17:29 +02:00
|
|
|
accum[wi].word[li] = token.word[li];
|
2021-03-30 03:22:42 +02:00
|
|
|
}
|
|
|
|
|
|
|
|
__kernel void
|
2021-09-01 21:15:48 +02:00
|
|
|
__attribute__((flatten))
|
2021-09-02 19:40:11 +02:00
|
|
|
ircd_gpt_lm_logit(__global const struct ircd_gpt_ctrl *const ctrl,
|
2021-04-02 22:01:38 +02:00
|
|
|
__constant const struct ircd_gpt_opts *const opts,
|
|
|
|
__global float *const restrict logit,
|
|
|
|
__global const union ircd_gpt_tokenv *const restrict accum,
|
|
|
|
__global const union ircd_gpt_tokenv *const restrict token)
|
2021-03-30 03:22:42 +02:00
|
|
|
{
|
|
|
|
const uint
|
2021-04-11 04:28:23 +02:00
|
|
|
gi = get_global_id(0),
|
2021-05-03 05:40:00 +02:00
|
|
|
ti = ctrl->tokens.count - 1,
|
2021-04-11 04:28:23 +02:00
|
|
|
words = opts->embed_width;
|
2021-03-30 03:22:42 +02:00
|
|
|
|
|
|
|
float4 acc = 0.0f;
|
2021-04-11 04:28:23 +02:00
|
|
|
__attribute__((opencl_unroll_hint))
|
|
|
|
for(uint j = 0; j < words; ++j)
|
2021-03-30 03:22:42 +02:00
|
|
|
{
|
|
|
|
const float4
|
2021-04-11 04:28:23 +02:00
|
|
|
in = accum[ti].word[j],
|
|
|
|
vocab = token[gi].word[j];
|
2021-03-30 03:22:42 +02:00
|
|
|
|
2021-04-11 04:28:23 +02:00
|
|
|
acc += vocab * in;
|
2021-03-30 03:22:42 +02:00
|
|
|
}
|
|
|
|
|
2021-04-17 21:01:12 +02:00
|
|
|
const float
|
|
|
|
ret = ircd_simt_reduce_add_f4(acc);
|
2021-03-30 03:22:42 +02:00
|
|
|
|
2021-04-11 04:28:23 +02:00
|
|
|
if(gi < opts->logits)
|
2021-04-17 21:01:12 +02:00
|
|
|
logit[gi] = ret;
|
2021-04-11 04:28:23 +02:00
|
|
|
else
|
|
|
|
logit[gi] = -10000.0f;
|
|
|
|
}
|
|
|
|
|
|
|
|
__kernel void
|
2021-09-01 21:15:48 +02:00
|
|
|
__attribute__((flatten))
|
2021-09-02 19:40:11 +02:00
|
|
|
ircd_gpt_lm_logsm(__global struct ircd_gpt_ctrl *const ctrl,
|
2021-04-11 04:28:23 +02:00
|
|
|
__constant const struct ircd_gpt_opts *const opts,
|
|
|
|
__global float4 *const restrict logsm,
|
|
|
|
__global float4 *const restrict logexp,
|
|
|
|
__global const float4 *const restrict logit)
|
|
|
|
{
|
|
|
|
const uint
|
|
|
|
li = get_local_id(0),
|
|
|
|
ln = get_local_size(0),
|
|
|
|
logits = opts->logits,
|
|
|
|
logits_alignup = logits + (ln - (logits % ln)),
|
|
|
|
tn = logits_alignup / ln / 4,
|
|
|
|
ti = tn * li;
|
|
|
|
|
|
|
|
__local float share[256];
|
|
|
|
__local float4 share4[256];
|
|
|
|
|
|
|
|
share4[li] = -10000.0f;
|
|
|
|
for(uint i = ti; i < ti + tn; ++i)
|
|
|
|
share4[li] = max(share4[li], logit[i]);
|
|
|
|
|
|
|
|
share[li] = -10000.0f;
|
|
|
|
for(uint k = 0; k < 4; ++k)
|
|
|
|
share[li] = max(share[li], share4[li][k]);
|
|
|
|
|
2021-04-17 21:01:12 +02:00
|
|
|
ircd_simt_reduce_max_flldr(share);
|
2021-04-11 04:28:23 +02:00
|
|
|
|
|
|
|
if(li == 0)
|
2021-05-02 23:51:49 +02:00
|
|
|
share4[li] = ctrl->samax.mu = share[li];
|
2021-04-11 04:28:23 +02:00
|
|
|
|
2021-04-17 21:01:12 +02:00
|
|
|
ircd_simt_broadcast_f4lldr(share4);
|
2021-04-11 04:28:23 +02:00
|
|
|
|
|
|
|
const float4
|
|
|
|
mu = share4[li];
|
|
|
|
|
|
|
|
share4[li] = 0.0f;
|
|
|
|
for(uint i = ti; i < ti + tn; ++i)
|
|
|
|
{
|
|
|
|
const float4
|
2021-04-17 21:01:12 +02:00
|
|
|
reg = logit[i] - mu;
|
2021-04-11 04:28:23 +02:00
|
|
|
|
2021-04-17 21:01:12 +02:00
|
|
|
float4 res;
|
2021-04-11 04:28:23 +02:00
|
|
|
for(uint k = 0; k < 4; ++k)
|
|
|
|
if(i * 4 + k < logits)
|
2021-04-17 21:01:12 +02:00
|
|
|
res[k] = exp(reg[k]);
|
2021-04-11 04:28:23 +02:00
|
|
|
else
|
2021-04-17 21:01:12 +02:00
|
|
|
res[k] = 0.0f;
|
|
|
|
|
|
|
|
share4[li] += res;
|
|
|
|
logexp[i] = res;
|
2021-04-11 04:28:23 +02:00
|
|
|
}
|
|
|
|
|
2021-04-17 21:01:12 +02:00
|
|
|
ircd_simt_reduce_add_f4lldr(share4);
|
2021-04-11 04:28:23 +02:00
|
|
|
|
|
|
|
if(li == 0)
|
|
|
|
{
|
2021-04-17 21:01:12 +02:00
|
|
|
const float
|
|
|
|
sum = ircd_simt_reduce_add_f4(share4[li]);
|
2021-04-11 04:28:23 +02:00
|
|
|
|
2021-05-02 23:51:49 +02:00
|
|
|
share4[li][0] = ctrl->samax.sum = sum;
|
|
|
|
share4[li][1] = ctrl->samax.lambda = 1.0f / sum;
|
2021-04-11 04:28:23 +02:00
|
|
|
}
|
|
|
|
|
2021-04-17 21:01:12 +02:00
|
|
|
ircd_simt_broadcast_f4lldr(share4);
|
2021-04-11 04:28:23 +02:00
|
|
|
|
|
|
|
const float4
|
|
|
|
sum = share4[li][0],
|
|
|
|
lambda = share4[li][1];
|
2021-04-17 21:01:12 +02:00
|
|
|
|
2021-04-11 04:28:23 +02:00
|
|
|
for(uint i = ti; i < ti + tn; ++i)
|
2021-04-17 21:01:12 +02:00
|
|
|
logsm[i] = logexp[i] * lambda;
|
2021-03-30 03:22:42 +02:00
|
|
|
}
|
|
|
|
|
2021-04-02 22:01:38 +02:00
|
|
|
inline void
|
2021-04-11 04:28:23 +02:00
|
|
|
__attribute__((always_inline))
|
2021-09-18 08:27:23 +02:00
|
|
|
ircd_gpt_lm_result_top(__global struct ircd_gpt_ctrl *const ctrl,
|
|
|
|
__constant const struct ircd_gpt_opts *const opts,
|
|
|
|
__local const ushort *const restrict idx,
|
|
|
|
__global const float *const restrict logsm,
|
|
|
|
__global const float *const restrict logit,
|
|
|
|
const uint i)
|
2021-04-02 22:01:38 +02:00
|
|
|
{
|
2021-09-18 08:27:23 +02:00
|
|
|
const ushort
|
|
|
|
token = idx[i];
|
2021-04-02 22:01:38 +02:00
|
|
|
|
2021-09-18 08:27:23 +02:00
|
|
|
const float
|
|
|
|
samax = logsm[token];
|
2021-04-02 22:01:38 +02:00
|
|
|
|
2021-09-18 08:27:23 +02:00
|
|
|
ctrl->top[i].token = token;
|
|
|
|
ctrl->top[i].samax = samax;
|
2021-04-02 22:01:38 +02:00
|
|
|
}
|
|
|
|
|
|
|
|
inline void
|
2021-04-11 04:28:23 +02:00
|
|
|
__attribute__((always_inline))
|
2021-09-18 08:27:23 +02:00
|
|
|
ircd_gpt_lm_result_label(__global struct ircd_gpt_ctrl *const ctrl,
|
|
|
|
__constant const struct ircd_gpt_opts *const opts,
|
|
|
|
__local const ushort *const restrict idx,
|
|
|
|
__global const float *const restrict logsm,
|
|
|
|
__global const float *const restrict logit,
|
|
|
|
const uint i)
|
2021-04-02 22:01:38 +02:00
|
|
|
{
|
2021-09-18 08:27:23 +02:00
|
|
|
__global struct ircd_gpt_ctrl_label
|
|
|
|
*const label = ctrl->label + i;
|
2021-04-02 22:01:38 +02:00
|
|
|
|
2021-09-18 08:27:23 +02:00
|
|
|
const ushort
|
|
|
|
token = label->token,
|
|
|
|
sum_sel = ctrl->epic.cycle % 3;
|
|
|
|
|
|
|
|
const float
|
|
|
|
samax = logsm[token],
|
|
|
|
mean_div = ctrl->epic.cycle + 1.0f;
|
2021-04-02 22:01:38 +02:00
|
|
|
|
2021-09-18 08:27:23 +02:00
|
|
|
const float
|
|
|
|
loss = 0.0f - log(samax),
|
|
|
|
loss_sum = label->loss.sum[0] + label->loss.sum[1] + label->loss.sum[2] + loss,
|
|
|
|
loss_mean = loss_sum / mean_div;
|
|
|
|
|
|
|
|
const float
|
|
|
|
perp = (1.0f - samax) * native_log2(opts->logits),
|
|
|
|
perp_sum = label->perp.sum[0] + label->perp.sum[1] + label->perp.sum[2] + perp,
|
|
|
|
perp_mean = perp_sum / mean_div;
|
|
|
|
|
|
|
|
label->samax = samax;
|
|
|
|
|
|
|
|
label->loss.last = loss;
|
|
|
|
label->loss.sum[sum_sel] += loss;
|
|
|
|
label->loss.mean = loss_mean;
|
|
|
|
|
|
|
|
label->perp.last = perp;
|
|
|
|
label->perp.sum[sum_sel] += perp;
|
|
|
|
label->perp.mean = perp_mean;
|
|
|
|
}
|
|
|
|
|
|
|
|
inline void
|
|
|
|
__attribute__((always_inline))
|
|
|
|
ircd_gpt_lm_result_select(__global struct ircd_gpt_ctrl *const ctrl,
|
|
|
|
__constant const struct ircd_gpt_opts *const opts,
|
|
|
|
__local const ushort *const restrict idx,
|
|
|
|
__global const float *const restrict logsm,
|
|
|
|
__global const float *const restrict logexp,
|
|
|
|
__global const float *const restrict logit)
|
|
|
|
{
|
2021-04-02 22:01:38 +02:00
|
|
|
const bool
|
2021-05-03 05:40:00 +02:00
|
|
|
buffer_full = ctrl->tokens.count >= opts->buffer_tokens;
|
2021-04-02 22:01:38 +02:00
|
|
|
|
|
|
|
const ulong
|
2021-09-17 17:21:20 +02:00
|
|
|
rnd = ircd_simt_rand_xoshiro256pg(ctrl->rand),
|
|
|
|
ent_k = max(opts->top_k, 1U),
|
|
|
|
ent_p = max(1U, min(opts->top_p, 100U));
|
|
|
|
|
|
|
|
const float
|
|
|
|
thresh = (rnd % ent_p) / 100.0f;
|
|
|
|
|
|
|
|
ushort select = 0;
|
|
|
|
float smacc = 0.0f;
|
|
|
|
for(; select < opts->top_k; ++select)
|
|
|
|
if((smacc += logsm[idx[select]]) > thresh)
|
|
|
|
break;
|
2021-04-02 22:01:38 +02:00
|
|
|
|
|
|
|
const ushort
|
2021-04-11 04:28:23 +02:00
|
|
|
token = idx[select],
|
2021-05-03 05:40:00 +02:00
|
|
|
dest = (ctrl->tokens.head + ctrl->tokens.count) % opts->buffer_tokens,
|
|
|
|
tokens = min(ctrl->tokens.count + 1, opts->buffer_tokens),
|
2021-04-11 04:28:23 +02:00
|
|
|
head = buffer_full?
|
2021-05-03 05:40:00 +02:00
|
|
|
(ctrl->tokens.head + 1) % opts->buffer_tokens: ctrl->tokens.head;
|
2021-04-02 22:01:38 +02:00
|
|
|
|
2021-05-03 05:40:00 +02:00
|
|
|
ctrl->tokens.head = head;
|
|
|
|
ctrl->tokens.count = tokens;
|
2021-04-17 21:01:12 +02:00
|
|
|
ctrl->token[dest] = token;
|
2021-09-18 08:27:23 +02:00
|
|
|
}
|
2021-04-17 21:01:12 +02:00
|
|
|
|
2021-09-18 08:27:23 +02:00
|
|
|
inline void
|
|
|
|
__attribute__((always_inline))
|
|
|
|
ircd_gpt_leave(__global struct ircd_gpt_ctrl *const ctrl,
|
|
|
|
__constant const struct ircd_gpt_opts *const opts,
|
|
|
|
const uint li)
|
|
|
|
{
|
|
|
|
if(ctrl->epic.cycle + 1 >= opts->limit)
|
|
|
|
ctrl->epic.epoch += 1;
|
2021-04-11 04:28:23 +02:00
|
|
|
|
2021-09-18 08:27:23 +02:00
|
|
|
ctrl->epic.cycle += 1;
|
|
|
|
ctrl->magic = 0xC7012C70U;
|
2021-04-02 22:01:38 +02:00
|
|
|
}
|
|
|
|
|
2021-03-30 03:22:42 +02:00
|
|
|
__kernel void
|
2021-09-01 21:15:48 +02:00
|
|
|
__attribute__((flatten))
|
2021-09-02 19:40:11 +02:00
|
|
|
ircd_gpt_lm_select(__global struct ircd_gpt_ctrl *const ctrl,
|
2021-04-02 22:01:38 +02:00
|
|
|
__constant const struct ircd_gpt_opts *const opts,
|
2021-04-11 04:28:23 +02:00
|
|
|
__global const float *const restrict logsm,
|
2021-04-02 22:01:38 +02:00
|
|
|
__global const float *const restrict logit)
|
2021-03-30 03:22:42 +02:00
|
|
|
{
|
|
|
|
const uint
|
|
|
|
li = get_local_id(0),
|
|
|
|
ln = get_local_size(0),
|
2021-04-11 04:28:23 +02:00
|
|
|
tn = opts->logits / ln,
|
2021-03-30 03:22:42 +02:00
|
|
|
ti = tn * li;
|
|
|
|
|
2021-04-11 04:28:23 +02:00
|
|
|
__local ushort idx[256];
|
2021-03-30 03:22:42 +02:00
|
|
|
|
|
|
|
idx[li] = ti;
|
2021-04-11 04:28:23 +02:00
|
|
|
for(uint j = ti + 1; j < ti + tn; ++j)
|
|
|
|
if(logsm[j] > logsm[idx[li]])
|
2021-03-30 03:22:42 +02:00
|
|
|
idx[li] = j;
|
|
|
|
|
2021-04-17 20:59:30 +02:00
|
|
|
ircd_simt_sort_idx16_flldr(idx, logsm);
|
2021-09-18 08:27:23 +02:00
|
|
|
|
|
|
|
if(li < opts->top_n)
|
|
|
|
ircd_gpt_lm_result_top(ctrl, opts, idx, logsm, logexp, logit, li);
|
|
|
|
|
|
|
|
if(li < opts->labels)
|
|
|
|
ircd_gpt_lm_result_label(ctrl, opts, idx, logsm, logexp, logit, li);
|
|
|
|
|
|
|
|
// Writes to `idx` from the sort are still pending across threads.
|
|
|
|
barrier(CLK_LOCAL_MEM_FENCE);
|
|
|
|
|
|
|
|
// Mask for write-leader
|
|
|
|
if(li == 0)
|
|
|
|
ircd_gpt_lm_result_select(ctrl, opts, idx, logsm, logexp, logit);
|
|
|
|
|
|
|
|
if(li != 0)
|
|
|
|
return;
|
|
|
|
|
2021-04-02 22:01:38 +02:00
|
|
|
ircd_gpt_leave(ctrl, opts, li);
|
2021-03-30 03:22:42 +02:00
|
|
|
}
|
2021-04-17 20:59:30 +02:00
|
|
|
|
|
|
|
//
|
|
|
|
// backpropagations
|
|
|
|
//
|
|
|
|
|
|
|
|
inline void
|
2021-09-01 21:15:48 +02:00
|
|
|
__attribute__((always_inline))
|
2021-09-02 19:40:11 +02:00
|
|
|
ircd_gpt_prop_elem(__global const struct ircd_gpt_ctrl *const ctrl,
|
2021-04-17 20:59:30 +02:00
|
|
|
__constant const struct ircd_gpt_opts *const opts,
|
|
|
|
__global float4 *const restrict param_,
|
|
|
|
__global float4 *const restrict exp_avg_,
|
|
|
|
__global float4 *const restrict exp_avg_sqr_)
|
|
|
|
{
|
|
|
|
const uint
|
|
|
|
li = get_local_id(0),
|
2021-05-03 05:40:00 +02:00
|
|
|
step = ctrl->epic.step;
|
2021-04-17 20:59:30 +02:00
|
|
|
|
|
|
|
const float4
|
|
|
|
param = param_[li],
|
2021-09-18 08:27:23 +02:00
|
|
|
grad = ctrl->label[0].loss.mean,
|
2021-04-17 20:59:30 +02:00
|
|
|
alpha[2] = { 1.0f - opts->beta[0], 1.0f - opts->beta[1], },
|
|
|
|
exp_avg = step? exp_avg_[li]: 0.0f,
|
|
|
|
exp_avg_sqr = step? exp_avg_sqr_[li]: 0.0f,
|
|
|
|
exp_avg_mul = exp_avg * opts->beta[0],
|
|
|
|
exp_avg_dot = exp_avg_mul + alpha[0] * grad,
|
|
|
|
exp_avg_sqr_mul = exp_avg_sqr * opts->beta[1],
|
|
|
|
exp_avg_sqr_dot = exp_avg_sqr_mul + alpha[1] * grad * grad,
|
|
|
|
denom = sqrt(exp_avg_sqr_dot) + opts->epsilon,
|
|
|
|
delta = opts->alpha * (exp_avg_dot / denom),
|
|
|
|
update = param - delta;
|
|
|
|
|
|
|
|
param_[li] = update;
|
|
|
|
exp_avg_[li] = exp_avg_dot;
|
|
|
|
exp_avg_sqr_[li] = exp_avg_sqr_dot;
|
|
|
|
}
|
|
|
|
|
|
|
|
__kernel void
|
2021-09-01 21:15:48 +02:00
|
|
|
__attribute__((flatten))
|
2021-09-02 19:40:11 +02:00
|
|
|
ircd_gpt_norm_prop(__global const struct ircd_gpt_ctrl *const ctrl,
|
2021-04-17 20:59:30 +02:00
|
|
|
__constant const struct ircd_gpt_opts *const opts,
|
|
|
|
__global union ircd_gpt_tokenv *const restrict bias,
|
|
|
|
__global union ircd_gpt_tokenv *const restrict bias_m0,
|
|
|
|
__global union ircd_gpt_tokenv *const restrict bias_m1,
|
|
|
|
__global union ircd_gpt_tokenv *const restrict weight,
|
|
|
|
__global union ircd_gpt_tokenv *const restrict weight_m0,
|
|
|
|
__global union ircd_gpt_tokenv *const restrict weight_m1)
|
|
|
|
{
|
|
|
|
ircd_gpt_prop_elem
|
|
|
|
(
|
|
|
|
ctrl, opts,
|
|
|
|
bias->word,
|
|
|
|
bias_m0->word,
|
|
|
|
bias_m1->word
|
|
|
|
);
|
|
|
|
|
|
|
|
ircd_gpt_prop_elem
|
|
|
|
(
|
|
|
|
ctrl, opts,
|
|
|
|
weight->word,
|
|
|
|
weight_m0->word,
|
|
|
|
weight_m1->word
|
|
|
|
);
|
|
|
|
}
|
|
|
|
|
|
|
|
__kernel void
|
2021-09-01 21:15:48 +02:00
|
|
|
__attribute__((flatten))
|
2021-09-02 19:40:11 +02:00
|
|
|
ircd_gpt_coil_prop_attn(__global const struct ircd_gpt_ctrl *const ctrl,
|
2021-04-17 20:59:30 +02:00
|
|
|
__constant const struct ircd_gpt_opts *const opts,
|
|
|
|
__global union ircd_gpt_tokenv *const restrict norm_bias,
|
|
|
|
__global union ircd_gpt_tokenv *const restrict norm_bias_m0,
|
|
|
|
__global union ircd_gpt_tokenv *const restrict norm_bias_m1,
|
|
|
|
__global union ircd_gpt_tokenv *const restrict norm_weight,
|
|
|
|
__global union ircd_gpt_tokenv *const restrict norm_weight_m0,
|
|
|
|
__global union ircd_gpt_tokenv *const restrict norm_weight_m1,
|
|
|
|
__global union ircd_gpt_attn_aperaturev *const restrict fcon_bias,
|
|
|
|
__global union ircd_gpt_attn_aperaturev *const restrict fcon_bias_m0,
|
|
|
|
__global union ircd_gpt_attn_aperaturev *const restrict fcon_bias_m1,
|
|
|
|
__global union ircd_gpt_attn_aperaturev *const restrict fcon_weight,
|
|
|
|
__global union ircd_gpt_attn_aperaturev *const restrict fcon_weight_m0,
|
|
|
|
__global union ircd_gpt_attn_aperaturev *const restrict fcon_weight_m1,
|
|
|
|
__global union ircd_gpt_tokenv *const restrict proj_bias,
|
|
|
|
__global union ircd_gpt_tokenv *const restrict proj_bias_m0,
|
|
|
|
__global union ircd_gpt_tokenv *const restrict proj_bias_m1,
|
|
|
|
__global union ircd_gpt_tokenv *const restrict proj_weight,
|
|
|
|
__global union ircd_gpt_tokenv *const restrict proj_weight_m0,
|
|
|
|
__global union ircd_gpt_tokenv *const restrict proj_weight_m1)
|
|
|
|
{
|
|
|
|
ircd_gpt_norm_prop
|
|
|
|
(
|
|
|
|
ctrl, opts,
|
|
|
|
norm_bias,
|
|
|
|
norm_bias_m0,
|
|
|
|
norm_bias_m1,
|
|
|
|
norm_weight,
|
|
|
|
norm_weight_m0,
|
|
|
|
norm_weight_m1
|
|
|
|
);
|
|
|
|
|
|
|
|
for(uint j = 0; j < 3; ++j)
|
|
|
|
ircd_gpt_prop_elem
|
|
|
|
(
|
|
|
|
ctrl, opts,
|
|
|
|
fcon_bias->proj[j],
|
|
|
|
fcon_bias_m0->proj[j],
|
|
|
|
fcon_bias_m1->proj[j]
|
|
|
|
);
|
|
|
|
|
|
|
|
for(uint i = 0; i < 768; ++i)
|
|
|
|
for(uint j = 0; j < 3; ++j)
|
|
|
|
ircd_gpt_prop_elem
|
|
|
|
(
|
|
|
|
ctrl, opts,
|
|
|
|
fcon_weight[i].proj[j],
|
|
|
|
fcon_weight_m0[i].proj[j],
|
|
|
|
fcon_weight_m1[i].proj[j]
|
|
|
|
);
|
|
|
|
|
|
|
|
ircd_gpt_prop_elem
|
|
|
|
(
|
|
|
|
ctrl, opts,
|
|
|
|
proj_bias->word,
|
|
|
|
proj_bias_m0->word,
|
|
|
|
proj_bias_m1->word
|
|
|
|
);
|
|
|
|
|
|
|
|
for(uint i = 0; i < 768; ++i)
|
|
|
|
ircd_gpt_prop_elem
|
|
|
|
(
|
|
|
|
ctrl, opts,
|
|
|
|
proj_weight[i].word,
|
|
|
|
proj_weight_m0[i].word,
|
|
|
|
proj_weight_m1[i].word
|
|
|
|
);
|
|
|
|
}
|
|
|
|
|
|
|
|
__kernel void
|
2021-09-01 21:15:48 +02:00
|
|
|
__attribute__((flatten))
|
2021-09-02 19:40:11 +02:00
|
|
|
ircd_gpt_coil_prop_ffnn(__global const struct ircd_gpt_ctrl *const ctrl,
|
2021-04-17 20:59:30 +02:00
|
|
|
__constant const struct ircd_gpt_opts *const opts,
|
|
|
|
__global union ircd_gpt_tokenv *const restrict norm_bias,
|
|
|
|
__global union ircd_gpt_tokenv *const restrict norm_bias_m0,
|
|
|
|
__global union ircd_gpt_tokenv *const restrict norm_bias_m1,
|
|
|
|
__global union ircd_gpt_tokenv *const restrict norm_weight,
|
|
|
|
__global union ircd_gpt_tokenv *const restrict norm_weight_m0,
|
|
|
|
__global union ircd_gpt_tokenv *const restrict norm_weight_m1,
|
|
|
|
__global union ircd_gpt_ffnn_aperaturev *const restrict fcon_bias,
|
|
|
|
__global union ircd_gpt_ffnn_aperaturev *const restrict fcon_bias_m0,
|
|
|
|
__global union ircd_gpt_ffnn_aperaturev *const restrict fcon_bias_m1,
|
|
|
|
__global union ircd_gpt_ffnn_aperaturev *const restrict fcon_weight,
|
|
|
|
__global union ircd_gpt_ffnn_aperaturev *const restrict fcon_weight_m0,
|
|
|
|
__global union ircd_gpt_ffnn_aperaturev *const restrict fcon_weight_m1,
|
|
|
|
__global union ircd_gpt_tokenv *const restrict proj_bias,
|
|
|
|
__global union ircd_gpt_tokenv *const restrict proj_bias_m0,
|
|
|
|
__global union ircd_gpt_tokenv *const restrict proj_bias_m1,
|
|
|
|
__global union ircd_gpt_tokenv *const restrict proj_weight,
|
|
|
|
__global union ircd_gpt_tokenv *const restrict proj_weight_m0,
|
|
|
|
__global union ircd_gpt_tokenv *const restrict proj_weight_m1)
|
|
|
|
{
|
|
|
|
ircd_gpt_norm_prop
|
|
|
|
(
|
|
|
|
ctrl, opts,
|
|
|
|
norm_bias,
|
|
|
|
norm_bias_m0,
|
|
|
|
norm_bias_m1,
|
|
|
|
norm_weight,
|
|
|
|
norm_weight_m0,
|
|
|
|
norm_weight_m1
|
|
|
|
);
|
|
|
|
|
|
|
|
for(uint j = 0; j < 4; ++j)
|
|
|
|
ircd_gpt_prop_elem
|
|
|
|
(
|
|
|
|
ctrl, opts,
|
|
|
|
fcon_bias->proj[j],
|
|
|
|
fcon_bias_m0->proj[j],
|
|
|
|
fcon_bias_m1->proj[j]
|
|
|
|
);
|
|
|
|
|
|
|
|
for(uint i = 0; i < 768; ++i)
|
|
|
|
for(uint j = 0; j < 4; ++j)
|
|
|
|
ircd_gpt_prop_elem
|
|
|
|
(
|
|
|
|
ctrl, opts,
|
|
|
|
fcon_weight[i].proj[j],
|
|
|
|
fcon_weight_m0[i].proj[j],
|
|
|
|
fcon_weight_m1[i].proj[j]
|
|
|
|
);
|
|
|
|
|
|
|
|
ircd_gpt_prop_elem
|
|
|
|
(
|
|
|
|
ctrl, opts,
|
|
|
|
proj_bias->word,
|
|
|
|
proj_bias_m0->word,
|
|
|
|
proj_bias_m1->word
|
|
|
|
);
|
|
|
|
|
|
|
|
for(uint i = 0; i < 3072; ++i)
|
|
|
|
ircd_gpt_prop_elem
|
|
|
|
(
|
|
|
|
ctrl, opts,
|
|
|
|
proj_weight[i].word,
|
|
|
|
proj_weight_m0[i].word,
|
|
|
|
proj_weight_m1[i].word
|
|
|
|
);
|
|
|
|
}
|
|
|
|
|
|
|
|
__kernel void
|
2021-09-01 21:15:48 +02:00
|
|
|
__attribute__((flatten))
|
2021-09-02 19:40:11 +02:00
|
|
|
ircd_gpt_lm_embed_prop(__global const struct ircd_gpt_ctrl *const ctrl,
|
2021-04-17 20:59:30 +02:00
|
|
|
__constant const struct ircd_gpt_opts *const opts,
|
|
|
|
__global union ircd_gpt_tokenv *const restrict pos,
|
|
|
|
__global union ircd_gpt_tokenv *const restrict pos_m0,
|
|
|
|
__global union ircd_gpt_tokenv *const restrict pos_m1,
|
|
|
|
__global union ircd_gpt_tokenv *const restrict token,
|
|
|
|
__global union ircd_gpt_tokenv *const restrict token_m0,
|
|
|
|
__global union ircd_gpt_tokenv *const restrict token_m1)
|
|
|
|
{
|
|
|
|
const uint
|
|
|
|
gn = get_global_size(0),
|
|
|
|
ln = get_local_size(0),
|
2021-09-17 08:02:11 +02:00
|
|
|
wi = get_global_offset(0) / ln + get_group_id(0),
|
|
|
|
wn = ctrl->tokens.count,
|
2021-04-17 20:59:30 +02:00
|
|
|
cn = opts->context_tokens / wn,
|
|
|
|
ci = cn * wi,
|
|
|
|
tn = opts->logits / wn,
|
|
|
|
ti = tn * wi;
|
|
|
|
|
|
|
|
for(uint i = ci; i < ci + cn; ++i)
|
|
|
|
ircd_gpt_prop_elem
|
|
|
|
(
|
|
|
|
ctrl, opts,
|
|
|
|
pos[i].word,
|
|
|
|
pos_m0[i].word,
|
|
|
|
pos_m1[i].word
|
|
|
|
);
|
|
|
|
|
|
|
|
for(uint i = ti; i < ti + tn; ++i)
|
|
|
|
ircd_gpt_prop_elem
|
|
|
|
(
|
|
|
|
ctrl, opts,
|
|
|
|
token[i].word,
|
|
|
|
token_m0[i].word,
|
|
|
|
token_m1[i].word
|
|
|
|
);
|
|
|
|
}
|