0
0
Fork 0
mirror of https://github.com/matrix-construct/construct synced 2024-07-06 18:48:36 +02:00

ircd::math: Allow increased mean template precision.

This commit is contained in:
Jason Volk 2021-03-15 20:37:31 -07:00
parent 05a3e505cc
commit 70bb5c257a

View file

@ -13,35 +13,39 @@
namespace ircd::math namespace ircd::math
{ {
template<class T> template<class T,
typename std::enable_if<!simd::is<T>(), T>::type class R = T>
mean(const vector_view<const T> &); typename std::enable_if<!simd::is<T>(), R>::type
mean(const vector_view<const T>);
template<class T> template<class T,
typename std::enable_if<simd::is<T>(), simd::lane_type<T>>::type class R = T>
mean(const vector_view<const T> &); typename std::enable_if<simd::is<T>(), simd::lane_type<R>>::type
mean(const vector_view<const T>);
} }
template<class T> template<class T,
inline typename std::enable_if<ircd::simd::is<T>(), ircd::simd::lane_type<T>>::type class R>
ircd::math::mean(const vector_view<const T> &a) inline typename std::enable_if<ircd::simd::is<T>(), ircd::simd::lane_type<R>>::type
ircd::math::mean(const vector_view<const T> a)
{ {
using value_type = simd::lane_type<T>; R acc {0};
simd::for_each(a.data(), u64x2{0, a.size()}, [&acc]
const auto &sum (const auto block, const auto mask)
{ {
simd::accumulate(a.data(), u64x2{0, a.size()}, T{0}, [] const R dp
(auto &ret, const auto block, const auto mask) (
{ simd::lane_cast<R>(block)
ret += block; );
})
};
value_type num {0}; acc += dp;
for(size_t i{0}; i < simd::lanes<T>(); ++i) });
num += sum[i];
const auto &den auto num(acc[0]);
for(uint i(1); i < simd::lanes<T>(); ++i)
num += acc[i];
const auto den
{ {
a.size() * simd::lanes<T>() a.size() * simd::lanes<T>()
}; };
@ -50,12 +54,12 @@ ircd::math::mean(const vector_view<const T> &a)
return num; return num;
} }
template<class T> template<class T,
inline typename std::enable_if<!ircd::simd::is<T>(), T>::type class R>
ircd::math::mean(const vector_view<const T> &a) inline typename std::enable_if<!ircd::simd::is<T>(), R>::type
ircd::math::mean(const vector_view<const T> a)
{ {
T ret{0}; R ret{0};
size_t i{0}; size_t i{0};
while(i < a.size()) while(i < a.size())
ret += a[i++]; ret += a[i++];