0
0
Fork 0
mirror of https://github.com/matrix-construct/construct synced 2024-12-11 08:02:59 +01:00
construct/ircd/gpt_gpu.cl

1455 lines
38 KiB
Common Lisp
Raw Normal View History

2021-03-30 03:22:42 +02:00
// Matrix Construct
//
// Copyright (C) Matrix Construct Developers, Authors & Contributors
// Copyright (C) 2016-2021 Jason Volk <jason@zemos.net>
//
// Permission to use, copy, modify, and/or distribute this software for any
// purpose with or without fee is hereby granted, provided that the above
// copyright notice and this permission notice is present in all copies. The
// full license for this software is available in the LICENSE file.
2022-06-20 03:59:29 +02:00
#pragma clang fp exceptions(ignore)
#pragma clang fp reassociate(on)
#pragma clang fp contract(fast)
#include <ircd/config.h>
#include <ircd/portable.h>
2022-06-20 03:59:29 +02:00
#include <clc/clc.h>
#include <ircd/simt/simt.h>
2022-06-20 03:59:29 +02:00
#include <ircd/gpt/vector.h>
#include <ircd/gpt/opts.h>
#include <ircd/gpt/ctrl.h>
2022-06-20 03:59:29 +02:00
#include <ircd/gpt/gpu.h>
//
// head
//
__kernel void
__attribute__((visibility("protected")))
ircd_gpt_alloc(__global const void *const restrict model,
__global void *const restrict master,
__constant const void *const opts,
__global void *const restrict ctrl,
__global void *const restrict frame0,
__global void *const restrict frame1,
__global void *const restrict frame2,
__global void *const restrict frame3,
__global void *const restrict frame4,
__global void *const restrict frame5,
__global void *const restrict frame6,
__global void *const restrict frame7)
2021-03-30 03:22:42 +02:00
{
}
2022-06-20 03:59:29 +02:00
__kernel void
__attribute__((visibility("protected")))
ircd_gpt_enter(__global const void *const restrict model,
__global void *const restrict state,
__global void *const restrict master,
__constant const struct ircd_gpt_opts *const opts,
__global struct ircd_gpt_ctrl *const restrict ctrl)
2021-03-30 03:22:42 +02:00
{
2022-06-20 03:59:29 +02:00
const ushort
gi = get_global_id(0),
li = get_local_id(0),
ln = get_local_size(0),
cycle = ctrl->clk.cycle;
2021-03-30 03:22:42 +02:00
2022-06-20 03:59:29 +02:00
if(li == 0)
ctrl->prof.entered = ircd_simt_cycles();
2021-03-30 03:22:42 +02:00
}
2022-06-20 03:59:29 +02:00
__kernel void
ircd_gpt_lm_embed(__global const struct ircd_gpt_ctrl *const ctrl,
__constant const struct ircd_gpt_opts *const opts,
__global ircd_gpt_vectorv *const restrict accum,
__global const ircd_gpt_vectorv *const restrict pos,
__global const ircd_gpt_vectorv *const restrict vocab)
{
2022-06-20 03:59:29 +02:00
const ushort
li = get_local_id(0),
2022-06-20 03:59:29 +02:00
ln = get_local_size(0);
2022-06-20 03:59:29 +02:00
const uint
wo = get_global_offset(0);
2022-06-20 03:59:29 +02:00
assume(ln == 192);
assume(wo % ln == 0);
2022-06-20 03:59:29 +02:00
const ushort
wi = wo / ln + get_group_id(0);
_ircd_gpt_lm_embed(ctrl, opts, accum, pos, vocab, wi, wi, li);
}
static void
_ircd_gpt_lm_embed(__global const struct ircd_gpt_ctrl *const ctrl,
__constant const struct ircd_gpt_opts *const opts,
__global ircd_gpt_vectorv *const restrict accum,
__global const ircd_gpt_vectorv *const restrict pos,
__global const ircd_gpt_vectorv *const restrict vocab,
const ushort out_idx,
const ushort tok_idx,
const ushort elem_idx)
{
const ushort
token = ctrl->token[tok_idx];
2022-06-20 03:59:29 +02:00
const float4
wpe = pos[tok_idx].elem[elem_idx],
wte = vocab[token].elem[elem_idx],
res = wte + wpe;
accum[out_idx].elem[elem_idx] = res;
}
2022-06-20 03:59:29 +02:00
//
// Frontside
//
void
ircd_gpt_ffnn_fcon_tmul(__constant const struct ircd_gpt_opts *const opts,
__local ircd_gpt_ffnn_aperaturev *const restrict out,
__local const ircd_gpt_vectorv *const restrict in,
__global const ircd_gpt_ffnn_aperaturev *const restrict bias,
__global const ircd_gpt_ffnn_aperaturev *const restrict weight,
const uint li)
{
const uint
2022-06-20 03:59:29 +02:00
lanes = 4,
segs = ircd_gpt_ffnn_segs,
height = ircd_gpt_vector_elems / lanes;
2022-06-20 03:59:29 +02:00
assume(height > 0);
assume(height % lanes == 0);
2022-06-20 03:59:29 +02:00
for(uint x = 0; x < segs; ++x)
out->proj[x][li] = bias->proj[x][li];
2022-06-20 03:59:29 +02:00
for(uint y = 0; y < height; ++y)
for(uint k = 0; k < lanes; ++k)
for(uint x = 0; x < segs; ++x)
{
const uint
row = y * lanes + k;
2022-06-20 03:59:29 +02:00
out->proj[x][li] += in->elem[y][k] * weight[row].proj[x][li];
}
}
2022-06-20 03:59:29 +02:00
void
ircd_gpt_ffnn_fcon(__global const struct ircd_gpt_ctrl *const ctrl,
__constant const struct ircd_gpt_opts *const opts,
2022-06-20 03:59:29 +02:00
__local ircd_gpt_ffnn_aperaturev *const restrict out,
__local const ircd_gpt_vectorv *const restrict in,
__global const ircd_gpt_ffnn_aperaturev *const restrict bias,
__global const ircd_gpt_ffnn_aperaturev *const restrict weight,
const uint ln,
const uint li)
{
const uint
2022-06-20 03:59:29 +02:00
segs = ircd_gpt_ffnn_segs;
2021-03-30 03:22:42 +02:00
2022-06-20 03:59:29 +02:00
// Fully connected
ircd_gpt_ffnn_fcon_tmul
(
opts,
out,
in,
bias,
weight,
li
);
2022-06-20 03:59:29 +02:00
for(uint i = 0; i < segs; ++i)
ircd_gpt_ffnn_gelu(out, out, i * ln + li);
}
2022-06-20 03:59:29 +02:00
void
ircd_gpt_ffnn_proj_tmul(__constant const struct ircd_gpt_opts *const opts,
__local ircd_gpt_vectorv *const restrict out,
__local const ircd_gpt_ffnn_aperaturev *const restrict in,
__global const ircd_gpt_vectorv *const restrict bias,
__global const ircd_gpt_vectorv *const restrict weight,
const uint li)
2021-03-30 03:22:42 +02:00
{
const uint
2022-06-20 03:59:29 +02:00
lanes = 4,
height = ircd_gpt_ffnn_fcon_elems / lanes;
assume(height > 0);
assume(height % lanes == 0);
out->elem[li] = bias->elem[li];
for(uint y = 0; y < height; ++y)
for(uint k = 0; k < lanes; ++k)
{
const uint
row = y * lanes + k;
2021-03-30 03:22:42 +02:00
2022-06-20 03:59:29 +02:00
out->elem[li] += in->fcon[y][k] * weight[row].elem[li];
}
}
void
ircd_gpt_ffnn(__global const struct ircd_gpt_ctrl *const ctrl,
__constant const struct ircd_gpt_opts *const opts,
__local ircd_gpt_vectorv *const restrict token,
__local ircd_gpt_ffnn_aperaturev *const restrict buf,
__local ircd_gpt_vectorv *const restrict tmp,
__global const ircd_gpt_vectorv *const restrict norm_bias,
__global const ircd_gpt_vectorv *const restrict norm_weight,
__global const ircd_gpt_ffnn_aperaturev *const restrict fcon_bias,
__global const ircd_gpt_ffnn_aperaturev *const restrict fcon_weight,
__global const ircd_gpt_vectorv *const restrict proj_bias,
__global const ircd_gpt_vectorv *const restrict proj_weight,
const uint ln,
const uint li)
{
2021-03-30 03:22:42 +02:00
// Layer re-normalization
2022-06-20 03:59:29 +02:00
ircd_gpt_norm(token, token, tmp, norm_bias, norm_weight, ln, li);
2021-03-30 03:22:42 +02:00
// ln's writes are still pending but fcon reads results across threads.
barrier(CLK_LOCAL_MEM_FENCE);
2021-03-30 03:22:42 +02:00
// Fully connected
2022-06-20 03:59:29 +02:00
ircd_gpt_ffnn_fcon
(
ctrl,
opts,
buf,
token,
fcon_bias,
fcon_weight,
ln,
li
);
2021-04-02 22:01:38 +02:00
// fcon's writes are still pending but proj reads results across threads.
2021-04-02 22:01:38 +02:00
barrier(CLK_LOCAL_MEM_FENCE);
2021-03-30 03:22:42 +02:00
// Projection
2022-06-20 03:59:29 +02:00
ircd_gpt_ffnn_proj_tmul
(
opts,
token,
buf,
proj_bias,
proj_weight,
li
);
2021-03-30 03:22:42 +02:00
}
2022-06-20 03:59:29 +02:00
static void
ircd_gpt_attn_self_samax(__global const struct ircd_gpt_ctrl *const ctrl,
__constant const struct ircd_gpt_opts *const opts,
2022-06-20 03:59:29 +02:00
__local float self[restrict][12],
const uint ln,
const uint li,
const uint wn,
const uint wi)
{
struct ircd_math_samax samax =
{
.mu = -10000.0f,
.sum = 0.0f,
};
2022-06-20 03:59:29 +02:00
__attribute__((opencl_unroll_hint))
for(uint i = 0; i < wn; ++i)
samax.mu = max(samax.mu, self[i][li]);
2022-06-20 03:59:29 +02:00
__attribute__((opencl_unroll_hint))
for(uint i = 0; i < wn; ++i)
2022-06-20 03:59:29 +02:00
self[i][li] -= samax.mu;
for(uint i = 0; i < wn; ++i)
self[i][li] = native_exp(self[i][li]);
__attribute__((opencl_unroll_hint))
for(uint i = 0; i < wn; ++i)
samax.sum += self[i][li];
2022-06-20 03:59:29 +02:00
samax.sum += FLT_EPSILON;
samax.lambda = 1.0f / samax.sum;
__attribute__((opencl_unroll_hint))
for(uint i = 0; i < wn; ++i)
self[i][li] *= samax.lambda;
}
2022-06-20 03:59:29 +02:00
static void
ircd_gpt_attn_self_keys(__global const struct ircd_gpt_ctrl *const ctrl,
__constant const struct ircd_gpt_opts *const opts,
__local float self[restrict][ircd_gpt_attn_rank],
__global const ircd_gpt_attn_qkvv *const restrict token,
const uint ln,
const uint li,
const uint wi,
const uint kn,
const uint i)
{
assume(i < wi);
self[i][li] = 0.0f;
__attribute__((opencl_unroll_hint))
for(uint k = 0; k < kn; ++k)
{
float4
qry = token[wi].qry.attn[li][k],
key = token[i].key.attn[li][k],
res = qry * key;
self[i][li] += ircd_simt_hadd_f4(res);
2022-06-20 03:59:29 +02:00
}
self[i][li] /= 8.0f;
}
static void
ircd_gpt_attn_self_vals(__global const struct ircd_gpt_ctrl *const ctrl,
__constant const struct ircd_gpt_opts *const opts,
__local ircd_gpt_vectorv *const restrict out,
__local const float self[restrict][ircd_gpt_attn_rank],
__global const ircd_gpt_attn_qkvv *const restrict token,
const uint li,
const uint wi,
const uint ki,
const uint ti)
{
out->attn[ti][ki] = 0.0f;
__attribute__((opencl_unroll_hint))
for(uint i = 0; i < wi; ++i)
{
const float4
val = token[i].val.attn[ti][ki],
attn = self[i][ti],
res = attn * val;
out->attn[ti][ki] += res;
}
}
static void
ircd_gpt_attn_self(__global struct ircd_gpt_ctrl *const ctrl,
2021-04-02 22:01:38 +02:00
__constant const struct ircd_gpt_opts *const opts,
2022-06-20 03:59:29 +02:00
__local ircd_gpt_vectorv *const restrict out,
__local float self[restrict][ircd_gpt_attn_rank],
__global float attns[restrict][ircd_gpt_attn_rank],
__global const ircd_gpt_attn_qkvv *const restrict token,
const uint ln,
const uint li,
const uint wi)
2021-03-30 03:22:42 +02:00
{
2022-06-20 03:59:29 +02:00
//assume(opts->attn_rank == sizeof(self[0]) / sizeof(float));
assume(opts->attn_rank == ircd_gpt_attn_rank);
assume(ctrl->count < ircd_gpt_context_tokens);
assume(ctrl->tokens <= ircd_gpt_context_tokens);
assume(ctrl->tokens > wi);
assume(ctrl->tokens > 0);
2021-03-30 03:22:42 +02:00
const uint
2022-06-20 03:59:29 +02:00
wn = ctrl->tokens,
kn = ln / opts->attn_rank,
ki = li / opts->attn_rank,
2022-06-20 03:59:29 +02:00
ti = li % opts->attn_rank;
2021-03-30 03:22:42 +02:00
// Low-rank mask
if(li < opts->attn_rank)
{
2022-06-20 03:59:29 +02:00
// Left attention
uint i;
for(i = 0; i < wi; ++i)
ircd_gpt_attn_self_keys(ctrl, opts, self, token, ln, li, wi, kn, i);
2022-06-20 03:59:29 +02:00
// Future mask
__attribute__((opencl_unroll_hint))
while(i < wn)
self[i++][li] = -10000.0f;
2021-03-30 03:22:42 +02:00
// Three-piece softmax
2022-06-20 03:59:29 +02:00
ircd_gpt_attn_self_samax(ctrl, opts, self, ln, li, wn, wi);
}
2021-03-30 03:22:42 +02:00
// Propagate to full width for value dot prod.
barrier(CLK_LOCAL_MEM_FENCE);
2022-06-20 03:59:29 +02:00
ircd_gpt_attn_self_vals(ctrl, opts, out, self, token, li, wi, ki, ti);
2021-03-30 03:22:42 +02:00
2022-06-20 03:59:29 +02:00
// Save softmax results for later analysis/observation.
if(li < opts->attn_rank)
{
2022-06-20 03:59:29 +02:00
__attribute__((opencl_unroll_hint))
for(uint i = 0; i < wn; ++i)
attns[i][li] = self[i][li];
}
2021-03-30 03:22:42 +02:00
}
2022-06-20 03:59:29 +02:00
static void
ircd_gpt_attn_proj_tmul(__constant const struct ircd_gpt_opts *const opts,
__local ircd_gpt_vectorv *const restrict out,
__local const ircd_gpt_vectorv *const restrict in,
__global const ircd_gpt_vectorv *const restrict bias,
__global const ircd_gpt_vectorv *const restrict weight,
const uint li)
2021-03-30 03:22:42 +02:00
{
const uint
2022-06-20 03:59:29 +02:00
lanes = 4,
height = ircd_gpt_vector_elems / 4;
2021-04-02 22:01:38 +02:00
2022-06-20 03:59:29 +02:00
assume(height > 0);
assume(height % lanes == 0);
out->elem[li] = bias->elem[li];
for(uint y = 0; y < height; ++y)
for(uint k = 0; k < lanes; ++k)
{
const uint
row = y * lanes + k;
const float4
a = in->elem[y][k],
b = weight[row].elem[li];
out->elem[li] += a * b;
}
2021-03-30 03:22:42 +02:00
}
__kernel void
2022-06-20 03:59:29 +02:00
__attribute__((visibility("protected")))
ircd_gpt_coil(__global struct ircd_gpt_ctrl *const ctrl,
2021-04-02 22:01:38 +02:00
__constant const struct ircd_gpt_opts *const opts,
2022-06-20 03:59:29 +02:00
__private const uint layer,
__global ircd_gpt_vectorv *const restrict accum,
__global float attns[restrict][ircd_gpt_attn_rank],
__global const ircd_gpt_attn_qkvv *const restrict state,
__global const ircd_gpt_vectorv *const restrict attn_proj_bias,
__global const ircd_gpt_vectorv *const restrict attn_proj_weight,
__global const ircd_gpt_vectorv *const restrict ffnn_norm_bias,
__global const ircd_gpt_vectorv *const restrict ffnn_norm_weight,
__global const ircd_gpt_ffnn_aperaturev *const restrict ffnn_fcon_bias,
__global const ircd_gpt_ffnn_aperaturev *const restrict ffnn_fcon_weight,
__global const ircd_gpt_vectorv *const restrict ffnn_proj_bias,
__global const ircd_gpt_vectorv *const restrict ffnn_proj_weight)
2021-03-30 03:22:42 +02:00
{
const uint
li = get_local_id(0),
ln = get_local_size(0),
2022-06-20 03:59:29 +02:00
wo = get_global_offset(0),
wi = wo / ln + get_group_id(0);
2022-06-20 03:59:29 +02:00
assume(ln == 192);
assume(wo % ln == 0);
__local union
{
2022-06-20 03:59:29 +02:00
float
attn_self[ircd_gpt_context_tokens][ircd_gpt_attn_rank];
ircd_gpt_ffnn_aperaturev
ffnn_fcon[2];
2022-06-20 03:59:29 +02:00
ircd_gpt_vectorv
vector[8];
}
buf;
2022-06-20 03:59:29 +02:00
__local ircd_gpt_vectorv
buf0, buf1,
*const restrict attn_self = &buf1,
*const restrict token = &buf0,
*const restrict tmp = &buf1;
// Self-attention backend; this computes the self-attention result now
// that keys and values are globally visible across tokens.
2021-04-02 22:01:38 +02:00
ircd_gpt_attn_self
(
ctrl,
opts,
2022-06-20 03:59:29 +02:00
attn_self,
buf.attn_self,
2022-06-20 03:59:29 +02:00
attns,
state,
ln,
li,
wi
2021-04-02 22:01:38 +02:00
);
2021-03-30 03:22:42 +02:00
barrier(CLK_LOCAL_MEM_FENCE);
// Project result of self-attention.
2022-06-20 03:59:29 +02:00
ircd_gpt_attn_proj_tmul
2021-04-02 22:01:38 +02:00
(
opts,
2022-06-20 03:59:29 +02:00
token,
attn_self,
2021-04-02 22:01:38 +02:00
attn_proj_bias,
2022-06-20 03:59:29 +02:00
attn_proj_weight,
li
2021-04-02 22:01:38 +02:00
);
2021-03-30 03:22:42 +02:00
// Frontend accumulation
{
const float4
2022-06-20 03:59:29 +02:00
attn = token->elem[li],
resid = accum[wi].elem[li],
result = resid + attn;
2022-06-20 03:59:29 +02:00
token->elem[li] = result;
accum[wi].elem[li] = result;
}
barrier(CLK_GLOBAL_MEM_FENCE);
// Backend mlp
2021-04-02 22:01:38 +02:00
ircd_gpt_ffnn
(
ctrl,
opts,
2022-06-20 03:59:29 +02:00
token,
buf.ffnn_fcon,
tmp,
2021-04-02 22:01:38 +02:00
ffnn_norm_bias,
ffnn_norm_weight,
ffnn_fcon_bias,
ffnn_fcon_weight,
ffnn_proj_bias,
2022-06-20 03:59:29 +02:00
ffnn_proj_weight,
ln,
li
2021-04-02 22:01:38 +02:00
);
// Backend accumulation
{
const float4
2022-06-20 03:59:29 +02:00
ffnn = token->elem[li],
resid = accum[wi].elem[li],
result = resid + ffnn;
2022-06-20 03:59:29 +02:00
accum[wi].elem[li] = result;
}
}
2022-06-20 03:59:29 +02:00
static void
ircd_gpt_attn_fcon_tmul(__constant const struct ircd_gpt_opts *const opts,
__local ircd_gpt_attn_aperaturev *const restrict out,
__local const ircd_gpt_vectorv *const restrict in,
__global const ircd_gpt_attn_aperaturev *const restrict bias,
__global const ircd_gpt_attn_aperaturev *const restrict weight,
const uint ln,
const uint li)
{
const uint
lanes = 4,
segs = ircd_gpt_attn_segs,
height = ircd_gpt_vector_elems / lanes;
assume(height > 0);
assume(height % segs == 0);
assume(height % lanes == 0);
for(uint x = 0; x < segs; ++x)
out->proj[x][li] = bias->proj[x][li];
for(uint y = 0; y < height; ++y)
for(uint k = 0; k < lanes; ++k)
for(uint x = 0; x < segs; ++x)
{
const uint
row = y * lanes + k;
const float4
a = in->elem[y][k],
b = weight[row].proj[x][li];
out->proj[x][li] += a * b;
}
}
__kernel void
ircd_gpt_attn_fcon(__global const struct ircd_gpt_ctrl *const ctrl,
__constant const struct ircd_gpt_opts *const opts,
2022-06-20 03:59:29 +02:00
__private const uint layer,
__global ircd_gpt_attn_aperaturev *const restrict state,
__global const ircd_gpt_vectorv *const restrict accum,
__global const ircd_gpt_vectorv *const restrict norm_bias,
__global const ircd_gpt_vectorv *const restrict norm_weight,
__global const ircd_gpt_attn_aperaturev *const restrict fcon_bias,
__global const ircd_gpt_attn_aperaturev *const restrict fcon_weight)
{
const uint
li = get_local_id(0),
ln = get_local_size(0),
2022-06-20 03:59:29 +02:00
wo = get_global_offset(0),
wi = wo / ln + get_group_id(0),
segs = ircd_gpt_attn_segs;
assume(ln == 192);
assume(wo % ln == 0);
2022-06-20 03:59:29 +02:00
__local ircd_gpt_attn_aperaturev
attn;
2022-06-20 03:59:29 +02:00
__local ircd_gpt_vectorv
token, *const restrict tmp = attn.vector;
2022-06-20 03:59:29 +02:00
token.elem[li] = accum[wi].elem[li];
// Layer re-normalization
2022-06-20 03:59:29 +02:00
ircd_gpt_norm(&token, &token, tmp, norm_bias, norm_weight, ln, li);
// Ln's writes are still pending; fcon requires results across threads.
barrier(CLK_LOCAL_MEM_FENCE);
// Fully connected
2022-06-20 03:59:29 +02:00
ircd_gpt_attn_fcon_tmul
(
opts,
&attn,
&token,
fcon_bias,
fcon_weight,
ln,
li
);
// Export queries, keys, and values.
2022-06-20 03:59:29 +02:00
for(uint x = 0; x < segs; ++x)
state[wi].proj[x][li] = attn.proj[x][li];
2021-04-02 22:01:38 +02:00
}
__kernel void
ircd_gpt_lm_norm(__global const struct ircd_gpt_ctrl *const ctrl,
2021-04-02 22:01:38 +02:00
__constant const struct ircd_gpt_opts *const opts,
2022-06-20 03:59:29 +02:00
__global ircd_gpt_vectorv *const restrict accum,
__global const ircd_gpt_vectorv *const restrict norm_bias,
__global const ircd_gpt_vectorv *const restrict norm_weight)
2021-03-30 03:22:42 +02:00
{
const uint
li = get_local_id(0),
ln = get_local_size(0),
2022-06-20 03:59:29 +02:00
wo = get_global_offset(0),
wi = wo / ln + get_group_id(0);
assume(ln == 192);
assume(wo % ln == 0);
2021-03-30 03:22:42 +02:00
2022-06-20 03:59:29 +02:00
__local ircd_gpt_vectorv
tmp, token;
2021-03-30 03:22:42 +02:00
2022-06-20 03:59:29 +02:00
token.elem[li] = accum[wi].elem[li];
2021-03-30 03:22:42 +02:00
// Final re-normalization
2022-06-20 03:59:29 +02:00
ircd_gpt_norm(&token, &token, &tmp, norm_bias, norm_weight, ln, li);
2021-03-30 03:22:42 +02:00
2022-06-20 03:59:29 +02:00
accum[wi].elem[li] = token.elem[li];
2021-03-30 03:22:42 +02:00
}
__kernel void
ircd_gpt_lm_logit(__global const struct ircd_gpt_ctrl *const ctrl,
2021-04-02 22:01:38 +02:00
__constant const struct ircd_gpt_opts *const opts,
__global float *const restrict logit,
2022-06-20 03:59:29 +02:00
__global const ircd_gpt_vectorv *const restrict accum,
__global const ircd_gpt_vectorv *const restrict pos,
__global const ircd_gpt_vectorv *const restrict vocab)
2021-03-30 03:22:42 +02:00
{
const uint
gi = get_global_id(0),
2022-06-20 03:59:29 +02:00
wi = ctrl->count - 1;
2021-03-30 03:22:42 +02:00
2022-06-20 03:59:29 +02:00
assume(opts->embed_width == 192);
assume(opts->logits <= 65536);
if(gi >= opts->logits)
{
logit[gi] = -10000.0f;
return;
}
float acc = 0.0f;
for(uint j = 0; j < opts->embed_width; ++j)
2021-03-30 03:22:42 +02:00
{
const float4
2022-06-20 03:59:29 +02:00
token = vocab[gi].elem[j],
in = accum[wi].elem[j],
wpe = pos[wi].elem[j],
res = in * token + wpe;
2021-03-30 03:22:42 +02:00
acc += ircd_simt_hadd_f4(res);
2021-03-30 03:22:42 +02:00
}
2022-06-20 03:59:29 +02:00
logit[gi] = acc;
}
__kernel void
2022-06-20 03:59:29 +02:00
__attribute__((reqd_work_group_size(256, 1, 1)))
ircd_gpt_lm_logsm(__global struct ircd_gpt_ctrl *const ctrl,
__constant const struct ircd_gpt_opts *const opts,
2022-06-20 03:59:29 +02:00
__global float logit[restrict 65536])
{
const uint
li = get_local_id(0),
ln = get_local_size(0),
2022-06-20 03:59:29 +02:00
//wo = get_global_offset(0),
//wi = wo / ln + get_group_id(0),
wn = 50432,
tn = wn / ln,
start = tn * li,
stop = min(start + tn, opts->logits);
2022-06-20 03:59:29 +02:00
__local float
mu[256], sum[256], lambda[256];
2022-06-20 03:59:29 +02:00
__local struct ircd_math_samax
samax;
2022-06-20 03:59:29 +02:00
assume(ln == 256);
2022-06-20 03:59:29 +02:00
mu[li] = -10000.0f;
__attribute__((opencl_unroll_hint))
for(uint ti = start; ti < stop; ++ti)
mu[li] = max(mu[li], logit[ti]);
2022-06-20 03:59:29 +02:00
ircd_simt_reduce_max_flldr(mu, ln, li);
2022-06-20 03:59:29 +02:00
if(li == 0)
samax.mu = mu[li];
ircd_simt_broadcast_flldr(mu, ln, li);
2022-06-20 03:59:29 +02:00
sum[li] = 0.0f;
for(uint ti = start; ti < stop; ++ti)
{
2022-06-20 03:59:29 +02:00
const float
sub = logit[ti] - mu[li],
2022-06-20 03:59:29 +02:00
res = native_exp(sub);
2022-06-20 03:59:29 +02:00
sum[li] += res;
}
2022-06-20 03:59:29 +02:00
ircd_simt_reduce_add_flldr(sum, ln, li);
if(li == 0)
2022-06-20 03:59:29 +02:00
sum[li] += FLT_EPSILON,
samax.sum = sum[li],
samax.lambda = lambda[li] = 1.0f / sum[li];
ircd_simt_broadcast_flldr(lambda, ln, li);
for(uint ti = start; ti < stop; ++ti)
{
const float
sub = logit[ti] - mu[li],
2022-06-20 03:59:29 +02:00
res = lambda[li] * native_exp(sub);
2022-06-20 03:59:29 +02:00
logit[ti] = res;
}
2021-03-30 03:22:42 +02:00
}
2022-06-20 03:59:29 +02:00
void
ircd_gpt_lm_result_top(__local struct ircd_gpt_ctrl *const ctrl,
__constant const struct ircd_gpt_opts *const opts,
__local const ushort *const restrict idx,
__global const float *const restrict logsm,
const uint i)
2021-04-02 22:01:38 +02:00
{
const ushort
token = idx[i];
2021-04-02 22:01:38 +02:00
const float
samax = logsm[token];
2021-04-02 22:01:38 +02:00
ctrl->top[i].token = token;
ctrl->top[i].samax = samax;
2021-04-02 22:01:38 +02:00
}
2022-06-20 03:59:29 +02:00
void
ircd_gpt_lm_result_label_mean(__local struct ircd_gpt_ctrl *const ctrl,
__constant const struct ircd_gpt_opts *const opts,
__local struct ircd_math_mean *const mean,
const float last)
2021-04-02 22:01:38 +02:00
{
2022-06-20 03:59:29 +02:00
const uint
div = mean->div + 1,
sum_sel = mean->div % 4;
2021-04-02 22:01:38 +02:00
const float
2022-06-20 03:59:29 +02:00
sum = mean->sum[0] + mean->sum[1] + mean->sum[2] + mean->sum[3] + last,
res = sum / div;
2022-06-20 03:59:29 +02:00
mean->sum[sum_sel] += last;
mean->div = div;
mean->last = last;
mean->mean = res;
}
2022-06-20 03:59:29 +02:00
void
ircd_gpt_lm_result_label(__local struct ircd_gpt_ctrl *const ctrl,
__constant const struct ircd_gpt_opts *const opts,
__local struct ircd_gpt_ctrl_label *const label,
__global const float *const restrict logsm)
{
const ushort
token = label->logit.token;
2022-06-20 03:59:29 +02:00
const float
samax = logsm[token],
2022-06-20 03:59:29 +02:00
loss = 0.0f - native_log(samax),
ppl = (1.0f - samax) * native_log2(opts->logits);
2022-06-20 03:59:29 +02:00
label->logit.samax = samax;
ircd_gpt_lm_result_label_mean(ctrl, opts, &label->loss, loss);
ircd_gpt_lm_result_label_mean(ctrl, opts, &label->ppl, ppl);
}
2022-06-20 03:59:29 +02:00
ushort
ircd_gpt_lm_result_select(__local struct ircd_gpt_ctrl *const ctrl,
__constant const struct ircd_gpt_opts *const opts,
__local const ushort *const restrict idx,
2022-06-20 03:59:29 +02:00
__global const float *const restrict logsm)
{
2021-04-02 22:01:38 +02:00
const ulong
2022-06-20 03:59:29 +02:00
ent_k = max(opts->top_k, 1U) - 1,
rnd = ircd_simt_rand_xoshiro256pl(ctrl->rand);
const float
2022-06-20 03:59:29 +02:00
ent_p = min(max(opts->top_p, 0.0f), 1.0f),
thresh = ent_p;
2022-06-20 03:59:29 +02:00
float acc = 1.0f;
ushort select = 0;
2022-06-20 03:59:29 +02:00
for(; select < ent_k; ++select)
if((acc -= logsm[idx[select]]) < thresh)
break;
2021-04-02 22:01:38 +02:00
const ushort
2022-06-20 03:59:29 +02:00
token = idx[select];
2021-04-02 22:01:38 +02:00
2022-06-20 03:59:29 +02:00
return token;
}
2022-06-20 03:59:29 +02:00
static ushort
ircd_gpt_lm_result(__local struct ircd_gpt_ctrl *const ctrl,
__constant const struct ircd_gpt_opts *const opts,
__local const ushort *const restrict idx,
__global const float *const restrict logsm)
{
2022-06-20 03:59:29 +02:00
const ushort
token = ircd_gpt_lm_result_select(ctrl, opts, idx, logsm);
// Update the dynamic result label.
ctrl->select.logit.token = token;
ircd_gpt_lm_result_label(ctrl, opts, &ctrl->select, logsm);
2022-06-20 03:59:29 +02:00
// Update the dynamic target label.
ctrl->target.logit.token = ctrl->count < ctrl->tokens?
ctrl->token[ctrl->count]:
ctrl->select.logit.token;
ircd_gpt_lm_result_label(ctrl, opts, &ctrl->target, logsm);
const bool
hit = ctrl->select.logit.token == ctrl->target.logit.token;
// Update the token context.
if(ctrl->count == ctrl->tokens)
ctrl->token[ctrl->tokens++] = ctrl->select.logit.token;
else
ctrl->accept = -2;
2022-06-20 03:59:29 +02:00
ctrl->miss += !hit;
ctrl->hit += hit;
ctrl->count++;
return token;
}
static void
ircd_gpt_lm_result_attns(__local struct ircd_gpt_ctrl *const ctrl,
__constant const struct ircd_gpt_opts *const opts,
__global const float *const restrict attns,
const uint ln,
const uint li)
{
const uint
layer = li / opts->layers,
head = li % opts->attn_rank,
base = layer * opts->attn_self_elems;
uint best = 0;
float bestv = 10000.0f;
for(uint i = 0; i < ctrl->count; ++i)
{
const uint
bx = (((i + 1) * i) / 2) * opts->attn_rank,
idx = base + bx + i * 12 + head;
if(attns[idx] < bestv)
bestv = attns[idx],
best = i;
}
ctrl->attn[layer][head] = best;
2021-04-02 22:01:38 +02:00
}
2021-03-30 03:22:42 +02:00
__kernel void
2022-06-20 03:59:29 +02:00
__attribute__((visibility("protected")))
2022-09-24 22:49:40 +02:00
__attribute__((reqd_work_group_size(256, 1, 1)))
2022-06-20 03:59:29 +02:00
ircd_gpt_lm_select(__global struct ircd_gpt_ctrl *const restrict ctrl_,
2021-04-02 22:01:38 +02:00
__constant const struct ircd_gpt_opts *const opts,
2022-06-20 03:59:29 +02:00
__global const float logsm[restrict 65536],
__global const float *const restrict attns)
2021-03-30 03:22:42 +02:00
{
const uint
li = get_local_id(0),
ln = get_local_size(0),
2022-06-20 03:59:29 +02:00
logits_pad = ln - (opts->logits % ln),
tn = (opts->logits + logits_pad) / ln,
start = tn * li,
stop = min(start + tn, opts->logits);
2021-03-30 03:22:42 +02:00
__local ushort idx[256];
2022-06-20 03:59:29 +02:00
__local struct ircd_gpt_ctrl ctrl;
__private event_t event[1];
2021-03-30 03:22:42 +02:00
2022-06-20 03:59:29 +02:00
assume(ln == 256);
assume(start < stop);
event[0] = async_work_group_copy
(
(__local char16 *)&ctrl,
(__global const char16 *)ctrl_,
sizeof(struct ircd_gpt_ctrl) / sizeof(char16),
0
);
idx[li] = start;
for(uint j = start + 1; j < stop; ++j)
if(logsm[j] > logsm[idx[li]])
2021-03-30 03:22:42 +02:00
idx[li] = j;
2022-06-20 03:59:29 +02:00
ircd_simt_sort_idx16_flldr(idx, logsm, ln, li);
wait_group_events(1, event);
if(ctrl.count >= opts->buffer_tokens)
return;
if(li < opts->top_n)
2022-06-20 03:59:29 +02:00
ircd_gpt_lm_result_top(&ctrl, opts, idx, logsm, li);
if(li < opts->labels)
2022-06-20 03:59:29 +02:00
ircd_gpt_lm_result_label(&ctrl, opts, ctrl.label + li, logsm);
if(li < opts->layers * opts->attn_rank)
ircd_gpt_lm_result_attns(&ctrl, opts, attns, ln, li);
barrier(CLK_LOCAL_MEM_FENCE);
if(li == 0)
2022-06-20 03:59:29 +02:00
ircd_gpt_lm_result(&ctrl, opts, idx, logsm);
2022-06-20 03:59:29 +02:00
barrier(CLK_LOCAL_MEM_FENCE);
event[0] = async_work_group_copy
(
(__global char16 *)ctrl_,
(__local const char16 *)&ctrl,
sizeof(struct ircd_gpt_ctrl) / sizeof(char16),
0
);
2022-06-20 03:59:29 +02:00
wait_group_events(1, event);
}
__kernel void
__attribute__((visibility("protected")))
__attribute__((reqd_work_group_size(256, 1, 1)))
ircd_gpt_leave(__global const void *const restrict model,
__global void *const restrict state,
__global void *const restrict master,
__constant const struct ircd_gpt_opts *const opts,
__global struct ircd_gpt_ctrl *const ctrl_,
__global struct ircd_gpt_ctrl *const frame)
{
const ushort
li = get_local_id(0),
ln = get_local_size(0);
assume(ln == 256);
__local struct ircd_gpt_ctrl _ctrl;
__local struct ircd_gpt_ctrl *const ctrl = &_ctrl;
if(li == 0)
*ctrl = *ctrl_;
barrier(CLK_LOCAL_MEM_FENCE);
if(li == 0 && ctrl->accept < 0)
ircd_gpt_accept(ctrl, opts);
barrier(CLK_LOCAL_MEM_FENCE);
const uint
batch_size = opts->batch_size,
samps = opts->training_steps + opts->validation_steps + opts->testing_steps,
steps = samps / batch_size;
const bool
accepting = ctrl->accept >= 0,
cycling = !accepting,
sampling = accepting,
stepping = sampling && (ctrl->clk.samp + 1) >= batch_size,
epoching = stepping && (ctrl->clk.step + 1) >= steps;
if(li == 0)
ctrl->prof.finished = ircd_simt_cycles();
2022-06-20 03:59:29 +02:00
if(li == 0)
{
// BARTS won't update the scalar after the copy. In this case we'll
// be setting magic for all remaining frames :/
#if defined(__R600__)
ctrl->magic = 0xC7012C70UL;
#endif
2022-06-20 03:59:29 +02:00
*frame = *ctrl;
frame->magic = 0xC7012C70UL;
}
2022-06-20 03:59:29 +02:00
if(li == 0 && !accepting)
2022-06-20 03:59:29 +02:00
{
ctrl->clk.cycle += cycling;
ctrl->clk.samp += sampling;
ctrl->clk.step += stepping;
ctrl->clk.epoch += epoching;
}
if(li == 0 && accepting)
ctrl->magic = 0xC7012C70UL;
if(li == 0)
*ctrl_ = *ctrl;
2022-06-20 03:59:29 +02:00
}
void
ircd_gpt_accept(__local struct ircd_gpt_ctrl *const ctrl,
__constant const struct ircd_gpt_opts *const opts)
{
const bool
unlimited = opts->limit < 0,
unprocessed = ctrl->accept == -2,
acceptable = ctrl->accept < 0 && !unprocessed;
2022-06-20 03:59:29 +02:00
const uint
batch_size = opts->batch_size,
samps = opts->training_steps + opts->validation_steps + opts->testing_steps,
steps = samps / batch_size,
unproc = ctrl->tokens - ctrl->count,
limit = opts->limit;
2022-06-20 03:59:29 +02:00
const int
cycle_remain = opts->context_tokens - (ctrl->clk.cycle + 1), // cycle not yet incr
2022-06-20 03:59:29 +02:00
token_remain = opts->context_tokens - ctrl->count, // but count already incr
remain = min(cycle_remain, token_remain);
2022-06-20 03:59:29 +02:00
int
accept = ircd_gpt_accept_check(ctrl, opts),
dispatch = accept < 0? min(abs(accept), (uint)remain): 0;
2022-06-20 03:59:29 +02:00
if(opts->limit > 0 && ctrl->clk.cycle >= limit - 1)
{
accept = accept >= 0? accept: -1;
dispatch = 0;
}
if(opts->limit == 0 && unprocessed)
{
accept = accept >= 0? accept: -1;
dispatch = min((uint)remain, unproc);
}
if(opts->limit != 0 && !acceptable)
{
accept = accept >= 0? accept: -1;
dispatch = max(dispatch, 1);
}
2022-06-20 03:59:29 +02:00
ctrl->accept = accept;
ctrl->dispatch = min((uint)dispatch, opts->frames);
2022-06-20 03:59:29 +02:00
}
int
ircd_gpt_accept_check(__local struct ircd_gpt_ctrl *const ctrl,
__constant const struct ircd_gpt_opts *const opts)
{
int best = opts->frames;
2022-06-20 03:59:29 +02:00
for(uint i = 0; i < 4; ++i)
{
const int
remain = ircd_gpt_accept_match(ctrl, opts, i);
if(remain == 0)
return i;
if(remain < best)
best = remain;
}
return -best;
}
uint
ircd_gpt_accept_match(__local struct ircd_gpt_ctrl *const ctrl,
__constant const struct ircd_gpt_opts *const opts,
const uint i)
{
const uint
len = ircd_gpt_accept_len(ctrl, opts, i),
n = min(ctrl->count, len),
maxlen = opts->frames;
2022-06-20 03:59:29 +02:00
uint ret = len?: maxlen;
for(uint j = 1; j <= n; ++j)
{
uint match = 0;
for(; match < j; ++match)
{
const uint
accept = opts->accept[i][match],
token = ctrl->token[ctrl->count - j + match];
if(token != accept)
break;
}
if(match >= j)
if(!(ret = len - match))
break;
}
ret = max(ret, ctrl->tokens - ctrl->count);
ret = min(ret, maxlen);
return ret;
}
uint
ircd_gpt_accept_len(__local struct ircd_gpt_ctrl *const ctrl,
__constant const struct ircd_gpt_opts *const opts,
const uint i)
{
uint len = 0;
for(; len < 8; ++len)
if(opts->accept[i][len] == (ushort)-1U)
break;
return len;
2021-03-30 03:22:42 +02:00
}
2021-04-17 20:59:30 +02:00
//
2022-06-20 03:59:29 +02:00
// backside
2021-04-17 20:59:30 +02:00
//
2022-06-20 03:59:29 +02:00
void
__attribute__((always_inline))
ircd_gpt_prop_elem(__global const struct ircd_gpt_ctrl *const ctrl,
2021-04-17 20:59:30 +02:00
__constant const struct ircd_gpt_opts *const opts,
__global float4 *const restrict param_,
__global float4 *const restrict exp_avg_,
__global float4 *const restrict exp_avg_sqr_)
{
const uint
li = get_local_id(0),
2022-06-20 03:59:29 +02:00
ts = ctrl->clk.step;
2021-04-17 20:59:30 +02:00
const float4
param = param_[li],
grad = ctrl->target.loss.mean,
2021-04-17 20:59:30 +02:00
alpha[2] = { 1.0f - opts->beta[0], 1.0f - opts->beta[1], },
2022-06-20 03:59:29 +02:00
exp_avg = ts? exp_avg_[li]: 0.0f,
exp_avg_sqr = ts? exp_avg_sqr_[li]: 0.0f,
2021-04-17 20:59:30 +02:00
exp_avg_mul = exp_avg * opts->beta[0],
exp_avg_dot = exp_avg_mul + alpha[0] * grad,
exp_avg_sqr_mul = exp_avg_sqr * opts->beta[1],
exp_avg_sqr_dot = exp_avg_sqr_mul + alpha[1] * grad * grad,
2022-06-20 03:59:29 +02:00
denom = native_sqrt(exp_avg_sqr_dot) + FLT_EPSILON,
2021-04-17 20:59:30 +02:00
delta = opts->alpha * (exp_avg_dot / denom),
update = param - delta;
param_[li] = update;
exp_avg_[li] = exp_avg_dot;
exp_avg_sqr_[li] = exp_avg_sqr_dot;
2021-04-17 20:59:30 +02:00
}
2022-06-20 03:59:29 +02:00
//
// backpropagations
//
2021-04-17 20:59:30 +02:00
__kernel void
2022-06-20 03:59:29 +02:00
__attribute__((always_inline))
ircd_gpt_norm_prop(__global const struct ircd_gpt_ctrl *const ctrl,
2021-04-17 20:59:30 +02:00
__constant const struct ircd_gpt_opts *const opts,
2022-06-20 03:59:29 +02:00
__global ircd_gpt_vectorv *const restrict bias,
__global ircd_gpt_vectorv *const restrict bias_m0,
__global ircd_gpt_vectorv *const restrict bias_m1,
__global ircd_gpt_vectorv *const restrict weight,
__global ircd_gpt_vectorv *const restrict weight_m0,
__global ircd_gpt_vectorv *const restrict weight_m1)
2021-04-17 20:59:30 +02:00
{
ircd_gpt_prop_elem
(
ctrl, opts,
2022-06-20 03:59:29 +02:00
bias->elem,
bias_m0->elem,
bias_m1->elem
2021-04-17 20:59:30 +02:00
);
ircd_gpt_prop_elem
(
ctrl, opts,
2022-06-20 03:59:29 +02:00
weight->elem,
weight_m0->elem,
weight_m1->elem
2021-04-17 20:59:30 +02:00
);
}
__kernel void
ircd_gpt_coil_prop_attn(__global const struct ircd_gpt_ctrl *const ctrl,
2021-04-17 20:59:30 +02:00
__constant const struct ircd_gpt_opts *const opts,
2022-06-20 03:59:29 +02:00
__global ircd_gpt_vectorv *const restrict norm_bias,
__global ircd_gpt_vectorv *const restrict norm_bias_m0,
__global ircd_gpt_vectorv *const restrict norm_bias_m1,
__global ircd_gpt_vectorv *const restrict norm_weight,
__global ircd_gpt_vectorv *const restrict norm_weight_m0,
__global ircd_gpt_vectorv *const restrict norm_weight_m1,
__global ircd_gpt_attn_aperaturev *const restrict fcon_bias,
__global ircd_gpt_attn_aperaturev *const restrict fcon_bias_m0,
__global ircd_gpt_attn_aperaturev *const restrict fcon_bias_m1,
__global ircd_gpt_attn_aperaturev *const restrict fcon_weight,
__global ircd_gpt_attn_aperaturev *const restrict fcon_weight_m0,
__global ircd_gpt_attn_aperaturev *const restrict fcon_weight_m1,
__global ircd_gpt_vectorv *const restrict proj_bias,
__global ircd_gpt_vectorv *const restrict proj_bias_m0,
__global ircd_gpt_vectorv *const restrict proj_bias_m1,
__global ircd_gpt_vectorv *const restrict proj_weight,
__global ircd_gpt_vectorv *const restrict proj_weight_m0,
__global ircd_gpt_vectorv *const restrict proj_weight_m1)
2021-04-17 20:59:30 +02:00
{
2022-06-20 03:59:29 +02:00
const uint
fcon_height = opts->embed_elems,
proj_height = opts->embed_elems,
segs = 3;
2021-04-17 20:59:30 +02:00
ircd_gpt_norm_prop
(
ctrl, opts,
norm_bias,
norm_bias_m0,
norm_bias_m1,
norm_weight,
norm_weight_m0,
norm_weight_m1
);
2022-06-20 03:59:29 +02:00
for(uint j = 0; j < segs; ++j)
2021-04-17 20:59:30 +02:00
ircd_gpt_prop_elem
(
ctrl, opts,
fcon_bias->proj[j],
fcon_bias_m0->proj[j],
fcon_bias_m1->proj[j]
);
2022-06-20 03:59:29 +02:00
for(uint i = 0; i < fcon_height; ++i)
for(uint j = 0; j < segs; ++j)
2021-04-17 20:59:30 +02:00
ircd_gpt_prop_elem
(
ctrl, opts,
fcon_weight[i].proj[j],
fcon_weight_m0[i].proj[j],
fcon_weight_m1[i].proj[j]
);
ircd_gpt_prop_elem
(
ctrl, opts,
2022-06-20 03:59:29 +02:00
proj_bias->elem,
proj_bias_m0->elem,
proj_bias_m1->elem
2021-04-17 20:59:30 +02:00
);
2022-06-20 03:59:29 +02:00
for(uint i = 0; i < proj_height; ++i)
2021-04-17 20:59:30 +02:00
ircd_gpt_prop_elem
(
ctrl, opts,
2022-06-20 03:59:29 +02:00
proj_weight[i].elem,
proj_weight_m0[i].elem,
proj_weight_m1[i].elem
2021-04-17 20:59:30 +02:00
);
}
__kernel void
ircd_gpt_coil_prop_ffnn(__global const struct ircd_gpt_ctrl *const ctrl,
2021-04-17 20:59:30 +02:00
__constant const struct ircd_gpt_opts *const opts,
2022-06-20 03:59:29 +02:00
__global ircd_gpt_vectorv *const restrict norm_bias,
__global ircd_gpt_vectorv *const restrict norm_bias_m0,
__global ircd_gpt_vectorv *const restrict norm_bias_m1,
__global ircd_gpt_vectorv *const restrict norm_weight,
__global ircd_gpt_vectorv *const restrict norm_weight_m0,
__global ircd_gpt_vectorv *const restrict norm_weight_m1,
__global ircd_gpt_ffnn_aperaturev *const restrict fcon_bias,
__global ircd_gpt_ffnn_aperaturev *const restrict fcon_bias_m0,
__global ircd_gpt_ffnn_aperaturev *const restrict fcon_bias_m1,
__global ircd_gpt_ffnn_aperaturev *const restrict fcon_weight,
__global ircd_gpt_ffnn_aperaturev *const restrict fcon_weight_m0,
__global ircd_gpt_ffnn_aperaturev *const restrict fcon_weight_m1,
__global ircd_gpt_vectorv *const restrict proj_bias,
__global ircd_gpt_vectorv *const restrict proj_bias_m0,
__global ircd_gpt_vectorv *const restrict proj_bias_m1,
__global ircd_gpt_vectorv *const restrict proj_weight,
__global ircd_gpt_vectorv *const restrict proj_weight_m0,
__global ircd_gpt_vectorv *const restrict proj_weight_m1)
2021-04-17 20:59:30 +02:00
{
2022-06-20 03:59:29 +02:00
const uint
fcon_height = opts->embed_elems,
proj_height = opts->ffnn_elems,
segs = 4;
2021-04-17 20:59:30 +02:00
ircd_gpt_norm_prop
(
ctrl, opts,
norm_bias,
norm_bias_m0,
norm_bias_m1,
norm_weight,
norm_weight_m0,
norm_weight_m1
);
2022-06-20 03:59:29 +02:00
for(uint j = 0; j < segs; ++j)
2021-04-17 20:59:30 +02:00
ircd_gpt_prop_elem
(
ctrl, opts,
fcon_bias->proj[j],
fcon_bias_m0->proj[j],
fcon_bias_m1->proj[j]
);
2022-06-20 03:59:29 +02:00
for(uint i = 0; i < fcon_height; ++i)
for(uint j = 0; j < segs; ++j)
2021-04-17 20:59:30 +02:00
ircd_gpt_prop_elem
(
ctrl, opts,
fcon_weight[i].proj[j],
fcon_weight_m0[i].proj[j],
fcon_weight_m1[i].proj[j]
);
ircd_gpt_prop_elem
(
ctrl, opts,
2022-06-20 03:59:29 +02:00
proj_bias->elem,
proj_bias_m0->elem,
proj_bias_m1->elem
2021-04-17 20:59:30 +02:00
);
2022-06-20 03:59:29 +02:00
for(uint i = 0; i < proj_height; ++i)
2021-04-17 20:59:30 +02:00
ircd_gpt_prop_elem
(
ctrl, opts,
2022-06-20 03:59:29 +02:00
proj_weight[i].elem,
proj_weight_m0[i].elem,
proj_weight_m1[i].elem
2021-04-17 20:59:30 +02:00
);
}
__kernel void
ircd_gpt_lm_embed_prop(__global const struct ircd_gpt_ctrl *const ctrl,
2021-04-17 20:59:30 +02:00
__constant const struct ircd_gpt_opts *const opts,
2022-06-20 03:59:29 +02:00
__global ircd_gpt_vectorv *const restrict pos,
__global ircd_gpt_vectorv *const restrict pos_m0,
__global ircd_gpt_vectorv *const restrict pos_m1,
__global ircd_gpt_vectorv *const restrict token,
__global ircd_gpt_vectorv *const restrict token_m0,
__global ircd_gpt_vectorv *const restrict token_m1)
2021-04-17 20:59:30 +02:00
{
const uint
ln = get_local_size(0),
wi = get_global_offset(0) / ln + get_group_id(0),
2022-06-20 03:59:29 +02:00
wn = ctrl->count,
2021-04-17 20:59:30 +02:00
cn = opts->context_tokens / wn,
ci = cn * wi,
tn = opts->logits / wn,
ti = tn * wi;
for(uint i = ci; i < ci + cn; ++i)
ircd_gpt_prop_elem
(
ctrl, opts,
2022-06-20 03:59:29 +02:00
pos[i].elem,
pos_m0[i].elem,
pos_m1[i].elem
2021-04-17 20:59:30 +02:00
);
for(uint i = ti; i < ti + tn; ++i)
ircd_gpt_prop_elem
(
ctrl, opts,
2022-06-20 03:59:29 +02:00
token[i].elem,
token_m0[i].elem,
token_m1[i].elem
2021-04-17 20:59:30 +02:00
);
}
2022-06-20 03:59:29 +02:00
/// Gaussian Error Linear Unit
void
ircd_gpt_ffnn_gelu(__local ircd_gpt_ffnn_aperaturev *const out,
__local const ircd_gpt_ffnn_aperaturev *const in_,
const uint i)
{
const float4
in = in_->fcon[i];
float4 a;
a = 0.044715f;
a *= in;
a *= in;
a += 1.0f;
a *= 0.7978845608f;
a *= in;
a = tanh(a);
a += 1.0f;
a *= in;
a *= 0.5f;
out->fcon[i] = a;
}
void
ircd_gpt_norm(__local ircd_gpt_vectorv *const out,
__local const ircd_gpt_vectorv *const in,
__local ircd_gpt_vectorv *const restrict tmp,
__global const ircd_gpt_vectorv *const restrict bias,
__global const ircd_gpt_vectorv *const restrict weight,
const uint ln,
const uint li)
{
// Layer re-normalization
ircd_simt_math_norm_f4lldr(out->elem, in->elem, tmp->elem, ln, li);
ircd_gpt_norm_fmad(out, out, bias, weight, li);
}
void
ircd_gpt_norm_fmad(__local ircd_gpt_vectorv *const out,
__local const ircd_gpt_vectorv *const in,
__global const ircd_gpt_vectorv *const restrict bias,
__global const ircd_gpt_vectorv *const restrict weight,
const uint i)
{
out->elem[i] = in->elem[i] * weight->elem[i] + bias->elem[i];
}