godot/thirdparty/oidn/mkl-dnn/src/cpu/simple_concat.hpp
Rémi Verschelde 38a200f1e3 oidn: Fix build for VS 2017
Backporting this upstream patch:
1e42e6db81

Fixes #39186.
2020-06-06 21:27:32 +02:00

156 lines
5.6 KiB
C++

/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#ifndef SIMPLE_CONCAT_HPP
#define SIMPLE_CONCAT_HPP
#include "memory_tracking.hpp"
#include "cpu_concat_pd.hpp"
#include "cpu_primitive.hpp"
namespace mkldnn {
namespace impl {
namespace cpu {
template <data_type_t data_type>
struct simple_concat_t: public cpu_primitive_t {
struct pd_t: public cpu_concat_pd_t {
using cpu_concat_pd_t::cpu_concat_pd_t;
pd_t(const pd_t &rhs): cpu_concat_pd_t(rhs) {
int ndims = rhs.dst_md_.ndims;
utils::array_copy(perm_, rhs.perm_, ndims);
utils::array_copy(iperm_, rhs.iperm_, ndims);
utils::array_copy(blocks_, rhs.blocks_, ndims);
}
DECLARE_CONCAT_PD_T("simple:any", simple_concat_t);
status_t init() {
const memory_desc_wrapper dst_d(dst_md());
bool ok = true
&& cpu_concat_pd_t::init() == status::success
&& dst_d.ndims() <= 6;
if (!ok) return status::unimplemented;
for (size_t i = 0; i < src_mds_.size(); ++i) {
const memory_desc_wrapper i_d(&src_mds_[i]);
const memory_desc_wrapper o_d(&src_image_mds_[i]);
const int ignore_strides = 0;
ok = ok
&& utils::everyone_is(data_type, i_d.data_type(),
o_d.data_type())
&& utils::everyone_is(format_kind::blocked,
i_d.format_kind(), o_d.format_kind())
&& types::blocking_desc_is_equal(i_d.blocking_desc(),
o_d.blocking_desc(), ignore_strides)
&& types::blocking_desc_is_equal(i_d.blocking_desc(),
dst_d.blocking_desc(), ignore_strides)
&& !i_d.is_additional_buffer();
if (!ok) return status::unimplemented;
}
dst_d.compute_blocks(blocks_);
format_perm();
// start dim is the first dimension after which the concatenation
// would happen contiguously
const int start_dim = perm_[concat_dim()];
// check that contiguous part is indeed contiguous (i.e. dense)
if (nelems_to_concat(dst_d) !=
dst_d.padded_dims()[concat_dim()] / blocks_[concat_dim()]
* dst_d.blocking_desc().strides[concat_dim()])
return status::unimplemented;
// check that all inputs have the same strides for the
// contiguous part [concat_dim .. ndims] for the *major* dims.
// the block part is already checked above
for (size_t i = 0; i < src_mds_.size(); ++i) {
const memory_desc_wrapper i_d(&src_mds_[i]);
for (int d = start_dim; d < dst_d.ndims(); ++d) {
if (dst_d.blocking_desc().strides[iperm_[d]]
!= i_d.blocking_desc().strides[iperm_[d]])
return status::unimplemented;
}
}
init_scratchpad();
return status::success;
}
int perm_[MKLDNN_MAX_NDIMS] {};
int iperm_[MKLDNN_MAX_NDIMS] {};
dims_t blocks_ {};
dim_t nelems_to_concat(const memory_desc_wrapper &data_d) const {
const int ndims = data_d.ndims();
dim_t nelems = 1;
for (int i = perm_[concat_dim()]; i < ndims; i++)
nelems *= data_d.dims()[iperm_[i]] / blocks_[iperm_[i]];
for (int i = 0; i < ndims; i++)
nelems *= blocks_[i];
return nelems;
}
private:
void format_perm() {
const memory_desc_wrapper dst_d(dst_md());
const int ndims = dst_d.ndims();
strides_t strides;
utils::array_copy(strides, dst_d.blocking_desc().strides, ndims);
for (int i = 0; i < ndims; i++) iperm_[i] = i;
utils::simultaneous_sort(strides, iperm_, ndims,
[](stride_t a, stride_t b) { return b - a; });
for (int i = 0; i < ndims; i++) perm_[iperm_[i]] = i;
}
void init_scratchpad() {
using namespace memory_tracking::names;
auto scratchpad = scratchpad_registry().registrar();
scratchpad.book(key_concat_iptrs, sizeof(data_t *) * n_inputs());
scratchpad.book(key_concat_optrs, sizeof(data_t *) * n_inputs());
scratchpad.book(key_concat_nelems, sizeof(dim_t) * n_inputs());
scratchpad.book(key_concat_istrides,
sizeof(strides_t) * n_inputs());
}
};
simple_concat_t(const pd_t *apd): cpu_primitive_t(apd) {}
virtual status_t execute(const exec_ctx_t &ctx) const override;
typedef typename prec_traits<data_type>::type data_t;
private:
const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
};
}
}
}
#endif