0
0
Fork 0
mirror of https://github.com/matrix-construct/construct synced 2024-11-13 21:41:06 +01:00
construct/include/ircd/gpt/vector.h

148 lines
2.9 KiB
C
Raw Normal View History

2022-06-20 03:59:29 +02:00
// Matrix Construct
//
// Copyright (C) Matrix Construct Developers, Authors & Contributors
// Copyright (C) 2016-2022 Jason Volk <jason@zemos.net>
//
// Permission to use, copy, modify, and/or distribute this software for any
// purpose with or without fee is hereby granted, provided that the above
// copyright notice and this permission notice is present in all copies. The
// full license for this software is available in the LICENSE file.
#pragma once
#define HAVE_IRCD_GPT_VECTOR_H
#if !defined(__SIZEOF_FLOAT4__) && defined(__OPENCL_VERSION__)
#define __SIZEOF_FLOAT4__ 16
#endif
#if !defined(__SIZEOF_FLOAT8__) && defined(__OPENCL_VERSION__)
#define __SIZEOF_FLOAT8__ 32
#endif
#if !defined(__SIZEOF_FLOAT16__) && defined(__OPENCL_VERSION__)
#define __SIZEOF_FLOAT16__ 64
#endif
#ifndef __OPENCL_VERSION__
#define __constant
#endif
static __constant const uint
ircd_gpt_context_tokens = 512, // 1024,
ircd_gpt_vector_elems = 768,
ircd_gpt_attn_rank = 12,
ircd_gpt_attn_segs = 3,
ircd_gpt_ffnn_segs = 4;
static __constant const uint
ircd_gpt_vector_attn_elems = ircd_gpt_vector_elems / ircd_gpt_attn_rank,
ircd_gpt_attn_fcon_elems = ircd_gpt_vector_elems * ircd_gpt_attn_segs,
ircd_gpt_ffnn_fcon_elems = ircd_gpt_vector_elems * ircd_gpt_ffnn_segs;
//
// embed vector
//
#if defined(__SIZEOF_FLOAT__)
union ircd_gpt_vector
{
float
elem[ircd_gpt_vector_elems],
attn[ircd_gpt_attn_rank][ircd_gpt_vector_attn_elems];
};
#endif
#if defined(__SIZEOF_FLOAT4__)
union ircd_gpt_vector_f32x4
{
float4
elem[ircd_gpt_vector_elems / 4],
attn[ircd_gpt_attn_rank][ircd_gpt_vector_attn_elems / 4];
union ircd_gpt_vector
vector;
};
#endif
//
// attn qkv
//
#if defined(__SIZEOF_FLOAT__)
struct ircd_gpt_attn_qkv
{
union ircd_gpt_vector
qry,
key,
val;
};
#endif
#if defined(__SIZEOF_FLOAT4__)
struct ircd_gpt_attn_qkv_f32x4
{
union ircd_gpt_vector_f32x4
qry,
key,
val;
};
#endif
//
// attn aperature
//
#if defined(__SIZEOF_FLOAT__)
union ircd_gpt_attn_aperature
{
float
fcon[ircd_gpt_attn_fcon_elems],
proj[ircd_gpt_attn_segs][ircd_gpt_vector_elems],
qkv[ircd_gpt_attn_segs][ircd_gpt_attn_rank][ircd_gpt_vector_attn_elems];
union ircd_gpt_vector
vector[ircd_gpt_attn_segs];
};
#endif
#if defined(__SIZEOF_FLOAT4__)
union ircd_gpt_attn_aperature_f32x4
{
float4
fcon[ircd_gpt_attn_fcon_elems / 4],
proj[ircd_gpt_attn_segs][ircd_gpt_vector_elems / 4],
qkv[ircd_gpt_attn_segs][ircd_gpt_attn_rank][ircd_gpt_vector_attn_elems / 4];
union ircd_gpt_vector_f32x4
vector[ircd_gpt_attn_segs];
};
#endif
//
// ffnn aperature
//
#if defined(__SIZEOF_FLOAT__)
union ircd_gpt_ffnn_aperature
{
float
fcon[ircd_gpt_ffnn_fcon_elems],
proj[ircd_gpt_ffnn_segs][ircd_gpt_vector_elems];
union ircd_gpt_vector
vector[ircd_gpt_ffnn_segs];
};
#endif
#if defined(__SIZEOF_FLOAT4__)
union ircd_gpt_ffnn_aperature_f32x4
{
float4
fcon[ircd_gpt_ffnn_fcon_elems / 4],
proj[ircd_gpt_ffnn_segs][ircd_gpt_vector_elems / 4];
union ircd_gpt_vector_f32x4
vector[ircd_gpt_ffnn_segs];
};
#endif