2021-03-30 03:22:42 +02:00
|
|
|
// Matrix Construct
|
|
|
|
//
|
|
|
|
// Copyright (C) Matrix Construct Developers, Authors & Contributors
|
|
|
|
// Copyright (C) 2016-2021 Jason Volk <jason@zemos.net>
|
|
|
|
//
|
|
|
|
// Permission to use, copy, modify, and/or distribute this software for any
|
|
|
|
// purpose with or without fee is hereby granted, provided that the above
|
|
|
|
// copyright notice and this permission notice is present in all copies. The
|
|
|
|
// full license for this software is available in the LICENSE file.
|
|
|
|
|
2022-06-20 03:59:29 +02:00
|
|
|
#pragma clang fp exceptions(ignore)
|
|
|
|
#pragma clang fp reassociate(on)
|
|
|
|
#pragma clang fp contract(fast)
|
|
|
|
|
|
|
|
#include <ircd/config.h>
|
2022-10-09 02:56:58 +02:00
|
|
|
#include <ircd/portable.h>
|
2022-06-20 03:59:29 +02:00
|
|
|
#include <clc/clc.h>
|
2021-08-28 03:02:33 +02:00
|
|
|
#include <ircd/simt/simt.h>
|
2022-06-20 03:59:29 +02:00
|
|
|
#include <ircd/gpt/vector.h>
|
2021-09-18 06:08:10 +02:00
|
|
|
#include <ircd/gpt/opts.h>
|
|
|
|
#include <ircd/gpt/ctrl.h>
|
2022-06-20 03:59:29 +02:00
|
|
|
#include <ircd/gpt/gpu.h>
|
|
|
|
|
|
|
|
//
|
|
|
|
// head
|
|
|
|
//
|
|
|
|
|
|
|
|
__kernel void
|
|
|
|
__attribute__((visibility("protected")))
|
|
|
|
ircd_gpt_alloc(__global const void *const restrict model,
|
|
|
|
__global void *const restrict master,
|
|
|
|
__constant const void *const opts,
|
|
|
|
__global void *const restrict ctrl,
|
|
|
|
__global void *const restrict frame0,
|
|
|
|
__global void *const restrict frame1,
|
|
|
|
__global void *const restrict frame2,
|
|
|
|
__global void *const restrict frame3,
|
|
|
|
__global void *const restrict frame4,
|
|
|
|
__global void *const restrict frame5,
|
|
|
|
__global void *const restrict frame6,
|
|
|
|
__global void *const restrict frame7)
|
2021-03-30 03:22:42 +02:00
|
|
|
{
|
|
|
|
}
|
|
|
|
|
2022-06-20 03:59:29 +02:00
|
|
|
__kernel void
|
|
|
|
__attribute__((visibility("protected")))
|
|
|
|
ircd_gpt_enter(__global const void *const restrict model,
|
|
|
|
__global void *const restrict state,
|
|
|
|
__global void *const restrict master,
|
|
|
|
__constant const struct ircd_gpt_opts *const opts,
|
|
|
|
__global struct ircd_gpt_ctrl *const restrict ctrl)
|
2021-03-30 03:22:42 +02:00
|
|
|
{
|
2022-06-20 03:59:29 +02:00
|
|
|
const ushort
|
|
|
|
gi = get_global_id(0),
|
|
|
|
li = get_local_id(0),
|
|
|
|
ln = get_local_size(0),
|
|
|
|
cycle = ctrl->clk.cycle;
|
2021-03-30 03:22:42 +02:00
|
|
|
|
2022-06-20 03:59:29 +02:00
|
|
|
if(li == 0)
|
2022-10-11 05:01:08 +02:00
|
|
|
ctrl->prof.entered = ircd_simt_cycles();
|
2021-03-30 03:22:42 +02:00
|
|
|
}
|
|
|
|
|
2022-06-20 03:59:29 +02:00
|
|
|
__kernel void
|
|
|
|
ircd_gpt_lm_embed(__global const struct ircd_gpt_ctrl *const ctrl,
|
|
|
|
__constant const struct ircd_gpt_opts *const opts,
|
|
|
|
__global ircd_gpt_vectorv *const restrict accum,
|
|
|
|
__global const ircd_gpt_vectorv *const restrict pos,
|
|
|
|
__global const ircd_gpt_vectorv *const restrict vocab)
|
2021-04-27 05:43:21 +02:00
|
|
|
{
|
2022-06-20 03:59:29 +02:00
|
|
|
const ushort
|
2021-04-27 05:43:21 +02:00
|
|
|
li = get_local_id(0),
|
2022-06-20 03:59:29 +02:00
|
|
|
ln = get_local_size(0);
|
2021-04-27 05:43:21 +02:00
|
|
|
|
2022-06-20 03:59:29 +02:00
|
|
|
const uint
|
|
|
|
wo = get_global_offset(0);
|
2021-10-04 02:47:45 +02:00
|
|
|
|
2022-06-20 03:59:29 +02:00
|
|
|
assume(ln == 192);
|
|
|
|
assume(wo % ln == 0);
|
2021-10-04 02:47:45 +02:00
|
|
|
|
2022-06-20 03:59:29 +02:00
|
|
|
const ushort
|
|
|
|
wi = wo / ln + get_group_id(0);
|
|
|
|
|
|
|
|
_ircd_gpt_lm_embed(ctrl, opts, accum, pos, vocab, wi, wi, li);
|
|
|
|
}
|
|
|
|
|
|
|
|
static void
|
|
|
|
_ircd_gpt_lm_embed(__global const struct ircd_gpt_ctrl *const ctrl,
|
|
|
|
__constant const struct ircd_gpt_opts *const opts,
|
|
|
|
__global ircd_gpt_vectorv *const restrict accum,
|
|
|
|
__global const ircd_gpt_vectorv *const restrict pos,
|
|
|
|
__global const ircd_gpt_vectorv *const restrict vocab,
|
|
|
|
const ushort out_idx,
|
|
|
|
const ushort tok_idx,
|
|
|
|
const ushort elem_idx)
|
|
|
|
{
|
|
|
|
const ushort
|
|
|
|
token = ctrl->token[tok_idx];
|
2021-10-04 02:47:45 +02:00
|
|
|
|
2022-06-20 03:59:29 +02:00
|
|
|
const float4
|
|
|
|
wpe = pos[tok_idx].elem[elem_idx],
|
|
|
|
wte = vocab[token].elem[elem_idx],
|
|
|
|
res = wte + wpe;
|
|
|
|
|
|
|
|
accum[out_idx].elem[elem_idx] = res;
|
2021-10-04 02:47:45 +02:00
|
|
|
}
|
|
|
|
|
2022-06-20 03:59:29 +02:00
|
|
|
//
|
|
|
|
// Frontside
|
|
|
|
//
|
|
|
|
|
|
|
|
void
|
|
|
|
ircd_gpt_ffnn_fcon_tmul(__constant const struct ircd_gpt_opts *const opts,
|
|
|
|
__local ircd_gpt_ffnn_aperaturev *const restrict out,
|
|
|
|
__local const ircd_gpt_vectorv *const restrict in,
|
|
|
|
__global const ircd_gpt_ffnn_aperaturev *const restrict bias,
|
|
|
|
__global const ircd_gpt_ffnn_aperaturev *const restrict weight,
|
|
|
|
const uint li)
|
2021-10-04 02:47:45 +02:00
|
|
|
{
|
|
|
|
const uint
|
2022-06-20 03:59:29 +02:00
|
|
|
lanes = 4,
|
|
|
|
segs = ircd_gpt_ffnn_segs,
|
|
|
|
height = ircd_gpt_vector_elems / lanes;
|
2021-10-04 02:47:45 +02:00
|
|
|
|
2022-06-20 03:59:29 +02:00
|
|
|
assume(height > 0);
|
|
|
|
assume(height % lanes == 0);
|
2021-04-27 05:43:21 +02:00
|
|
|
|
2022-06-20 03:59:29 +02:00
|
|
|
for(uint x = 0; x < segs; ++x)
|
|
|
|
out->proj[x][li] = bias->proj[x][li];
|
2021-04-27 05:43:21 +02:00
|
|
|
|
2022-06-20 03:59:29 +02:00
|
|
|
for(uint y = 0; y < height; ++y)
|
|
|
|
for(uint k = 0; k < lanes; ++k)
|
|
|
|
for(uint x = 0; x < segs; ++x)
|
|
|
|
{
|
|
|
|
const uint
|
|
|
|
row = y * lanes + k;
|
2021-04-27 05:43:21 +02:00
|
|
|
|
2022-06-20 03:59:29 +02:00
|
|
|
out->proj[x][li] += in->elem[y][k] * weight[row].proj[x][li];
|
|
|
|
}
|
2021-04-27 05:43:21 +02:00
|
|
|
}
|
|
|
|
|
2022-06-20 03:59:29 +02:00
|
|
|
void
|
2021-09-02 19:40:11 +02:00
|
|
|
ircd_gpt_ffnn_fcon(__global const struct ircd_gpt_ctrl *const ctrl,
|
2021-04-11 04:28:23 +02:00
|
|
|
__constant const struct ircd_gpt_opts *const opts,
|
2022-06-20 03:59:29 +02:00
|
|
|
__local ircd_gpt_ffnn_aperaturev *const restrict out,
|
|
|
|
__local const ircd_gpt_vectorv *const restrict in,
|
|
|
|
__global const ircd_gpt_ffnn_aperaturev *const restrict bias,
|
|
|
|
__global const ircd_gpt_ffnn_aperaturev *const restrict weight,
|
|
|
|
const uint ln,
|
|
|
|
const uint li)
|
2021-04-11 04:28:23 +02:00
|
|
|
{
|
|
|
|
const uint
|
2022-06-20 03:59:29 +02:00
|
|
|
segs = ircd_gpt_ffnn_segs;
|
2021-03-30 03:22:42 +02:00
|
|
|
|
2022-06-20 03:59:29 +02:00
|
|
|
// Fully connected
|
|
|
|
ircd_gpt_ffnn_fcon_tmul
|
|
|
|
(
|
|
|
|
opts,
|
|
|
|
out,
|
|
|
|
in,
|
|
|
|
bias,
|
|
|
|
weight,
|
|
|
|
li
|
|
|
|
);
|
2021-04-11 04:28:23 +02:00
|
|
|
|
2022-06-20 03:59:29 +02:00
|
|
|
for(uint i = 0; i < segs; ++i)
|
|
|
|
ircd_gpt_ffnn_gelu(out, out, i * ln + li);
|
2021-04-11 04:28:23 +02:00
|
|
|
}
|
|
|
|
|
2022-06-20 03:59:29 +02:00
|
|
|
void
|
|
|
|
ircd_gpt_ffnn_proj_tmul(__constant const struct ircd_gpt_opts *const opts,
|
|
|
|
__local ircd_gpt_vectorv *const restrict out,
|
|
|
|
__local const ircd_gpt_ffnn_aperaturev *const restrict in,
|
|
|
|
__global const ircd_gpt_vectorv *const restrict bias,
|
|
|
|
__global const ircd_gpt_vectorv *const restrict weight,
|
|
|
|
const uint li)
|
2021-03-30 03:22:42 +02:00
|
|
|
{
|
|
|
|
const uint
|
2022-06-20 03:59:29 +02:00
|
|
|
lanes = 4,
|
|
|
|
height = ircd_gpt_ffnn_fcon_elems / lanes;
|
|
|
|
|
|
|
|
assume(height > 0);
|
|
|
|
assume(height % lanes == 0);
|
|
|
|
|
|
|
|
out->elem[li] = bias->elem[li];
|
|
|
|
|
|
|
|
for(uint y = 0; y < height; ++y)
|
|
|
|
for(uint k = 0; k < lanes; ++k)
|
|
|
|
{
|
|
|
|
const uint
|
|
|
|
row = y * lanes + k;
|
2021-03-30 03:22:42 +02:00
|
|
|
|
2022-06-20 03:59:29 +02:00
|
|
|
out->elem[li] += in->fcon[y][k] * weight[row].elem[li];
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
void
|
|
|
|
ircd_gpt_ffnn(__global const struct ircd_gpt_ctrl *const ctrl,
|
|
|
|
__constant const struct ircd_gpt_opts *const opts,
|
|
|
|
__local ircd_gpt_vectorv *const restrict token,
|
|
|
|
__local ircd_gpt_ffnn_aperaturev *const restrict buf,
|
|
|
|
__local ircd_gpt_vectorv *const restrict tmp,
|
|
|
|
__global const ircd_gpt_vectorv *const restrict norm_bias,
|
|
|
|
__global const ircd_gpt_vectorv *const restrict norm_weight,
|
|
|
|
__global const ircd_gpt_ffnn_aperaturev *const restrict fcon_bias,
|
|
|
|
__global const ircd_gpt_ffnn_aperaturev *const restrict fcon_weight,
|
|
|
|
__global const ircd_gpt_vectorv *const restrict proj_bias,
|
|
|
|
__global const ircd_gpt_vectorv *const restrict proj_weight,
|
|
|
|
const uint ln,
|
|
|
|
const uint li)
|
|
|
|
{
|
2021-03-30 03:22:42 +02:00
|
|
|
// Layer re-normalization
|
2022-06-20 03:59:29 +02:00
|
|
|
ircd_gpt_norm(token, token, tmp, norm_bias, norm_weight, ln, li);
|
2021-03-30 03:22:42 +02:00
|
|
|
|
2021-04-11 04:28:23 +02:00
|
|
|
// ln's writes are still pending but fcon reads results across threads.
|
|
|
|
barrier(CLK_LOCAL_MEM_FENCE);
|
2021-03-30 03:22:42 +02:00
|
|
|
|
2021-04-11 04:28:23 +02:00
|
|
|
// Fully connected
|
2022-06-20 03:59:29 +02:00
|
|
|
ircd_gpt_ffnn_fcon
|
|
|
|
(
|
|
|
|
ctrl,
|
|
|
|
opts,
|
|
|
|
buf,
|
|
|
|
token,
|
|
|
|
fcon_bias,
|
|
|
|
fcon_weight,
|
|
|
|
ln,
|
|
|
|
li
|
|
|
|
);
|
2021-04-02 22:01:38 +02:00
|
|
|
|
2021-04-11 04:28:23 +02:00
|
|
|
// fcon's writes are still pending but proj reads results across threads.
|
2021-04-02 22:01:38 +02:00
|
|
|
barrier(CLK_LOCAL_MEM_FENCE);
|
2021-03-30 03:22:42 +02:00
|
|
|
|
|
|
|
// Projection
|
2022-06-20 03:59:29 +02:00
|
|
|
ircd_gpt_ffnn_proj_tmul
|
|
|
|
(
|
|
|
|
opts,
|
|
|
|
token,
|
|
|
|
buf,
|
|
|
|
proj_bias,
|
|
|
|
proj_weight,
|
|
|
|
li
|
|
|
|
);
|
2021-03-30 03:22:42 +02:00
|
|
|
}
|
|
|
|
|
2022-06-20 03:59:29 +02:00
|
|
|
static void
|
2021-09-02 19:40:11 +02:00
|
|
|
ircd_gpt_attn_self_samax(__global const struct ircd_gpt_ctrl *const ctrl,
|
2021-05-02 23:51:49 +02:00
|
|
|
__constant const struct ircd_gpt_opts *const opts,
|
2022-06-20 03:59:29 +02:00
|
|
|
__local float self[restrict][12],
|
|
|
|
const uint ln,
|
|
|
|
const uint li,
|
|
|
|
const uint wn,
|
|
|
|
const uint wi)
|
2021-05-02 23:51:49 +02:00
|
|
|
{
|
|
|
|
struct ircd_math_samax samax =
|
|
|
|
{
|
|
|
|
.mu = -10000.0f,
|
|
|
|
.sum = 0.0f,
|
|
|
|
};
|
|
|
|
|
2022-06-20 03:59:29 +02:00
|
|
|
__attribute__((opencl_unroll_hint))
|
2021-05-02 23:51:49 +02:00
|
|
|
for(uint i = 0; i < wn; ++i)
|
|
|
|
samax.mu = max(samax.mu, self[i][li]);
|
|
|
|
|
2022-06-20 03:59:29 +02:00
|
|
|
__attribute__((opencl_unroll_hint))
|
2021-05-02 23:51:49 +02:00
|
|
|
for(uint i = 0; i < wn; ++i)
|
2022-06-20 03:59:29 +02:00
|
|
|
self[i][li] -= samax.mu;
|
|
|
|
|
|
|
|
for(uint i = 0; i < wn; ++i)
|
|
|
|
self[i][li] = native_exp(self[i][li]);
|
2021-05-02 23:51:49 +02:00
|
|
|
|
|
|
|
__attribute__((opencl_unroll_hint))
|
|
|
|
for(uint i = 0; i < wn; ++i)
|
|
|
|
samax.sum += self[i][li];
|
|
|
|
|
2022-06-20 03:59:29 +02:00
|
|
|
samax.sum += FLT_EPSILON;
|
2021-05-02 23:51:49 +02:00
|
|
|
samax.lambda = 1.0f / samax.sum;
|
|
|
|
|
|
|
|
__attribute__((opencl_unroll_hint))
|
|
|
|
for(uint i = 0; i < wn; ++i)
|
|
|
|
self[i][li] *= samax.lambda;
|
|
|
|
}
|
|
|
|
|
2022-06-20 03:59:29 +02:00
|
|
|
static void
|
|
|
|
ircd_gpt_attn_self_keys(__global const struct ircd_gpt_ctrl *const ctrl,
|
|
|
|
__constant const struct ircd_gpt_opts *const opts,
|
|
|
|
__local float self[restrict][ircd_gpt_attn_rank],
|
|
|
|
__global const ircd_gpt_attn_qkvv *const restrict token,
|
|
|
|
const uint ln,
|
|
|
|
const uint li,
|
|
|
|
const uint wi,
|
|
|
|
const uint kn,
|
|
|
|
const uint i)
|
|
|
|
{
|
|
|
|
assume(i < wi);
|
|
|
|
|
|
|
|
self[i][li] = 0.0f;
|
|
|
|
|
|
|
|
__attribute__((opencl_unroll_hint))
|
|
|
|
for(uint k = 0; k < kn; ++k)
|
|
|
|
{
|
|
|
|
float4
|
|
|
|
qry = token[wi].qry.attn[li][k],
|
|
|
|
key = token[i].key.attn[li][k],
|
|
|
|
res = qry * key;
|
|
|
|
|
2022-12-28 03:13:36 +01:00
|
|
|
self[i][li] += ircd_simt_hadd_f4(res);
|
2022-06-20 03:59:29 +02:00
|
|
|
}
|
|
|
|
|
|
|
|
self[i][li] /= 8.0f;
|
|
|
|
}
|
|
|
|
|
|
|
|
static void
|
|
|
|
ircd_gpt_attn_self_vals(__global const struct ircd_gpt_ctrl *const ctrl,
|
|
|
|
__constant const struct ircd_gpt_opts *const opts,
|
|
|
|
__local ircd_gpt_vectorv *const restrict out,
|
|
|
|
__local const float self[restrict][ircd_gpt_attn_rank],
|
|
|
|
__global const ircd_gpt_attn_qkvv *const restrict token,
|
|
|
|
const uint li,
|
|
|
|
const uint wi,
|
|
|
|
const uint ki,
|
|
|
|
const uint ti)
|
|
|
|
{
|
|
|
|
out->attn[ti][ki] = 0.0f;
|
|
|
|
|
|
|
|
__attribute__((opencl_unroll_hint))
|
|
|
|
for(uint i = 0; i < wi; ++i)
|
|
|
|
{
|
|
|
|
const float4
|
|
|
|
val = token[i].val.attn[ti][ki],
|
|
|
|
attn = self[i][ti],
|
|
|
|
res = attn * val;
|
|
|
|
|
|
|
|
out->attn[ti][ki] += res;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
static void
|
|
|
|
ircd_gpt_attn_self(__global struct ircd_gpt_ctrl *const ctrl,
|
2021-04-02 22:01:38 +02:00
|
|
|
__constant const struct ircd_gpt_opts *const opts,
|
2022-06-20 03:59:29 +02:00
|
|
|
__local ircd_gpt_vectorv *const restrict out,
|
|
|
|
__local float self[restrict][ircd_gpt_attn_rank],
|
|
|
|
__global float attns[restrict][ircd_gpt_attn_rank],
|
|
|
|
__global const ircd_gpt_attn_qkvv *const restrict token,
|
|
|
|
const uint ln,
|
|
|
|
const uint li,
|
|
|
|
const uint wi)
|
2021-03-30 03:22:42 +02:00
|
|
|
{
|
2022-06-20 03:59:29 +02:00
|
|
|
//assume(opts->attn_rank == sizeof(self[0]) / sizeof(float));
|
|
|
|
assume(opts->attn_rank == ircd_gpt_attn_rank);
|
|
|
|
assume(ctrl->count < ircd_gpt_context_tokens);
|
|
|
|
assume(ctrl->tokens <= ircd_gpt_context_tokens);
|
|
|
|
assume(ctrl->tokens > wi);
|
|
|
|
assume(ctrl->tokens > 0);
|
|
|
|
|
2021-03-30 03:22:42 +02:00
|
|
|
const uint
|
2022-06-20 03:59:29 +02:00
|
|
|
wn = ctrl->tokens,
|
|
|
|
kn = ln / opts->attn_rank,
|
2021-09-17 08:03:44 +02:00
|
|
|
ki = li / opts->attn_rank,
|
2022-06-20 03:59:29 +02:00
|
|
|
ti = li % opts->attn_rank;
|
2021-03-30 03:22:42 +02:00
|
|
|
|
2021-04-17 21:01:12 +02:00
|
|
|
// Low-rank mask
|
2021-09-17 08:03:44 +02:00
|
|
|
if(li < opts->attn_rank)
|
2021-04-17 21:01:12 +02:00
|
|
|
{
|
2022-06-20 03:59:29 +02:00
|
|
|
// Left attention
|
|
|
|
uint i;
|
|
|
|
for(i = 0; i < wi; ++i)
|
|
|
|
ircd_gpt_attn_self_keys(ctrl, opts, self, token, ln, li, wi, kn, i);
|
2021-04-17 21:01:12 +02:00
|
|
|
|
2022-06-20 03:59:29 +02:00
|
|
|
// Future mask
|
|
|
|
__attribute__((opencl_unroll_hint))
|
|
|
|
while(i < wn)
|
|
|
|
self[i++][li] = -10000.0f;
|
2021-03-30 03:22:42 +02:00
|
|
|
|
2021-04-17 21:01:12 +02:00
|
|
|
// Three-piece softmax
|
2022-06-20 03:59:29 +02:00
|
|
|
ircd_gpt_attn_self_samax(ctrl, opts, self, ln, li, wn, wi);
|
2021-04-17 21:01:12 +02:00
|
|
|
}
|
2021-03-30 03:22:42 +02:00
|
|
|
|
2021-04-17 21:01:12 +02:00
|
|
|
// Propagate to full width for value dot prod.
|
|
|
|
barrier(CLK_LOCAL_MEM_FENCE);
|
2022-06-20 03:59:29 +02:00
|
|
|
ircd_gpt_attn_self_vals(ctrl, opts, out, self, token, li, wi, ki, ti);
|
2021-03-30 03:22:42 +02:00
|
|
|
|
2022-06-20 03:59:29 +02:00
|
|
|
// Save softmax results for later analysis/observation.
|
|
|
|
if(li < opts->attn_rank)
|
2021-04-17 21:01:12 +02:00
|
|
|
{
|
2022-06-20 03:59:29 +02:00
|
|
|
__attribute__((opencl_unroll_hint))
|
|
|
|
for(uint i = 0; i < wn; ++i)
|
|
|
|
attns[i][li] = self[i][li];
|
2021-04-17 21:01:12 +02:00
|
|
|
}
|
2021-03-30 03:22:42 +02:00
|
|
|
}
|
|
|
|
|
2022-06-20 03:59:29 +02:00
|
|
|
static void
|
|
|
|
ircd_gpt_attn_proj_tmul(__constant const struct ircd_gpt_opts *const opts,
|
|
|
|
__local ircd_gpt_vectorv *const restrict out,
|
|
|
|
__local const ircd_gpt_vectorv *const restrict in,
|
|
|
|
__global const ircd_gpt_vectorv *const restrict bias,
|
|
|
|
__global const ircd_gpt_vectorv *const restrict weight,
|
|
|
|
const uint li)
|
2021-03-30 03:22:42 +02:00
|
|
|
{
|
|
|
|
const uint
|
2022-06-20 03:59:29 +02:00
|
|
|
lanes = 4,
|
|
|
|
height = ircd_gpt_vector_elems / 4;
|
2021-04-02 22:01:38 +02:00
|
|
|
|
2022-06-20 03:59:29 +02:00
|
|
|
assume(height > 0);
|
|
|
|
assume(height % lanes == 0);
|
|
|
|
|
|
|
|
out->elem[li] = bias->elem[li];
|
|
|
|
|
|
|
|
for(uint y = 0; y < height; ++y)
|
|
|
|
for(uint k = 0; k < lanes; ++k)
|
|
|
|
{
|
|
|
|
const uint
|
|
|
|
row = y * lanes + k;
|
|
|
|
|
|
|
|
const float4
|
|
|
|
a = in->elem[y][k],
|
|
|
|
b = weight[row].elem[li];
|
|
|
|
|
|
|
|
out->elem[li] += a * b;
|
|
|
|
}
|
2021-03-30 03:22:42 +02:00
|
|
|
}
|
|
|
|
|
|
|
|
__kernel void
|
2022-06-20 03:59:29 +02:00
|
|
|
__attribute__((visibility("protected")))
|
|
|
|
ircd_gpt_coil(__global struct ircd_gpt_ctrl *const ctrl,
|
2021-04-02 22:01:38 +02:00
|
|
|
__constant const struct ircd_gpt_opts *const opts,
|
2022-06-20 03:59:29 +02:00
|
|
|
__private const uint layer,
|
|
|
|
__global ircd_gpt_vectorv *const restrict accum,
|
|
|
|
__global float attns[restrict][ircd_gpt_attn_rank],
|
|
|
|
__global const ircd_gpt_attn_qkvv *const restrict state,
|
|
|
|
__global const ircd_gpt_vectorv *const restrict attn_proj_bias,
|
|
|
|
__global const ircd_gpt_vectorv *const restrict attn_proj_weight,
|
|
|
|
__global const ircd_gpt_vectorv *const restrict ffnn_norm_bias,
|
|
|
|
__global const ircd_gpt_vectorv *const restrict ffnn_norm_weight,
|
|
|
|
__global const ircd_gpt_ffnn_aperaturev *const restrict ffnn_fcon_bias,
|
|
|
|
__global const ircd_gpt_ffnn_aperaturev *const restrict ffnn_fcon_weight,
|
|
|
|
__global const ircd_gpt_vectorv *const restrict ffnn_proj_bias,
|
|
|
|
__global const ircd_gpt_vectorv *const restrict ffnn_proj_weight)
|
2021-03-30 03:22:42 +02:00
|
|
|
{
|
2021-04-11 04:28:23 +02:00
|
|
|
const uint
|
|
|
|
li = get_local_id(0),
|
2021-04-27 05:43:21 +02:00
|
|
|
ln = get_local_size(0),
|
2022-06-20 03:59:29 +02:00
|
|
|
wo = get_global_offset(0),
|
|
|
|
wi = wo / ln + get_group_id(0);
|
2021-04-11 04:28:23 +02:00
|
|
|
|
2022-06-20 03:59:29 +02:00
|
|
|
assume(ln == 192);
|
|
|
|
assume(wo % ln == 0);
|
2021-04-17 21:01:12 +02:00
|
|
|
|
2021-04-22 21:17:29 +02:00
|
|
|
__local union
|
|
|
|
{
|
2022-06-20 03:59:29 +02:00
|
|
|
float
|
|
|
|
attn_self[ircd_gpt_context_tokens][ircd_gpt_attn_rank];
|
|
|
|
|
|
|
|
ircd_gpt_ffnn_aperaturev
|
2021-04-27 05:43:21 +02:00
|
|
|
ffnn_fcon[2];
|
2021-04-22 21:17:29 +02:00
|
|
|
|
2022-06-20 03:59:29 +02:00
|
|
|
ircd_gpt_vectorv
|
|
|
|
vector[8];
|
2021-04-22 21:17:29 +02:00
|
|
|
}
|
|
|
|
buf;
|
2021-04-11 04:28:23 +02:00
|
|
|
|
2022-06-20 03:59:29 +02:00
|
|
|
__local ircd_gpt_vectorv
|
|
|
|
buf0, buf1,
|
|
|
|
*const restrict attn_self = &buf1,
|
|
|
|
*const restrict token = &buf0,
|
|
|
|
*const restrict tmp = &buf1;
|
|
|
|
|
2021-04-11 04:28:23 +02:00
|
|
|
// Self-attention backend; this computes the self-attention result now
|
|
|
|
// that keys and values are globally visible across tokens.
|
2021-04-02 22:01:38 +02:00
|
|
|
ircd_gpt_attn_self
|
|
|
|
(
|
|
|
|
ctrl,
|
|
|
|
opts,
|
2022-06-20 03:59:29 +02:00
|
|
|
attn_self,
|
2021-04-22 21:17:29 +02:00
|
|
|
buf.attn_self,
|
2022-06-20 03:59:29 +02:00
|
|
|
attns,
|
|
|
|
state,
|
|
|
|
ln,
|
|
|
|
li,
|
|
|
|
wi
|
2021-04-02 22:01:38 +02:00
|
|
|
);
|
2021-03-30 03:22:42 +02:00
|
|
|
|
2021-04-11 04:28:23 +02:00
|
|
|
barrier(CLK_LOCAL_MEM_FENCE);
|
|
|
|
|
|
|
|
// Project result of self-attention.
|
2022-06-20 03:59:29 +02:00
|
|
|
ircd_gpt_attn_proj_tmul
|
2021-04-02 22:01:38 +02:00
|
|
|
(
|
|
|
|
opts,
|
2022-06-20 03:59:29 +02:00
|
|
|
token,
|
|
|
|
attn_self,
|
2021-04-02 22:01:38 +02:00
|
|
|
attn_proj_bias,
|
2022-06-20 03:59:29 +02:00
|
|
|
attn_proj_weight,
|
|
|
|
li
|
2021-04-02 22:01:38 +02:00
|
|
|
);
|
2021-03-30 03:22:42 +02:00
|
|
|
|
2021-04-11 04:28:23 +02:00
|
|
|
// Frontend accumulation
|
|
|
|
{
|
|
|
|
const float4
|
2022-06-20 03:59:29 +02:00
|
|
|
attn = token->elem[li],
|
|
|
|
resid = accum[wi].elem[li],
|
|
|
|
result = resid + attn;
|
2021-04-11 04:28:23 +02:00
|
|
|
|
2022-06-20 03:59:29 +02:00
|
|
|
token->elem[li] = result;
|
|
|
|
accum[wi].elem[li] = result;
|
2021-04-11 04:28:23 +02:00
|
|
|
}
|
|
|
|
|
2022-10-05 19:33:44 +02:00
|
|
|
barrier(CLK_GLOBAL_MEM_FENCE);
|
|
|
|
|
|
|
|
// Backend mlp
|
2021-04-02 22:01:38 +02:00
|
|
|
ircd_gpt_ffnn
|
|
|
|
(
|
|
|
|
ctrl,
|
|
|
|
opts,
|
2022-06-20 03:59:29 +02:00
|
|
|
token,
|
|
|
|
buf.ffnn_fcon,
|
|
|
|
tmp,
|
2021-04-02 22:01:38 +02:00
|
|
|
ffnn_norm_bias,
|
|
|
|
ffnn_norm_weight,
|
|
|
|
ffnn_fcon_bias,
|
|
|
|
ffnn_fcon_weight,
|
|
|
|
ffnn_proj_bias,
|
2022-06-20 03:59:29 +02:00
|
|
|
ffnn_proj_weight,
|
|
|
|
ln,
|
|
|
|
li
|
2021-04-02 22:01:38 +02:00
|
|
|
);
|
2021-04-11 04:28:23 +02:00
|
|
|
|
|
|
|
// Backend accumulation
|
|
|
|
{
|
|
|
|
const float4
|
2022-06-20 03:59:29 +02:00
|
|
|
ffnn = token->elem[li],
|
|
|
|
resid = accum[wi].elem[li],
|
|
|
|
result = resid + ffnn;
|
2021-04-11 04:28:23 +02:00
|
|
|
|
2022-06-20 03:59:29 +02:00
|
|
|
accum[wi].elem[li] = result;
|
2021-04-11 04:28:23 +02:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2022-06-20 03:59:29 +02:00
|
|
|
static void
|
|
|
|
ircd_gpt_attn_fcon_tmul(__constant const struct ircd_gpt_opts *const opts,
|
|
|
|
__local ircd_gpt_attn_aperaturev *const restrict out,
|
|
|
|
__local const ircd_gpt_vectorv *const restrict in,
|
|
|
|
__global const ircd_gpt_attn_aperaturev *const restrict bias,
|
|
|
|
__global const ircd_gpt_attn_aperaturev *const restrict weight,
|
|
|
|
const uint ln,
|
|
|
|
const uint li)
|
|
|
|
{
|
|
|
|
const uint
|
|
|
|
lanes = 4,
|
|
|
|
segs = ircd_gpt_attn_segs,
|
|
|
|
height = ircd_gpt_vector_elems / lanes;
|
|
|
|
|
|
|
|
assume(height > 0);
|
|
|
|
assume(height % segs == 0);
|
|
|
|
assume(height % lanes == 0);
|
|
|
|
|
|
|
|
for(uint x = 0; x < segs; ++x)
|
|
|
|
out->proj[x][li] = bias->proj[x][li];
|
|
|
|
|
|
|
|
for(uint y = 0; y < height; ++y)
|
|
|
|
for(uint k = 0; k < lanes; ++k)
|
|
|
|
for(uint x = 0; x < segs; ++x)
|
|
|
|
{
|
|
|
|
const uint
|
|
|
|
row = y * lanes + k;
|
|
|
|
|
|
|
|
const float4
|
|
|
|
a = in->elem[y][k],
|
|
|
|
b = weight[row].proj[x][li];
|
|
|
|
|
|
|
|
out->proj[x][li] += a * b;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2021-04-11 04:28:23 +02:00
|
|
|
__kernel void
|
2021-09-02 19:40:11 +02:00
|
|
|
ircd_gpt_attn_fcon(__global const struct ircd_gpt_ctrl *const ctrl,
|
2021-04-11 04:28:23 +02:00
|
|
|
__constant const struct ircd_gpt_opts *const opts,
|
2022-06-20 03:59:29 +02:00
|
|
|
__private const uint layer,
|
|
|
|
__global ircd_gpt_attn_aperaturev *const restrict state,
|
|
|
|
__global const ircd_gpt_vectorv *const restrict accum,
|
|
|
|
__global const ircd_gpt_vectorv *const restrict norm_bias,
|
|
|
|
__global const ircd_gpt_vectorv *const restrict norm_weight,
|
|
|
|
__global const ircd_gpt_attn_aperaturev *const restrict fcon_bias,
|
|
|
|
__global const ircd_gpt_attn_aperaturev *const restrict fcon_weight)
|
2021-04-11 04:28:23 +02:00
|
|
|
{
|
|
|
|
const uint
|
|
|
|
li = get_local_id(0),
|
|
|
|
ln = get_local_size(0),
|
2022-06-20 03:59:29 +02:00
|
|
|
wo = get_global_offset(0),
|
|
|
|
wi = wo / ln + get_group_id(0),
|
|
|
|
segs = ircd_gpt_attn_segs;
|
|
|
|
|
|
|
|
assume(ln == 192);
|
|
|
|
assume(wo % ln == 0);
|
2021-04-11 04:28:23 +02:00
|
|
|
|
2022-06-20 03:59:29 +02:00
|
|
|
__local ircd_gpt_attn_aperaturev
|
|
|
|
attn;
|
2021-04-11 04:28:23 +02:00
|
|
|
|
2022-06-20 03:59:29 +02:00
|
|
|
__local ircd_gpt_vectorv
|
|
|
|
token, *const restrict tmp = attn.vector;
|
2021-04-11 04:28:23 +02:00
|
|
|
|
2022-06-20 03:59:29 +02:00
|
|
|
token.elem[li] = accum[wi].elem[li];
|
2021-04-11 04:28:23 +02:00
|
|
|
|
|
|
|
// Layer re-normalization
|
2022-06-20 03:59:29 +02:00
|
|
|
ircd_gpt_norm(&token, &token, tmp, norm_bias, norm_weight, ln, li);
|
2021-04-11 04:28:23 +02:00
|
|
|
|
|
|
|
// Ln's writes are still pending; fcon requires results across threads.
|
|
|
|
barrier(CLK_LOCAL_MEM_FENCE);
|
|
|
|
|
|
|
|
// Fully connected
|
2022-06-20 03:59:29 +02:00
|
|
|
ircd_gpt_attn_fcon_tmul
|
|
|
|
(
|
|
|
|
opts,
|
|
|
|
&attn,
|
|
|
|
&token,
|
|
|
|
fcon_bias,
|
|
|
|
fcon_weight,
|
|
|
|
ln,
|
|
|
|
li
|
|
|
|
);
|
2021-04-11 04:28:23 +02:00
|
|
|
|
|
|
|
// Export queries, keys, and values.
|
2022-06-20 03:59:29 +02:00
|
|
|
for(uint x = 0; x < segs; ++x)
|
|
|
|
state[wi].proj[x][li] = attn.proj[x][li];
|
2021-04-02 22:01:38 +02:00
|
|
|
}
|
|
|
|
|
|
|
|
__kernel void
|
2021-09-02 19:40:11 +02:00
|
|
|
ircd_gpt_lm_norm(__global const struct ircd_gpt_ctrl *const ctrl,
|
2021-04-02 22:01:38 +02:00
|
|
|
__constant const struct ircd_gpt_opts *const opts,
|
2022-06-20 03:59:29 +02:00
|
|
|
__global ircd_gpt_vectorv *const restrict accum,
|
|
|
|
__global const ircd_gpt_vectorv *const restrict norm_bias,
|
|
|
|
__global const ircd_gpt_vectorv *const restrict norm_weight)
|
2021-03-30 03:22:42 +02:00
|
|
|
{
|
|
|
|
const uint
|
|
|
|
li = get_local_id(0),
|
|
|
|
ln = get_local_size(0),
|
2022-06-20 03:59:29 +02:00
|
|
|
wo = get_global_offset(0),
|
|
|
|
wi = wo / ln + get_group_id(0);
|
|
|
|
|
|
|
|
assume(ln == 192);
|
|
|
|
assume(wo % ln == 0);
|
2021-03-30 03:22:42 +02:00
|
|
|
|
2022-06-20 03:59:29 +02:00
|
|
|
__local ircd_gpt_vectorv
|
|
|
|
tmp, token;
|
2021-03-30 03:22:42 +02:00
|
|
|
|
2022-06-20 03:59:29 +02:00
|
|
|
token.elem[li] = accum[wi].elem[li];
|
2021-03-30 03:22:42 +02:00
|
|
|
|
|
|
|
// Final re-normalization
|
2022-06-20 03:59:29 +02:00
|
|
|
ircd_gpt_norm(&token, &token, &tmp, norm_bias, norm_weight, ln, li);
|
2021-03-30 03:22:42 +02:00
|
|
|
|
2022-06-20 03:59:29 +02:00
|
|
|
accum[wi].elem[li] = token.elem[li];
|
2021-03-30 03:22:42 +02:00
|
|
|
}
|
|
|
|
|
|
|
|
__kernel void
|
2021-09-02 19:40:11 +02:00
|
|
|
ircd_gpt_lm_logit(__global const struct ircd_gpt_ctrl *const ctrl,
|
2021-04-02 22:01:38 +02:00
|
|
|
__constant const struct ircd_gpt_opts *const opts,
|
|
|
|
__global float *const restrict logit,
|
2022-06-20 03:59:29 +02:00
|
|
|
__global const ircd_gpt_vectorv *const restrict accum,
|
|
|
|
__global const ircd_gpt_vectorv *const restrict pos,
|
|
|
|
__global const ircd_gpt_vectorv *const restrict vocab)
|
2021-03-30 03:22:42 +02:00
|
|
|
{
|
|
|
|
const uint
|
2021-04-11 04:28:23 +02:00
|
|
|
gi = get_global_id(0),
|
2022-06-20 03:59:29 +02:00
|
|
|
wi = ctrl->count - 1;
|
2021-03-30 03:22:42 +02:00
|
|
|
|
2022-06-20 03:59:29 +02:00
|
|
|
assume(opts->embed_width == 192);
|
|
|
|
assume(opts->logits <= 65536);
|
|
|
|
|
|
|
|
if(gi >= opts->logits)
|
|
|
|
{
|
|
|
|
logit[gi] = -10000.0f;
|
|
|
|
return;
|
|
|
|
}
|
|
|
|
|
|
|
|
float acc = 0.0f;
|
|
|
|
for(uint j = 0; j < opts->embed_width; ++j)
|
2021-03-30 03:22:42 +02:00
|
|
|
{
|
|
|
|
const float4
|
2022-06-20 03:59:29 +02:00
|
|
|
token = vocab[gi].elem[j],
|
|
|
|
in = accum[wi].elem[j],
|
|
|
|
wpe = pos[wi].elem[j],
|
|
|
|
res = in * token + wpe;
|
2021-03-30 03:22:42 +02:00
|
|
|
|
2022-12-28 03:13:36 +01:00
|
|
|
acc += ircd_simt_hadd_f4(res);
|
2021-03-30 03:22:42 +02:00
|
|
|
}
|
|
|
|
|
2022-06-20 03:59:29 +02:00
|
|
|
logit[gi] = acc;
|
2021-04-11 04:28:23 +02:00
|
|
|
}
|
|
|
|
|
|
|
|
__kernel void
|
2022-06-20 03:59:29 +02:00
|
|
|
__attribute__((reqd_work_group_size(256, 1, 1)))
|
2021-09-02 19:40:11 +02:00
|
|
|
ircd_gpt_lm_logsm(__global struct ircd_gpt_ctrl *const ctrl,
|
2021-04-11 04:28:23 +02:00
|
|
|
__constant const struct ircd_gpt_opts *const opts,
|
2022-06-20 03:59:29 +02:00
|
|
|
__global float logit[restrict 65536])
|
2021-04-11 04:28:23 +02:00
|
|
|
{
|
|
|
|
const uint
|
|
|
|
li = get_local_id(0),
|
|
|
|
ln = get_local_size(0),
|
2022-06-20 03:59:29 +02:00
|
|
|
//wo = get_global_offset(0),
|
|
|
|
//wi = wo / ln + get_group_id(0),
|
|
|
|
wn = 50432,
|
|
|
|
tn = wn / ln,
|
|
|
|
start = tn * li,
|
|
|
|
stop = min(start + tn, opts->logits);
|
2021-04-11 04:28:23 +02:00
|
|
|
|
2022-06-20 03:59:29 +02:00
|
|
|
__local float
|
|
|
|
mu[256], sum[256], lambda[256];
|
2021-04-11 04:28:23 +02:00
|
|
|
|
2022-06-20 03:59:29 +02:00
|
|
|
__local struct ircd_math_samax
|
|
|
|
samax;
|
2021-04-11 04:28:23 +02:00
|
|
|
|
2022-06-20 03:59:29 +02:00
|
|
|
assume(ln == 256);
|
2021-04-11 04:28:23 +02:00
|
|
|
|
2022-06-20 03:59:29 +02:00
|
|
|
mu[li] = -10000.0f;
|
|
|
|
__attribute__((opencl_unroll_hint))
|
|
|
|
for(uint ti = start; ti < stop; ++ti)
|
|
|
|
mu[li] = max(mu[li], logit[ti]);
|
2021-04-11 04:28:23 +02:00
|
|
|
|
2022-06-20 03:59:29 +02:00
|
|
|
ircd_simt_reduce_max_flldr(mu, ln, li);
|
2021-04-11 04:28:23 +02:00
|
|
|
|
2022-06-20 03:59:29 +02:00
|
|
|
if(li == 0)
|
|
|
|
samax.mu = mu[li];
|
2021-04-11 04:28:23 +02:00
|
|
|
|
2022-10-04 00:14:23 +02:00
|
|
|
ircd_simt_broadcast_flldr(mu, ln, li);
|
|
|
|
|
2022-06-20 03:59:29 +02:00
|
|
|
sum[li] = 0.0f;
|
|
|
|
for(uint ti = start; ti < stop; ++ti)
|
2021-04-11 04:28:23 +02:00
|
|
|
{
|
2022-06-20 03:59:29 +02:00
|
|
|
const float
|
2022-10-04 00:14:23 +02:00
|
|
|
sub = logit[ti] - mu[li],
|
2022-06-20 03:59:29 +02:00
|
|
|
res = native_exp(sub);
|
2021-04-17 21:01:12 +02:00
|
|
|
|
2022-06-20 03:59:29 +02:00
|
|
|
sum[li] += res;
|
2021-04-11 04:28:23 +02:00
|
|
|
}
|
|
|
|
|
2022-06-20 03:59:29 +02:00
|
|
|
ircd_simt_reduce_add_flldr(sum, ln, li);
|
2021-04-11 04:28:23 +02:00
|
|
|
|
|
|
|
if(li == 0)
|
2022-06-20 03:59:29 +02:00
|
|
|
sum[li] += FLT_EPSILON,
|
|
|
|
samax.sum = sum[li],
|
|
|
|
samax.lambda = lambda[li] = 1.0f / sum[li];
|
|
|
|
|
|
|
|
ircd_simt_broadcast_flldr(lambda, ln, li);
|
|
|
|
|
|
|
|
for(uint ti = start; ti < stop; ++ti)
|
2021-04-11 04:28:23 +02:00
|
|
|
{
|
2021-04-17 21:01:12 +02:00
|
|
|
const float
|
2022-10-04 00:14:23 +02:00
|
|
|
sub = logit[ti] - mu[li],
|
2022-06-20 03:59:29 +02:00
|
|
|
res = lambda[li] * native_exp(sub);
|
2021-04-11 04:28:23 +02:00
|
|
|
|
2022-06-20 03:59:29 +02:00
|
|
|
logit[ti] = res;
|
2021-04-11 04:28:23 +02:00
|
|
|
}
|
2021-03-30 03:22:42 +02:00
|
|
|
}
|
|
|
|
|
2022-06-20 03:59:29 +02:00
|
|
|
void
|
|
|
|
ircd_gpt_lm_result_top(__local struct ircd_gpt_ctrl *const ctrl,
|
2021-09-18 08:27:23 +02:00
|
|
|
__constant const struct ircd_gpt_opts *const opts,
|
|
|
|
__local const ushort *const restrict idx,
|
|
|
|
__global const float *const restrict logsm,
|
|
|
|
const uint i)
|
2021-04-02 22:01:38 +02:00
|
|
|
{
|
2021-09-18 08:27:23 +02:00
|
|
|
const ushort
|
|
|
|
token = idx[i];
|
2021-04-02 22:01:38 +02:00
|
|
|
|
2021-09-18 08:27:23 +02:00
|
|
|
const float
|
2022-10-20 23:12:59 +02:00
|
|
|
samax = logsm[token];
|
2021-04-02 22:01:38 +02:00
|
|
|
|
2021-09-18 08:27:23 +02:00
|
|
|
ctrl->top[i].token = token;
|
|
|
|
ctrl->top[i].samax = samax;
|
2021-04-02 22:01:38 +02:00
|
|
|
}
|
|
|
|
|
2022-06-20 03:59:29 +02:00
|
|
|
void
|
|
|
|
ircd_gpt_lm_result_label_mean(__local struct ircd_gpt_ctrl *const ctrl,
|
|
|
|
__constant const struct ircd_gpt_opts *const opts,
|
|
|
|
__local struct ircd_math_mean *const mean,
|
|
|
|
const float last)
|
2021-04-02 22:01:38 +02:00
|
|
|
{
|
2022-06-20 03:59:29 +02:00
|
|
|
const uint
|
|
|
|
div = mean->div + 1,
|
|
|
|
sum_sel = mean->div % 4;
|
2021-04-02 22:01:38 +02:00
|
|
|
|
2021-09-18 08:27:23 +02:00
|
|
|
const float
|
2022-06-20 03:59:29 +02:00
|
|
|
sum = mean->sum[0] + mean->sum[1] + mean->sum[2] + mean->sum[3] + last,
|
|
|
|
res = sum / div;
|
2021-09-18 08:27:23 +02:00
|
|
|
|
2022-06-20 03:59:29 +02:00
|
|
|
mean->sum[sum_sel] += last;
|
|
|
|
mean->div = div;
|
|
|
|
mean->last = last;
|
|
|
|
mean->mean = res;
|
|
|
|
}
|
2021-09-18 08:27:23 +02:00
|
|
|
|
2022-06-20 03:59:29 +02:00
|
|
|
void
|
|
|
|
ircd_gpt_lm_result_label(__local struct ircd_gpt_ctrl *const ctrl,
|
|
|
|
__constant const struct ircd_gpt_opts *const opts,
|
|
|
|
__local struct ircd_gpt_ctrl_label *const label,
|
|
|
|
__global const float *const restrict logsm)
|
|
|
|
{
|
|
|
|
const ushort
|
|
|
|
token = label->logit.token;
|
2021-09-18 08:27:23 +02:00
|
|
|
|
2022-06-20 03:59:29 +02:00
|
|
|
const float
|
2022-10-20 23:12:59 +02:00
|
|
|
samax = logsm[token],
|
2022-06-20 03:59:29 +02:00
|
|
|
loss = 0.0f - native_log(samax),
|
|
|
|
ppl = (1.0f - samax) * native_log2(opts->logits);
|
2021-09-18 08:27:23 +02:00
|
|
|
|
2022-06-20 03:59:29 +02:00
|
|
|
label->logit.samax = samax;
|
|
|
|
ircd_gpt_lm_result_label_mean(ctrl, opts, &label->loss, loss);
|
|
|
|
ircd_gpt_lm_result_label_mean(ctrl, opts, &label->ppl, ppl);
|
2021-09-18 08:27:23 +02:00
|
|
|
}
|
|
|
|
|
2022-06-20 03:59:29 +02:00
|
|
|
ushort
|
|
|
|
ircd_gpt_lm_result_select(__local struct ircd_gpt_ctrl *const ctrl,
|
2021-09-18 08:27:23 +02:00
|
|
|
__constant const struct ircd_gpt_opts *const opts,
|
|
|
|
__local const ushort *const restrict idx,
|
2022-06-20 03:59:29 +02:00
|
|
|
__global const float *const restrict logsm)
|
2021-09-18 08:27:23 +02:00
|
|
|
{
|
2021-04-02 22:01:38 +02:00
|
|
|
const ulong
|
2022-06-20 03:59:29 +02:00
|
|
|
ent_k = max(opts->top_k, 1U) - 1,
|
|
|
|
rnd = ircd_simt_rand_xoshiro256pl(ctrl->rand);
|
2021-09-17 17:21:20 +02:00
|
|
|
|
|
|
|
const float
|
2022-06-20 03:59:29 +02:00
|
|
|
ent_p = min(max(opts->top_p, 0.0f), 1.0f),
|
|
|
|
thresh = ent_p;
|
2021-09-17 17:21:20 +02:00
|
|
|
|
2022-06-20 03:59:29 +02:00
|
|
|
float acc = 1.0f;
|
2021-09-17 17:21:20 +02:00
|
|
|
ushort select = 0;
|
2022-06-20 03:59:29 +02:00
|
|
|
for(; select < ent_k; ++select)
|
|
|
|
if((acc -= logsm[idx[select]]) < thresh)
|
2021-09-17 17:21:20 +02:00
|
|
|
break;
|
2021-04-02 22:01:38 +02:00
|
|
|
|
|
|
|
const ushort
|
2022-06-20 03:59:29 +02:00
|
|
|
token = idx[select];
|
2021-04-02 22:01:38 +02:00
|
|
|
|
2022-06-20 03:59:29 +02:00
|
|
|
return token;
|
2021-09-18 08:27:23 +02:00
|
|
|
}
|
2021-04-17 21:01:12 +02:00
|
|
|
|
2022-06-20 03:59:29 +02:00
|
|
|
static ushort
|
|
|
|
ircd_gpt_lm_result(__local struct ircd_gpt_ctrl *const ctrl,
|
|
|
|
__constant const struct ircd_gpt_opts *const opts,
|
|
|
|
__local const ushort *const restrict idx,
|
|
|
|
__global const float *const restrict logsm)
|
2021-09-18 08:27:23 +02:00
|
|
|
{
|
2022-06-20 03:59:29 +02:00
|
|
|
const ushort
|
|
|
|
token = ircd_gpt_lm_result_select(ctrl, opts, idx, logsm);
|
|
|
|
|
|
|
|
// Update the dynamic result label.
|
|
|
|
ctrl->select.logit.token = token;
|
|
|
|
ircd_gpt_lm_result_label(ctrl, opts, &ctrl->select, logsm);
|
2021-04-11 04:28:23 +02:00
|
|
|
|
2022-06-20 03:59:29 +02:00
|
|
|
// Update the dynamic target label.
|
|
|
|
ctrl->target.logit.token = ctrl->count < ctrl->tokens?
|
|
|
|
ctrl->token[ctrl->count]:
|
|
|
|
ctrl->select.logit.token;
|
|
|
|
|
|
|
|
ircd_gpt_lm_result_label(ctrl, opts, &ctrl->target, logsm);
|
|
|
|
|
|
|
|
const bool
|
|
|
|
hit = ctrl->select.logit.token == ctrl->target.logit.token;
|
|
|
|
|
|
|
|
// Update the token context.
|
|
|
|
if(ctrl->count == ctrl->tokens)
|
2022-10-09 03:00:40 +02:00
|
|
|
ctrl->token[ctrl->tokens++] = ctrl->select.logit.token;
|
|
|
|
else
|
|
|
|
ctrl->accept = -2;
|
2022-06-20 03:59:29 +02:00
|
|
|
|
|
|
|
ctrl->miss += !hit;
|
|
|
|
ctrl->hit += hit;
|
|
|
|
ctrl->count++;
|
|
|
|
return token;
|
|
|
|
}
|
|
|
|
|
|
|
|
static void
|
|
|
|
ircd_gpt_lm_result_attns(__local struct ircd_gpt_ctrl *const ctrl,
|
|
|
|
__constant const struct ircd_gpt_opts *const opts,
|
|
|
|
__global const float *const restrict attns,
|
|
|
|
const uint ln,
|
|
|
|
const uint li)
|
|
|
|
{
|
|
|
|
const uint
|
|
|
|
layer = li / opts->layers,
|
|
|
|
head = li % opts->attn_rank,
|
|
|
|
base = layer * opts->attn_self_elems;
|
|
|
|
|
|
|
|
uint best = 0;
|
|
|
|
float bestv = 10000.0f;
|
|
|
|
for(uint i = 0; i < ctrl->count; ++i)
|
|
|
|
{
|
|
|
|
const uint
|
|
|
|
bx = (((i + 1) * i) / 2) * opts->attn_rank,
|
|
|
|
idx = base + bx + i * 12 + head;
|
|
|
|
|
|
|
|
if(attns[idx] < bestv)
|
|
|
|
bestv = attns[idx],
|
|
|
|
best = i;
|
|
|
|
}
|
|
|
|
|
|
|
|
ctrl->attn[layer][head] = best;
|
2021-04-02 22:01:38 +02:00
|
|
|
}
|
|
|
|
|
2021-03-30 03:22:42 +02:00
|
|
|
__kernel void
|
2022-06-20 03:59:29 +02:00
|
|
|
__attribute__((visibility("protected")))
|
2022-09-24 22:49:40 +02:00
|
|
|
__attribute__((reqd_work_group_size(256, 1, 1)))
|
2022-06-20 03:59:29 +02:00
|
|
|
ircd_gpt_lm_select(__global struct ircd_gpt_ctrl *const restrict ctrl_,
|
2021-04-02 22:01:38 +02:00
|
|
|
__constant const struct ircd_gpt_opts *const opts,
|
2022-06-20 03:59:29 +02:00
|
|
|
__global const float logsm[restrict 65536],
|
|
|
|
__global const float *const restrict attns)
|
2021-03-30 03:22:42 +02:00
|
|
|
{
|
|
|
|
const uint
|
|
|
|
li = get_local_id(0),
|
|
|
|
ln = get_local_size(0),
|
2022-06-20 03:59:29 +02:00
|
|
|
logits_pad = ln - (opts->logits % ln),
|
|
|
|
tn = (opts->logits + logits_pad) / ln,
|
|
|
|
start = tn * li,
|
|
|
|
stop = min(start + tn, opts->logits);
|
2021-03-30 03:22:42 +02:00
|
|
|
|
2021-04-11 04:28:23 +02:00
|
|
|
__local ushort idx[256];
|
2022-06-20 03:59:29 +02:00
|
|
|
__local struct ircd_gpt_ctrl ctrl;
|
|
|
|
__private event_t event[1];
|
2021-03-30 03:22:42 +02:00
|
|
|
|
2022-06-20 03:59:29 +02:00
|
|
|
assume(ln == 256);
|
|
|
|
assume(start < stop);
|
|
|
|
|
|
|
|
event[0] = async_work_group_copy
|
|
|
|
(
|
|
|
|
(__local char16 *)&ctrl,
|
|
|
|
(__global const char16 *)ctrl_,
|
|
|
|
sizeof(struct ircd_gpt_ctrl) / sizeof(char16),
|
|
|
|
0
|
|
|
|
);
|
|
|
|
|
|
|
|
idx[li] = start;
|
|
|
|
for(uint j = start + 1; j < stop; ++j)
|
2021-04-11 04:28:23 +02:00
|
|
|
if(logsm[j] > logsm[idx[li]])
|
2021-03-30 03:22:42 +02:00
|
|
|
idx[li] = j;
|
|
|
|
|
2022-06-20 03:59:29 +02:00
|
|
|
ircd_simt_sort_idx16_flldr(idx, logsm, ln, li);
|
|
|
|
wait_group_events(1, event);
|
|
|
|
|
|
|
|
if(ctrl.count >= opts->buffer_tokens)
|
|
|
|
return;
|
2021-09-18 08:27:23 +02:00
|
|
|
|
|
|
|
if(li < opts->top_n)
|
2022-06-20 03:59:29 +02:00
|
|
|
ircd_gpt_lm_result_top(&ctrl, opts, idx, logsm, li);
|
2021-09-18 08:27:23 +02:00
|
|
|
|
|
|
|
if(li < opts->labels)
|
2022-06-20 03:59:29 +02:00
|
|
|
ircd_gpt_lm_result_label(&ctrl, opts, ctrl.label + li, logsm);
|
|
|
|
|
|
|
|
if(li < opts->layers * opts->attn_rank)
|
|
|
|
ircd_gpt_lm_result_attns(&ctrl, opts, attns, ln, li);
|
2021-09-18 08:27:23 +02:00
|
|
|
|
|
|
|
barrier(CLK_LOCAL_MEM_FENCE);
|
|
|
|
|
|
|
|
if(li == 0)
|
2022-06-20 03:59:29 +02:00
|
|
|
ircd_gpt_lm_result(&ctrl, opts, idx, logsm);
|
2021-09-18 08:27:23 +02:00
|
|
|
|
2022-06-20 03:59:29 +02:00
|
|
|
barrier(CLK_LOCAL_MEM_FENCE);
|
|
|
|
|
|
|
|
event[0] = async_work_group_copy
|
|
|
|
(
|
|
|
|
(__global char16 *)ctrl_,
|
|
|
|
(__local const char16 *)&ctrl,
|
|
|
|
sizeof(struct ircd_gpt_ctrl) / sizeof(char16),
|
|
|
|
0
|
|
|
|
);
|
2021-09-18 08:27:23 +02:00
|
|
|
|
2022-06-20 03:59:29 +02:00
|
|
|
wait_group_events(1, event);
|
|
|
|
}
|
|
|
|
|
|
|
|
__kernel void
|
|
|
|
__attribute__((visibility("protected")))
|
|
|
|
__attribute__((reqd_work_group_size(256, 1, 1)))
|
|
|
|
ircd_gpt_leave(__global const void *const restrict model,
|
|
|
|
__global void *const restrict state,
|
|
|
|
__global void *const restrict master,
|
|
|
|
__constant const struct ircd_gpt_opts *const opts,
|
|
|
|
__global struct ircd_gpt_ctrl *const ctrl_,
|
|
|
|
__global struct ircd_gpt_ctrl *const frame)
|
|
|
|
{
|
|
|
|
const ushort
|
|
|
|
li = get_local_id(0),
|
|
|
|
ln = get_local_size(0);
|
|
|
|
|
|
|
|
assume(ln == 256);
|
|
|
|
|
|
|
|
__local struct ircd_gpt_ctrl _ctrl;
|
|
|
|
__local struct ircd_gpt_ctrl *const ctrl = &_ctrl;
|
|
|
|
|
|
|
|
if(li == 0)
|
|
|
|
*ctrl = *ctrl_;
|
|
|
|
|
|
|
|
barrier(CLK_LOCAL_MEM_FENCE);
|
|
|
|
|
|
|
|
if(li == 0 && ctrl->accept < 0)
|
|
|
|
ircd_gpt_accept(ctrl, opts);
|
|
|
|
|
|
|
|
barrier(CLK_LOCAL_MEM_FENCE);
|
|
|
|
|
|
|
|
const uint
|
|
|
|
batch_size = opts->batch_size,
|
|
|
|
samps = opts->training_steps + opts->validation_steps + opts->testing_steps,
|
|
|
|
steps = samps / batch_size;
|
|
|
|
|
|
|
|
const bool
|
|
|
|
accepting = ctrl->accept >= 0,
|
|
|
|
cycling = !accepting,
|
|
|
|
sampling = accepting,
|
|
|
|
stepping = sampling && (ctrl->clk.samp + 1) >= batch_size,
|
|
|
|
epoching = stepping && (ctrl->clk.step + 1) >= steps;
|
|
|
|
|
|
|
|
if(li == 0)
|
2022-10-11 05:01:08 +02:00
|
|
|
ctrl->prof.finished = ircd_simt_cycles();
|
2022-06-20 03:59:29 +02:00
|
|
|
|
|
|
|
if(li == 0)
|
2022-10-05 22:55:21 +02:00
|
|
|
{
|
|
|
|
// BARTS won't update the scalar after the copy. In this case we'll
|
|
|
|
// be setting magic for all remaining frames :/
|
|
|
|
#if defined(__R600__)
|
|
|
|
ctrl->magic = 0xC7012C70UL;
|
|
|
|
#endif
|
|
|
|
|
2022-06-20 03:59:29 +02:00
|
|
|
*frame = *ctrl;
|
2022-10-05 22:55:21 +02:00
|
|
|
frame->magic = 0xC7012C70UL;
|
|
|
|
}
|
2022-06-20 03:59:29 +02:00
|
|
|
|
2022-10-05 22:55:21 +02:00
|
|
|
if(li == 0 && !accepting)
|
2022-06-20 03:59:29 +02:00
|
|
|
{
|
|
|
|
ctrl->clk.cycle += cycling;
|
|
|
|
ctrl->clk.samp += sampling;
|
|
|
|
ctrl->clk.step += stepping;
|
|
|
|
ctrl->clk.epoch += epoching;
|
|
|
|
}
|
2022-10-05 22:55:21 +02:00
|
|
|
|
|
|
|
if(li == 0 && accepting)
|
|
|
|
ctrl->magic = 0xC7012C70UL;
|
|
|
|
|
|
|
|
if(li == 0)
|
|
|
|
*ctrl_ = *ctrl;
|
2022-06-20 03:59:29 +02:00
|
|
|
}
|
|
|
|
|
|
|
|
void
|
|
|
|
ircd_gpt_accept(__local struct ircd_gpt_ctrl *const ctrl,
|
|
|
|
__constant const struct ircd_gpt_opts *const opts)
|
|
|
|
{
|
|
|
|
const bool
|
2022-10-09 03:00:40 +02:00
|
|
|
unlimited = opts->limit < 0,
|
|
|
|
unprocessed = ctrl->accept == -2,
|
|
|
|
acceptable = ctrl->accept < 0 && !unprocessed;
|
2022-06-20 03:59:29 +02:00
|
|
|
|
|
|
|
const uint
|
|
|
|
batch_size = opts->batch_size,
|
|
|
|
samps = opts->training_steps + opts->validation_steps + opts->testing_steps,
|
2022-07-02 03:50:20 +02:00
|
|
|
steps = samps / batch_size,
|
2022-10-09 03:00:40 +02:00
|
|
|
unproc = ctrl->tokens - ctrl->count,
|
|
|
|
limit = opts->limit;
|
2022-06-20 03:59:29 +02:00
|
|
|
|
|
|
|
const int
|
2022-10-09 03:00:40 +02:00
|
|
|
cycle_remain = opts->context_tokens - (ctrl->clk.cycle + 1), // cycle not yet incr
|
2022-06-20 03:59:29 +02:00
|
|
|
token_remain = opts->context_tokens - ctrl->count, // but count already incr
|
2022-10-09 03:00:40 +02:00
|
|
|
remain = min(cycle_remain, token_remain);
|
2022-06-20 03:59:29 +02:00
|
|
|
|
2022-10-09 03:00:40 +02:00
|
|
|
int
|
|
|
|
accept = ircd_gpt_accept_check(ctrl, opts),
|
|
|
|
dispatch = accept < 0? min(abs(accept), (uint)remain): 0;
|
2022-06-20 03:59:29 +02:00
|
|
|
|
2022-10-09 03:00:40 +02:00
|
|
|
if(opts->limit > 0 && ctrl->clk.cycle >= limit - 1)
|
|
|
|
{
|
|
|
|
accept = accept >= 0? accept: -1;
|
|
|
|
dispatch = 0;
|
|
|
|
}
|
|
|
|
|
|
|
|
if(opts->limit == 0 && unprocessed)
|
|
|
|
{
|
2022-10-19 00:10:04 +02:00
|
|
|
accept = accept >= 0? accept: -1;
|
2022-10-09 03:00:40 +02:00
|
|
|
dispatch = min((uint)remain, unproc);
|
|
|
|
}
|
|
|
|
|
|
|
|
if(opts->limit != 0 && !acceptable)
|
|
|
|
{
|
2022-10-19 00:10:04 +02:00
|
|
|
accept = accept >= 0? accept: -1;
|
2022-10-09 03:00:40 +02:00
|
|
|
dispatch = max(dispatch, 1);
|
|
|
|
}
|
2022-06-20 03:59:29 +02:00
|
|
|
|
|
|
|
ctrl->accept = accept;
|
2022-10-19 00:10:04 +02:00
|
|
|
ctrl->dispatch = min((uint)dispatch, opts->frames);
|
2022-06-20 03:59:29 +02:00
|
|
|
}
|
|
|
|
|
|
|
|
int
|
|
|
|
ircd_gpt_accept_check(__local struct ircd_gpt_ctrl *const ctrl,
|
|
|
|
__constant const struct ircd_gpt_opts *const opts)
|
|
|
|
{
|
2022-10-19 00:10:04 +02:00
|
|
|
int best = opts->frames;
|
2022-06-20 03:59:29 +02:00
|
|
|
for(uint i = 0; i < 4; ++i)
|
|
|
|
{
|
|
|
|
const int
|
|
|
|
remain = ircd_gpt_accept_match(ctrl, opts, i);
|
|
|
|
|
|
|
|
if(remain == 0)
|
|
|
|
return i;
|
|
|
|
|
|
|
|
if(remain < best)
|
|
|
|
best = remain;
|
|
|
|
}
|
|
|
|
|
|
|
|
return -best;
|
|
|
|
}
|
|
|
|
|
|
|
|
uint
|
|
|
|
ircd_gpt_accept_match(__local struct ircd_gpt_ctrl *const ctrl,
|
|
|
|
__constant const struct ircd_gpt_opts *const opts,
|
|
|
|
const uint i)
|
|
|
|
{
|
|
|
|
const uint
|
|
|
|
len = ircd_gpt_accept_len(ctrl, opts, i),
|
|
|
|
n = min(ctrl->count, len),
|
2022-10-19 00:10:04 +02:00
|
|
|
maxlen = opts->frames;
|
2022-06-20 03:59:29 +02:00
|
|
|
|
|
|
|
uint ret = len?: maxlen;
|
|
|
|
for(uint j = 1; j <= n; ++j)
|
|
|
|
{
|
|
|
|
uint match = 0;
|
|
|
|
for(; match < j; ++match)
|
|
|
|
{
|
|
|
|
const uint
|
|
|
|
accept = opts->accept[i][match],
|
|
|
|
token = ctrl->token[ctrl->count - j + match];
|
|
|
|
|
|
|
|
if(token != accept)
|
|
|
|
break;
|
|
|
|
}
|
|
|
|
|
|
|
|
if(match >= j)
|
|
|
|
if(!(ret = len - match))
|
|
|
|
break;
|
|
|
|
}
|
|
|
|
|
|
|
|
ret = max(ret, ctrl->tokens - ctrl->count);
|
|
|
|
ret = min(ret, maxlen);
|
|
|
|
return ret;
|
|
|
|
}
|
|
|
|
|
|
|
|
uint
|
|
|
|
ircd_gpt_accept_len(__local struct ircd_gpt_ctrl *const ctrl,
|
|
|
|
__constant const struct ircd_gpt_opts *const opts,
|
|
|
|
const uint i)
|
|
|
|
{
|
|
|
|
uint len = 0;
|
|
|
|
for(; len < 8; ++len)
|
|
|
|
if(opts->accept[i][len] == (ushort)-1U)
|
|
|
|
break;
|
|
|
|
|
|
|
|
return len;
|
2021-03-30 03:22:42 +02:00
|
|
|
}
|
2021-04-17 20:59:30 +02:00
|
|
|
|
|
|
|
//
|
2022-06-20 03:59:29 +02:00
|
|
|
// backside
|
2021-04-17 20:59:30 +02:00
|
|
|
//
|
|
|
|
|
2022-06-20 03:59:29 +02:00
|
|
|
void
|
2021-09-01 21:15:48 +02:00
|
|
|
__attribute__((always_inline))
|
2021-09-02 19:40:11 +02:00
|
|
|
ircd_gpt_prop_elem(__global const struct ircd_gpt_ctrl *const ctrl,
|
2021-04-17 20:59:30 +02:00
|
|
|
__constant const struct ircd_gpt_opts *const opts,
|
|
|
|
__global float4 *const restrict param_,
|
|
|
|
__global float4 *const restrict exp_avg_,
|
|
|
|
__global float4 *const restrict exp_avg_sqr_)
|
|
|
|
{
|
|
|
|
const uint
|
|
|
|
li = get_local_id(0),
|
2022-06-20 03:59:29 +02:00
|
|
|
ts = ctrl->clk.step;
|
2021-04-17 20:59:30 +02:00
|
|
|
|
|
|
|
const float4
|
|
|
|
param = param_[li],
|
2022-10-16 23:16:55 +02:00
|
|
|
grad = ctrl->target.loss.mean,
|
2021-04-17 20:59:30 +02:00
|
|
|
alpha[2] = { 1.0f - opts->beta[0], 1.0f - opts->beta[1], },
|
2022-06-20 03:59:29 +02:00
|
|
|
exp_avg = ts? exp_avg_[li]: 0.0f,
|
|
|
|
exp_avg_sqr = ts? exp_avg_sqr_[li]: 0.0f,
|
2021-04-17 20:59:30 +02:00
|
|
|
exp_avg_mul = exp_avg * opts->beta[0],
|
|
|
|
exp_avg_dot = exp_avg_mul + alpha[0] * grad,
|
|
|
|
exp_avg_sqr_mul = exp_avg_sqr * opts->beta[1],
|
|
|
|
exp_avg_sqr_dot = exp_avg_sqr_mul + alpha[1] * grad * grad,
|
2022-06-20 03:59:29 +02:00
|
|
|
denom = native_sqrt(exp_avg_sqr_dot) + FLT_EPSILON,
|
2021-04-17 20:59:30 +02:00
|
|
|
delta = opts->alpha * (exp_avg_dot / denom),
|
|
|
|
update = param - delta;
|
|
|
|
|
2022-10-16 23:16:55 +02:00
|
|
|
param_[li] = update;
|
|
|
|
exp_avg_[li] = exp_avg_dot;
|
|
|
|
exp_avg_sqr_[li] = exp_avg_sqr_dot;
|
2021-04-17 20:59:30 +02:00
|
|
|
}
|
|
|
|
|
2022-06-20 03:59:29 +02:00
|
|
|
//
|
|
|
|
// backpropagations
|
|
|
|
//
|
|
|
|
|
2021-04-17 20:59:30 +02:00
|
|
|
__kernel void
|
2022-06-20 03:59:29 +02:00
|
|
|
__attribute__((always_inline))
|
2021-09-02 19:40:11 +02:00
|
|
|
ircd_gpt_norm_prop(__global const struct ircd_gpt_ctrl *const ctrl,
|
2021-04-17 20:59:30 +02:00
|
|
|
__constant const struct ircd_gpt_opts *const opts,
|
2022-06-20 03:59:29 +02:00
|
|
|
__global ircd_gpt_vectorv *const restrict bias,
|
|
|
|
__global ircd_gpt_vectorv *const restrict bias_m0,
|
|
|
|
__global ircd_gpt_vectorv *const restrict bias_m1,
|
|
|
|
__global ircd_gpt_vectorv *const restrict weight,
|
|
|
|
__global ircd_gpt_vectorv *const restrict weight_m0,
|
|
|
|
__global ircd_gpt_vectorv *const restrict weight_m1)
|
2021-04-17 20:59:30 +02:00
|
|
|
{
|
|
|
|
ircd_gpt_prop_elem
|
|
|
|
(
|
|
|
|
ctrl, opts,
|
2022-06-20 03:59:29 +02:00
|
|
|
bias->elem,
|
|
|
|
bias_m0->elem,
|
|
|
|
bias_m1->elem
|
2021-04-17 20:59:30 +02:00
|
|
|
);
|
|
|
|
|
|
|
|
ircd_gpt_prop_elem
|
|
|
|
(
|
|
|
|
ctrl, opts,
|
2022-06-20 03:59:29 +02:00
|
|
|
weight->elem,
|
|
|
|
weight_m0->elem,
|
|
|
|
weight_m1->elem
|
2021-04-17 20:59:30 +02:00
|
|
|
);
|
|
|
|
}
|
|
|
|
|
|
|
|
__kernel void
|
2021-09-02 19:40:11 +02:00
|
|
|
ircd_gpt_coil_prop_attn(__global const struct ircd_gpt_ctrl *const ctrl,
|
2021-04-17 20:59:30 +02:00
|
|
|
__constant const struct ircd_gpt_opts *const opts,
|
2022-06-20 03:59:29 +02:00
|
|
|
__global ircd_gpt_vectorv *const restrict norm_bias,
|
|
|
|
__global ircd_gpt_vectorv *const restrict norm_bias_m0,
|
|
|
|
__global ircd_gpt_vectorv *const restrict norm_bias_m1,
|
|
|
|
__global ircd_gpt_vectorv *const restrict norm_weight,
|
|
|
|
__global ircd_gpt_vectorv *const restrict norm_weight_m0,
|
|
|
|
__global ircd_gpt_vectorv *const restrict norm_weight_m1,
|
|
|
|
__global ircd_gpt_attn_aperaturev *const restrict fcon_bias,
|
|
|
|
__global ircd_gpt_attn_aperaturev *const restrict fcon_bias_m0,
|
|
|
|
__global ircd_gpt_attn_aperaturev *const restrict fcon_bias_m1,
|
|
|
|
__global ircd_gpt_attn_aperaturev *const restrict fcon_weight,
|
|
|
|
__global ircd_gpt_attn_aperaturev *const restrict fcon_weight_m0,
|
|
|
|
__global ircd_gpt_attn_aperaturev *const restrict fcon_weight_m1,
|
|
|
|
__global ircd_gpt_vectorv *const restrict proj_bias,
|
|
|
|
__global ircd_gpt_vectorv *const restrict proj_bias_m0,
|
|
|
|
__global ircd_gpt_vectorv *const restrict proj_bias_m1,
|
|
|
|
__global ircd_gpt_vectorv *const restrict proj_weight,
|
|
|
|
__global ircd_gpt_vectorv *const restrict proj_weight_m0,
|
|
|
|
__global ircd_gpt_vectorv *const restrict proj_weight_m1)
|
2021-04-17 20:59:30 +02:00
|
|
|
{
|
2022-06-20 03:59:29 +02:00
|
|
|
const uint
|
|
|
|
fcon_height = opts->embed_elems,
|
|
|
|
proj_height = opts->embed_elems,
|
|
|
|
segs = 3;
|
|
|
|
|
2021-04-17 20:59:30 +02:00
|
|
|
ircd_gpt_norm_prop
|
|
|
|
(
|
|
|
|
ctrl, opts,
|
|
|
|
norm_bias,
|
|
|
|
norm_bias_m0,
|
|
|
|
norm_bias_m1,
|
|
|
|
norm_weight,
|
|
|
|
norm_weight_m0,
|
|
|
|
norm_weight_m1
|
|
|
|
);
|
|
|
|
|
2022-06-20 03:59:29 +02:00
|
|
|
for(uint j = 0; j < segs; ++j)
|
2021-04-17 20:59:30 +02:00
|
|
|
ircd_gpt_prop_elem
|
|
|
|
(
|
|
|
|
ctrl, opts,
|
|
|
|
fcon_bias->proj[j],
|
|
|
|
fcon_bias_m0->proj[j],
|
|
|
|
fcon_bias_m1->proj[j]
|
|
|
|
);
|
|
|
|
|
2022-06-20 03:59:29 +02:00
|
|
|
for(uint i = 0; i < fcon_height; ++i)
|
|
|
|
for(uint j = 0; j < segs; ++j)
|
2021-04-17 20:59:30 +02:00
|
|
|
ircd_gpt_prop_elem
|
|
|
|
(
|
|
|
|
ctrl, opts,
|
|
|
|
fcon_weight[i].proj[j],
|
|
|
|
fcon_weight_m0[i].proj[j],
|
|
|
|
fcon_weight_m1[i].proj[j]
|
|
|
|
);
|
|
|
|
|
|
|
|
ircd_gpt_prop_elem
|
|
|
|
(
|
|
|
|
ctrl, opts,
|
2022-06-20 03:59:29 +02:00
|
|
|
proj_bias->elem,
|
|
|
|
proj_bias_m0->elem,
|
|
|
|
proj_bias_m1->elem
|
2021-04-17 20:59:30 +02:00
|
|
|
);
|
|
|
|
|
2022-06-20 03:59:29 +02:00
|
|
|
for(uint i = 0; i < proj_height; ++i)
|
2021-04-17 20:59:30 +02:00
|
|
|
ircd_gpt_prop_elem
|
|
|
|
(
|
|
|
|
ctrl, opts,
|
2022-06-20 03:59:29 +02:00
|
|
|
proj_weight[i].elem,
|
|
|
|
proj_weight_m0[i].elem,
|
|
|
|
proj_weight_m1[i].elem
|
2021-04-17 20:59:30 +02:00
|
|
|
);
|
|
|
|
}
|
|
|
|
|
|
|
|
__kernel void
|
2021-09-02 19:40:11 +02:00
|
|
|
ircd_gpt_coil_prop_ffnn(__global const struct ircd_gpt_ctrl *const ctrl,
|
2021-04-17 20:59:30 +02:00
|
|
|
__constant const struct ircd_gpt_opts *const opts,
|
2022-06-20 03:59:29 +02:00
|
|
|
__global ircd_gpt_vectorv *const restrict norm_bias,
|
|
|
|
__global ircd_gpt_vectorv *const restrict norm_bias_m0,
|
|
|
|
__global ircd_gpt_vectorv *const restrict norm_bias_m1,
|
|
|
|
__global ircd_gpt_vectorv *const restrict norm_weight,
|
|
|
|
__global ircd_gpt_vectorv *const restrict norm_weight_m0,
|
|
|
|
__global ircd_gpt_vectorv *const restrict norm_weight_m1,
|
|
|
|
__global ircd_gpt_ffnn_aperaturev *const restrict fcon_bias,
|
|
|
|
__global ircd_gpt_ffnn_aperaturev *const restrict fcon_bias_m0,
|
|
|
|
__global ircd_gpt_ffnn_aperaturev *const restrict fcon_bias_m1,
|
|
|
|
__global ircd_gpt_ffnn_aperaturev *const restrict fcon_weight,
|
|
|
|
__global ircd_gpt_ffnn_aperaturev *const restrict fcon_weight_m0,
|
|
|
|
__global ircd_gpt_ffnn_aperaturev *const restrict fcon_weight_m1,
|
|
|
|
__global ircd_gpt_vectorv *const restrict proj_bias,
|
|
|
|
__global ircd_gpt_vectorv *const restrict proj_bias_m0,
|
|
|
|
__global ircd_gpt_vectorv *const restrict proj_bias_m1,
|
|
|
|
__global ircd_gpt_vectorv *const restrict proj_weight,
|
|
|
|
__global ircd_gpt_vectorv *const restrict proj_weight_m0,
|
|
|
|
__global ircd_gpt_vectorv *const restrict proj_weight_m1)
|
2021-04-17 20:59:30 +02:00
|
|
|
{
|
2022-06-20 03:59:29 +02:00
|
|
|
const uint
|
|
|
|
fcon_height = opts->embed_elems,
|
|
|
|
proj_height = opts->ffnn_elems,
|
|
|
|
segs = 4;
|
|
|
|
|
2021-04-17 20:59:30 +02:00
|
|
|
ircd_gpt_norm_prop
|
|
|
|
(
|
|
|
|
ctrl, opts,
|
|
|
|
norm_bias,
|
|
|
|
norm_bias_m0,
|
|
|
|
norm_bias_m1,
|
|
|
|
norm_weight,
|
|
|
|
norm_weight_m0,
|
|
|
|
norm_weight_m1
|
|
|
|
);
|
|
|
|
|
2022-06-20 03:59:29 +02:00
|
|
|
for(uint j = 0; j < segs; ++j)
|
2021-04-17 20:59:30 +02:00
|
|
|
ircd_gpt_prop_elem
|
|
|
|
(
|
|
|
|
ctrl, opts,
|
|
|
|
fcon_bias->proj[j],
|
|
|
|
fcon_bias_m0->proj[j],
|
|
|
|
fcon_bias_m1->proj[j]
|
|
|
|
);
|
|
|
|
|
2022-06-20 03:59:29 +02:00
|
|
|
for(uint i = 0; i < fcon_height; ++i)
|
|
|
|
for(uint j = 0; j < segs; ++j)
|
2021-04-17 20:59:30 +02:00
|
|
|
ircd_gpt_prop_elem
|
|
|
|
(
|
|
|
|
ctrl, opts,
|
|
|
|
fcon_weight[i].proj[j],
|
|
|
|
fcon_weight_m0[i].proj[j],
|
|
|
|
fcon_weight_m1[i].proj[j]
|
|
|
|
);
|
|
|
|
|
|
|
|
ircd_gpt_prop_elem
|
|
|
|
(
|
|
|
|
ctrl, opts,
|
2022-06-20 03:59:29 +02:00
|
|
|
proj_bias->elem,
|
|
|
|
proj_bias_m0->elem,
|
|
|
|
proj_bias_m1->elem
|
2021-04-17 20:59:30 +02:00
|
|
|
);
|
|
|
|
|
2022-06-20 03:59:29 +02:00
|
|
|
for(uint i = 0; i < proj_height; ++i)
|
2021-04-17 20:59:30 +02:00
|
|
|
ircd_gpt_prop_elem
|
|
|
|
(
|
|
|
|
ctrl, opts,
|
2022-06-20 03:59:29 +02:00
|
|
|
proj_weight[i].elem,
|
|
|
|
proj_weight_m0[i].elem,
|
|
|
|
proj_weight_m1[i].elem
|
2021-04-17 20:59:30 +02:00
|
|
|
);
|
|
|
|
}
|
|
|
|
|
|
|
|
__kernel void
|
2021-09-02 19:40:11 +02:00
|
|
|
ircd_gpt_lm_embed_prop(__global const struct ircd_gpt_ctrl *const ctrl,
|
2021-04-17 20:59:30 +02:00
|
|
|
__constant const struct ircd_gpt_opts *const opts,
|
2022-06-20 03:59:29 +02:00
|
|
|
__global ircd_gpt_vectorv *const restrict pos,
|
|
|
|
__global ircd_gpt_vectorv *const restrict pos_m0,
|
|
|
|
__global ircd_gpt_vectorv *const restrict pos_m1,
|
|
|
|
__global ircd_gpt_vectorv *const restrict token,
|
|
|
|
__global ircd_gpt_vectorv *const restrict token_m0,
|
|
|
|
__global ircd_gpt_vectorv *const restrict token_m1)
|
2021-04-17 20:59:30 +02:00
|
|
|
{
|
|
|
|
const uint
|
|
|
|
ln = get_local_size(0),
|
2021-09-17 08:02:11 +02:00
|
|
|
wi = get_global_offset(0) / ln + get_group_id(0),
|
2022-06-20 03:59:29 +02:00
|
|
|
wn = ctrl->count,
|
2021-04-17 20:59:30 +02:00
|
|
|
cn = opts->context_tokens / wn,
|
|
|
|
ci = cn * wi,
|
|
|
|
tn = opts->logits / wn,
|
|
|
|
ti = tn * wi;
|
|
|
|
|
|
|
|
for(uint i = ci; i < ci + cn; ++i)
|
|
|
|
ircd_gpt_prop_elem
|
|
|
|
(
|
|
|
|
ctrl, opts,
|
2022-06-20 03:59:29 +02:00
|
|
|
pos[i].elem,
|
|
|
|
pos_m0[i].elem,
|
|
|
|
pos_m1[i].elem
|
2021-04-17 20:59:30 +02:00
|
|
|
);
|
|
|
|
|
|
|
|
for(uint i = ti; i < ti + tn; ++i)
|
|
|
|
ircd_gpt_prop_elem
|
|
|
|
(
|
|
|
|
ctrl, opts,
|
2022-06-20 03:59:29 +02:00
|
|
|
token[i].elem,
|
|
|
|
token_m0[i].elem,
|
|
|
|
token_m1[i].elem
|
2021-04-17 20:59:30 +02:00
|
|
|
);
|
|
|
|
}
|
2022-06-20 03:59:29 +02:00
|
|
|
|
|
|
|
/// Gaussian Error Linear Unit
|
|
|
|
void
|
|
|
|
ircd_gpt_ffnn_gelu(__local ircd_gpt_ffnn_aperaturev *const out,
|
|
|
|
__local const ircd_gpt_ffnn_aperaturev *const in_,
|
|
|
|
const uint i)
|
|
|
|
{
|
|
|
|
const float4
|
|
|
|
in = in_->fcon[i];
|
|
|
|
|
|
|
|
float4 a;
|
|
|
|
a = 0.044715f;
|
|
|
|
a *= in;
|
|
|
|
a *= in;
|
|
|
|
a += 1.0f;
|
|
|
|
a *= 0.7978845608f;
|
|
|
|
a *= in;
|
|
|
|
|
|
|
|
a = tanh(a);
|
|
|
|
a += 1.0f;
|
|
|
|
a *= in;
|
|
|
|
a *= 0.5f;
|
|
|
|
|
|
|
|
out->fcon[i] = a;
|
|
|
|
}
|
|
|
|
|
|
|
|
void
|
|
|
|
ircd_gpt_norm(__local ircd_gpt_vectorv *const out,
|
|
|
|
__local const ircd_gpt_vectorv *const in,
|
|
|
|
__local ircd_gpt_vectorv *const restrict tmp,
|
|
|
|
__global const ircd_gpt_vectorv *const restrict bias,
|
|
|
|
__global const ircd_gpt_vectorv *const restrict weight,
|
|
|
|
const uint ln,
|
|
|
|
const uint li)
|
|
|
|
{
|
|
|
|
// Layer re-normalization
|
|
|
|
ircd_simt_math_norm_f4lldr(out->elem, in->elem, tmp->elem, ln, li);
|
|
|
|
ircd_gpt_norm_fmad(out, out, bias, weight, li);
|
|
|
|
}
|
|
|
|
|
|
|
|
void
|
|
|
|
ircd_gpt_norm_fmad(__local ircd_gpt_vectorv *const out,
|
|
|
|
__local const ircd_gpt_vectorv *const in,
|
|
|
|
__global const ircd_gpt_vectorv *const restrict bias,
|
|
|
|
__global const ircd_gpt_vectorv *const restrict weight,
|
|
|
|
const uint i)
|
|
|
|
{
|
|
|
|
out->elem[i] = in->elem[i] * weight->elem[i] + bias->elem[i];
|
|
|
|
}
|