mirror of
https://github.com/matrix-construct/construct
synced 2025-01-12 07:54:12 +01:00
ircd::simt: Split vector reduce_add to hadd.
This commit is contained in:
parent
1f87668a28
commit
d377674748
5 changed files with 55 additions and 16 deletions
51
include/ircd/simt/hadd.h
Normal file
51
include/ircd/simt/hadd.h
Normal file
|
@ -0,0 +1,51 @@
|
||||||
|
// Matrix Construct
|
||||||
|
//
|
||||||
|
// Copyright (C) Matrix Construct Developers, Authors & Contributors
|
||||||
|
// Copyright (C) 2016-2022 Jason Volk <jason@zemos.net>
|
||||||
|
//
|
||||||
|
// Permission to use, copy, modify, and/or distribute this software for any
|
||||||
|
// purpose with or without fee is hereby granted, provided that the above
|
||||||
|
// copyright notice and this permission notice is present in all copies. The
|
||||||
|
// full license for this software is available in the LICENSE file.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
#define HAVE_IRCD_SIMT_HADD_H
|
||||||
|
|
||||||
|
#if defined(__OPENCL_VERSION__) && defined(__SIZEOF_FLOAT4__)
|
||||||
|
inline float
|
||||||
|
__attribute__((always_inline))
|
||||||
|
ircd_simt_hadd_f4(const float4 in)
|
||||||
|
{
|
||||||
|
float ret = 0.0f;
|
||||||
|
for(uint i = 0; i < 4; ++i)
|
||||||
|
ret += in[i];
|
||||||
|
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#if defined(__OPENCL_VERSION__) && defined(__SIZEOF_FLOAT8__)
|
||||||
|
inline float
|
||||||
|
__attribute__((always_inline))
|
||||||
|
ircd_simt_hadd_f8(const float8 in)
|
||||||
|
{
|
||||||
|
float ret = 0.0f;
|
||||||
|
for(uint i = 0; i < 8; ++i)
|
||||||
|
ret += in[i];
|
||||||
|
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#if defined(__OPENCL_VERSION__) && defined(__SIZEOF_FLOAT16__)
|
||||||
|
inline float
|
||||||
|
__attribute__((always_inline))
|
||||||
|
ircd_simt_hadd_f16(const float16 in)
|
||||||
|
{
|
||||||
|
float ret = 0.0f;
|
||||||
|
for(uint i = 0; i < 16; ++i)
|
||||||
|
ret += in[i];
|
||||||
|
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
#endif
|
|
@ -61,7 +61,7 @@ ircd_simt_math_mean_f4lldr(__local float4 *const buf,
|
||||||
if(li == 0)
|
if(li == 0)
|
||||||
{
|
{
|
||||||
const float
|
const float
|
||||||
sum = ircd_simt_reduce_add_f4(buf[li]),
|
sum = ircd_simt_hadd_f4(buf[li]),
|
||||||
div = ln * 4,
|
div = ln * 4,
|
||||||
res = sum / div;
|
res = sum / div;
|
||||||
|
|
||||||
|
|
|
@ -88,16 +88,3 @@ ircd_simt_reduce_add_ulldr(__local uint *const buf,
|
||||||
atomic_add(buf + 0, buf[li]);
|
atomic_add(buf + 0, buf[li]);
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#ifdef __OPENCL_VERSION__
|
|
||||||
inline float
|
|
||||||
__attribute__((always_inline))
|
|
||||||
ircd_simt_reduce_add_f4(const float4 in)
|
|
||||||
{
|
|
||||||
float ret = 0.0f;
|
|
||||||
for(uint i = 0; i < 4; ++i)
|
|
||||||
ret += in[i];
|
|
||||||
|
|
||||||
return ret;
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
|
|
|
@ -15,6 +15,7 @@
|
||||||
#include "assert.h"
|
#include "assert.h"
|
||||||
#include "cycles.h"
|
#include "cycles.h"
|
||||||
#include "math.h"
|
#include "math.h"
|
||||||
|
#include "hadd.h"
|
||||||
#include "broadcast.h"
|
#include "broadcast.h"
|
||||||
#include "reduce_add.h"
|
#include "reduce_add.h"
|
||||||
#include "reduce_max.h"
|
#include "reduce_max.h"
|
||||||
|
|
|
@ -303,7 +303,7 @@ ircd_gpt_attn_self_keys(__global const struct ircd_gpt_ctrl *const ctrl,
|
||||||
key = token[i].key.attn[li][k],
|
key = token[i].key.attn[li][k],
|
||||||
res = qry * key;
|
res = qry * key;
|
||||||
|
|
||||||
self[i][li] += ircd_simt_reduce_add_f4(res);
|
self[i][li] += ircd_simt_hadd_f4(res);
|
||||||
}
|
}
|
||||||
|
|
||||||
self[i][li] /= 8.0f;
|
self[i][li] /= 8.0f;
|
||||||
|
@ -679,7 +679,7 @@ ircd_gpt_lm_logit(__global const struct ircd_gpt_ctrl *const ctrl,
|
||||||
wpe = pos[wi].elem[j],
|
wpe = pos[wi].elem[j],
|
||||||
res = in * token + wpe;
|
res = in * token + wpe;
|
||||||
|
|
||||||
acc += ircd_simt_reduce_add_f4(res);
|
acc += ircd_simt_hadd_f4(res);
|
||||||
}
|
}
|
||||||
|
|
||||||
logit[gi] = acc;
|
logit[gi] = acc;
|
||||||
|
|
Loading…
Reference in a new issue